agent_stream_kit/
tool.rs

1use std::{
2    collections::{BTreeMap, HashMap, HashSet},
3    sync::{Arc, Mutex, OnceLock, RwLock},
4    time::Duration,
5};
6
7use crate::{
8    ASKit, Agent, AgentContext, AgentData, AgentError, AgentOutput, AgentSpec, AgentValue, AsAgent,
9    Message, ToolCall, askit_agent, async_trait,
10};
11use im::{Vector, vector};
12use regex::RegexSet;
13use tokio::sync::{Mutex as AsyncMutex, oneshot};
14
15const CATEGORY: &str = "Core/Tool";
16
17const PIN_MESSAGE: &str = "message";
18const PIN_PATTERNS: &str = "patterns";
19const PIN_TOOLS: &str = "tools";
20const PIN_TOOL_CALL: &str = "tool_call";
21const PIN_TOOL_IN: &str = "tool_in";
22const PIN_TOOL_OUT: &str = "tool_out";
23const PIN_VALUE: &str = "value";
24
25const CONFIG_TOOLS: &str = "tools";
26const CONFIG_TOOL_NAME: &str = "name";
27const CONFIG_TOOL_DESCRIPTION: &str = "description";
28const CONFIG_TOOL_PARAMETERS: &str = "parameters";
29
30#[derive(Clone, Debug)]
31pub struct ToolInfo {
32    pub name: String,
33    pub description: String,
34    pub parameters: Option<serde_json::Value>,
35}
36
37/// Trait for Tool implementations.
38#[async_trait]
39pub trait Tool {
40    fn info(&self) -> &ToolInfo;
41
42    /// Call the tool with the given context and arguments.
43    async fn call(&self, ctx: AgentContext, args: AgentValue) -> Result<AgentValue, AgentError>;
44}
45
46impl From<ToolInfo> for AgentValue {
47    fn from(info: ToolInfo) -> Self {
48        let mut obj: BTreeMap<String, AgentValue> = BTreeMap::new();
49        obj.insert("name".to_string(), AgentValue::from(info.name));
50        obj.insert(
51            "description".to_string(),
52            AgentValue::from(info.description),
53        );
54        if let Some(params) = &info.parameters {
55            if let Ok(params_value) = AgentValue::from_serialize(params) {
56                obj.insert("parameters".to_string(), params_value);
57            }
58        }
59        AgentValue::object(obj.into())
60    }
61}
62
63#[derive(Clone)]
64struct ToolEntry {
65    info: ToolInfo,
66    tool: Arc<Box<dyn Tool + Send + Sync>>,
67}
68
69impl ToolEntry {
70    fn new<T: Tool + Send + Sync + 'static>(tool: T) -> Self {
71        Self {
72            info: tool.info().clone(),
73            tool: Arc::new(Box::new(tool)),
74        }
75    }
76}
77
78struct ToolRegistry {
79    tools: HashMap<String, ToolEntry>,
80}
81
82impl ToolRegistry {
83    fn new() -> Self {
84        Self {
85            tools: HashMap::new(),
86        }
87    }
88
89    fn register_tool<T: Tool + Send + Sync + 'static>(&mut self, tool: T) {
90        let name = tool.info().name.to_string();
91        let entry = ToolEntry::new(tool);
92        self.tools.insert(name, entry);
93    }
94
95    fn unregister_tool(&mut self, name: &str) {
96        self.tools.remove(name);
97    }
98
99    fn get_tool(&self, name: &str) -> Option<Arc<Box<dyn Tool + Send + Sync>>> {
100        self.tools.get(name).map(|entry| entry.tool.clone())
101    }
102}
103
104// Global registry instance.
105static TOOL_REGISTRY: OnceLock<RwLock<ToolRegistry>> = OnceLock::new();
106
107fn registry() -> &'static RwLock<ToolRegistry> {
108    TOOL_REGISTRY.get_or_init(|| RwLock::new(ToolRegistry::new()))
109}
110
111/// Register a new tool.
112pub fn register_tool<T: Tool + Send + Sync + 'static>(tool: T) {
113    registry().write().unwrap().register_tool(tool);
114}
115
116/// Unregister a tool by name.
117pub fn unregister_tool(name: &str) {
118    registry().write().unwrap().unregister_tool(name);
119}
120
121/// List all registered tool infos.
122pub fn list_tool_infos() -> Vec<ToolInfo> {
123    registry()
124        .read()
125        .unwrap()
126        .tools
127        .values()
128        .map(|entry| entry.info.clone())
129        .collect()
130}
131
132/// List registerd tool infos filtered by patterns.
133pub fn list_tool_infos_patterns(patterns: &str) -> Result<Vec<ToolInfo>, regex::Error> {
134    // Split patterns by newline and trim whitespace
135    let patterns = patterns
136        .lines()
137        .map(|line| line.trim())
138        .filter(|line| !line.is_empty())
139        .collect::<Vec<&str>>();
140    let reg_set = RegexSet::new(&patterns)?;
141    let tool_names = registry()
142        .read()
143        .unwrap()
144        .tools
145        .values()
146        .filter_map(|entry| {
147            if reg_set.is_match(&entry.info.name) {
148                Some(entry.info.clone())
149            } else {
150                None
151            }
152        })
153        .collect();
154    Ok(tool_names)
155}
156
157/// Get a tool by name.
158pub fn get_tool(name: &str) -> Option<Arc<Box<dyn Tool + Send + Sync>>> {
159    registry().read().unwrap().get_tool(name)
160}
161
162/// Call a tool by name.
163pub async fn call_tool(
164    ctx: AgentContext,
165    name: &str,
166    args: AgentValue,
167) -> Result<AgentValue, AgentError> {
168    let tool = {
169        let guard = registry().read().unwrap();
170        guard.get_tool(name)
171    };
172
173    let Some(tool) = tool else {
174        return Err(AgentError::Other(format!("Tool '{}' not found", name)));
175    };
176
177    tool.call(ctx, args).await
178}
179
180pub async fn call_tools(
181    ctx: &AgentContext,
182    tool_calls: &Vector<ToolCall>,
183) -> Result<Vector<Message>, AgentError> {
184    if tool_calls.is_empty() {
185        return Ok(vector![]);
186    };
187    let mut resp_messages = vec![];
188
189    for call in tool_calls {
190        let args: AgentValue =
191            AgentValue::from_json(call.function.parameters.clone()).map_err(|e| {
192                AgentError::InvalidValue(format!("Failed to parse tool call parameters: {}", e))
193            })?;
194        let tool_resp = call_tool(ctx.clone(), call.function.name.as_str(), args).await?;
195        resp_messages.push(Message::tool(
196            call.function.name.clone(),
197            tool_resp.to_json().to_string(),
198        ));
199    }
200
201    Ok(resp_messages.into())
202}
203
204// Agents
205
206#[askit_agent(
207    title="List Tools",
208    category=CATEGORY,
209    inputs=[PIN_PATTERNS],
210    outputs=[PIN_TOOLS],
211)]
212pub struct ListToolsAgent {
213    data: AgentData,
214}
215
216#[async_trait]
217impl AsAgent for ListToolsAgent {
218    fn new(askit: ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
219        Ok(Self {
220            data: AgentData::new(askit, id, spec),
221        })
222    }
223
224    async fn process(
225        &mut self,
226        ctx: AgentContext,
227        _pin: String,
228        value: AgentValue,
229    ) -> Result<(), AgentError> {
230        let Some(patterns) = value.as_str() else {
231            return Err(AgentError::InvalidValue(
232                "patterns input must be a string".to_string(),
233            ));
234        };
235
236        let tools = if !patterns.is_empty() {
237            list_tool_infos_patterns(patterns)
238                .map_err(|e| AgentError::InvalidValue(format!("Invalid regex patterns: {}", e)))?
239        } else {
240            list_tool_infos()
241        };
242        let tools = tools
243            .into_iter()
244            .map(|tool| tool.into())
245            .collect::<Vector<AgentValue>>();
246        let tools_array = AgentValue::array(tools);
247
248        self.output(ctx, PIN_TOOLS, tools_array).await?;
249
250        Ok(())
251    }
252}
253
254#[askit_agent(
255    title="Stream Tool",
256    category=CATEGORY,
257    inputs=[PIN_TOOL_OUT],
258    outputs=[PIN_TOOL_IN],
259    string_config(name=CONFIG_TOOL_NAME),
260    text_config(name=CONFIG_TOOL_DESCRIPTION),
261    object_config(name=CONFIG_TOOL_PARAMETERS),
262)]
263pub struct StreamToolAgent {
264    data: AgentData,
265    name: String,
266    description: String,
267    parameters: Option<serde_json::Value>,
268    pending: Arc<Mutex<HashMap<usize, oneshot::Sender<AgentValue>>>>,
269}
270
271impl StreamToolAgent {
272    fn start_tool_call(
273        &mut self,
274        ctx: AgentContext,
275        args: AgentValue,
276    ) -> Result<oneshot::Receiver<AgentValue>, AgentError> {
277        let (tx, rx) = oneshot::channel();
278
279        self.pending.lock().unwrap().insert(ctx.id(), tx);
280        self.try_output(ctx.clone(), PIN_TOOL_IN, args)?;
281
282        Ok(rx)
283    }
284}
285
286#[async_trait]
287impl AsAgent for StreamToolAgent {
288    fn new(askit: ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
289        let def_name = spec.def_name.clone();
290        let configs = spec.configs.clone();
291        let name = configs
292            .as_ref()
293            .and_then(|c| c.get_string(CONFIG_TOOL_NAME).ok())
294            .unwrap_or_else(|| def_name.clone());
295        let description = configs
296            .as_ref()
297            .and_then(|c| c.get_string(CONFIG_TOOL_DESCRIPTION).ok())
298            .unwrap_or_default();
299        let parameters = configs
300            .as_ref()
301            .and_then(|c| c.get(CONFIG_TOOL_PARAMETERS).ok())
302            .and_then(|v| serde_json::to_value(v).ok());
303        Ok(Self {
304            data: AgentData::new(askit, id, spec),
305            name,
306            description,
307            parameters,
308            pending: Arc::new(Mutex::new(HashMap::new())),
309        })
310    }
311
312    fn configs_changed(&mut self) -> Result<(), AgentError> {
313        self.name = self.configs()?.get_string_or_default(CONFIG_TOOL_NAME);
314        self.description = self
315            .configs()?
316            .get_string_or_default(CONFIG_TOOL_DESCRIPTION);
317        self.parameters = self
318            .configs()?
319            .get(CONFIG_TOOL_PARAMETERS)
320            .ok()
321            .and_then(|v| serde_json::to_value(v).ok());
322
323        // TODO: update registered tool info
324
325        Ok(())
326    }
327
328    async fn start(&mut self) -> Result<(), AgentError> {
329        let agent_handle = self
330            .askit()
331            .get_agent(self.id())
332            .ok_or_else(|| AgentError::AgentNotFound(self.id().to_string()))?;
333        let tool = StreamTool::new(
334            self.name.clone(),
335            self.description.clone(),
336            self.parameters.clone(),
337            agent_handle,
338        );
339        register_tool(tool);
340        Ok(())
341    }
342
343    async fn stop(&mut self) -> Result<(), AgentError> {
344        unregister_tool(&self.name);
345        self.pending.lock().unwrap().clear();
346        Ok(())
347    }
348
349    async fn process(
350        &mut self,
351        ctx: AgentContext,
352        _pin: String,
353        value: AgentValue,
354    ) -> Result<(), AgentError> {
355        if let Some(tx) = self.pending.lock().unwrap().remove(&ctx.id()) {
356            let _ = tx.send(value);
357        }
358        Ok(())
359    }
360}
361
362struct StreamTool {
363    info: ToolInfo,
364    agent: Arc<AsyncMutex<Box<dyn Agent>>>,
365}
366
367impl StreamTool {
368    fn new(
369        name: String,
370        description: String,
371        parameters: Option<serde_json::Value>,
372        agent: Arc<AsyncMutex<Box<dyn Agent>>>,
373    ) -> Self {
374        Self {
375            info: ToolInfo {
376                name: name,
377                description: description,
378                parameters: parameters,
379            },
380            agent,
381        }
382    }
383
384    async fn tool_call(
385        &self,
386        ctx: AgentContext,
387        args: AgentValue,
388    ) -> Result<AgentValue, AgentError> {
389        // Kick off the tool call while holding the lock, then drop it before awaiting the result
390        let rx = {
391            let mut guard = self.agent.lock().await;
392            let Some(stream_tool_agent) = guard.as_agent_mut::<StreamToolAgent>() else {
393                return Err(AgentError::Other(
394                    "Agent is not StreamToolAgent".to_string(),
395                ));
396            };
397            stream_tool_agent.start_tool_call(ctx, args)?
398        };
399
400        tokio::time::timeout(Duration::from_secs(60), rx)
401            .await
402            .map_err(|_| AgentError::Other("tool_call timed out".to_string()))?
403            .map_err(|_| AgentError::Other("tool_out dropped".to_string()))
404    }
405}
406
407#[async_trait]
408impl Tool for StreamTool {
409    fn info(&self) -> &ToolInfo {
410        &self.info
411    }
412
413    async fn call(&self, ctx: AgentContext, args: AgentValue) -> Result<AgentValue, AgentError> {
414        self.tool_call(ctx, args).await
415    }
416}
417
418// Call Tool Message Agent
419#[askit_agent(
420    title="Call Tool Message",
421    category=CATEGORY,
422    inputs=[PIN_MESSAGE],
423    outputs=[PIN_MESSAGE],
424    string_config(name=CONFIG_TOOLS),
425)]
426pub struct CallToolMessageAgent {
427    data: AgentData,
428}
429
430#[async_trait]
431impl AsAgent for CallToolMessageAgent {
432    fn new(askit: ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
433        Ok(Self {
434            data: AgentData::new(askit, id, spec),
435        })
436    }
437
438    async fn process(
439        &mut self,
440        ctx: AgentContext,
441        _pin: String,
442        value: AgentValue,
443    ) -> Result<(), AgentError> {
444        let Some(message) = value.as_message() else {
445            return Ok(());
446        };
447        let Some(mut tool_calls) = message.tool_calls.clone() else {
448            return Ok(());
449        };
450
451        // Filter tools
452        let config_tools = self.configs()?.get_string_or_default(CONFIG_TOOLS);
453        if !config_tools.is_empty() {
454            let tools = list_tool_infos_patterns(&config_tools)
455                .map_err(|e| AgentError::InvalidValue(format!("Invalid regex patterns: {}", e)))?;
456            // FIXME: cache allowed tool names
457            let allowed_tool_names: HashSet<String> = tools.into_iter().map(|t| t.name).collect();
458            tool_calls = tool_calls
459                .iter()
460                .filter(|call| allowed_tool_names.contains(&call.function.name))
461                .cloned()
462                .collect();
463        }
464
465        let resp_messages = call_tools(&ctx, &tool_calls).await?;
466        for resp_msg in resp_messages {
467            self.output(ctx.clone(), PIN_MESSAGE, AgentValue::message(resp_msg))
468                .await?;
469        }
470        Ok(())
471    }
472}
473
474// Call Tool Agent
475#[askit_agent(
476    title="Call Tool",
477    category=CATEGORY,
478    inputs=[PIN_TOOL_CALL],
479    outputs=[PIN_VALUE],
480)]
481pub struct CallToolAgent {
482    data: AgentData,
483}
484
485#[async_trait]
486impl AsAgent for CallToolAgent {
487    fn new(askit: ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
488        Ok(Self {
489            data: AgentData::new(askit, id, spec),
490        })
491    }
492
493    async fn process(
494        &mut self,
495        ctx: AgentContext,
496        _pin: String,
497        value: AgentValue,
498    ) -> Result<(), AgentError> {
499        let obj = value.as_object().ok_or_else(|| {
500            AgentError::InvalidValue("tool_call input must be an object".to_string())
501        })?;
502        let tool_name = obj.get("name").and_then(|v| v.as_str()).ok_or_else(|| {
503            AgentError::InvalidValue("tool_call.name must be a string".to_string())
504        })?;
505        let tool_parameters = obj.get("parameters").cloned().unwrap_or(AgentValue::unit());
506
507        let resp = call_tool(ctx.clone(), tool_name, tool_parameters).await?;
508        self.output(ctx, PIN_VALUE, resp).await?;
509
510        Ok(())
511    }
512}