1use std::{
2 collections::{BTreeMap, HashMap, HashSet},
3 sync::{Arc, Mutex, OnceLock, RwLock},
4 time::Duration,
5};
6
7use crate::{
8 ASKit, Agent, AgentContext, AgentData, AgentError, AgentOutput, AgentSpec, AgentValue, AsAgent,
9 Message, ToolCall, askit_agent, async_trait,
10};
11use im::{Vector, vector};
12use regex::RegexSet;
13use tokio::sync::{Mutex as AsyncMutex, oneshot};
14
15const CATEGORY: &str = "Core/Tool";
16
17const PIN_MESSAGE: &str = "message";
18const PIN_PATTERNS: &str = "patterns";
19const PIN_TOOLS: &str = "tools";
20const PIN_TOOL_CALL: &str = "tool_call";
21const PIN_TOOL_IN: &str = "tool_in";
22const PIN_TOOL_OUT: &str = "tool_out";
23const PIN_VALUE: &str = "value";
24
25const CONFIG_TOOLS: &str = "tools";
26const CONFIG_TOOL_NAME: &str = "name";
27const CONFIG_TOOL_DESCRIPTION: &str = "description";
28const CONFIG_TOOL_PARAMETERS: &str = "parameters";
29
30#[derive(Clone, Debug)]
31pub struct ToolInfo {
32 pub name: String,
33 pub description: String,
34 pub parameters: Option<serde_json::Value>,
35}
36
37#[async_trait]
39pub trait Tool {
40 fn info(&self) -> &ToolInfo;
41
42 async fn call(&self, ctx: AgentContext, args: AgentValue) -> Result<AgentValue, AgentError>;
44}
45
46impl From<ToolInfo> for AgentValue {
47 fn from(info: ToolInfo) -> Self {
48 let mut obj: BTreeMap<String, AgentValue> = BTreeMap::new();
49 obj.insert("name".to_string(), AgentValue::from(info.name));
50 obj.insert(
51 "description".to_string(),
52 AgentValue::from(info.description),
53 );
54 if let Some(params) = &info.parameters {
55 if let Ok(params_value) = AgentValue::from_serialize(params) {
56 obj.insert("parameters".to_string(), params_value);
57 }
58 }
59 AgentValue::object(obj.into())
60 }
61}
62
63#[derive(Clone)]
64struct ToolEntry {
65 info: ToolInfo,
66 tool: Arc<Box<dyn Tool + Send + Sync>>,
67}
68
69impl ToolEntry {
70 fn new<T: Tool + Send + Sync + 'static>(tool: T) -> Self {
71 Self {
72 info: tool.info().clone(),
73 tool: Arc::new(Box::new(tool)),
74 }
75 }
76}
77
78struct ToolRegistry {
79 tools: HashMap<String, ToolEntry>,
80}
81
82impl ToolRegistry {
83 fn new() -> Self {
84 Self {
85 tools: HashMap::new(),
86 }
87 }
88
89 fn register_tool<T: Tool + Send + Sync + 'static>(&mut self, tool: T) {
90 let name = tool.info().name.to_string();
91 let entry = ToolEntry::new(tool);
92 self.tools.insert(name, entry);
93 }
94
95 fn unregister_tool(&mut self, name: &str) {
96 self.tools.remove(name);
97 }
98
99 fn get_tool(&self, name: &str) -> Option<Arc<Box<dyn Tool + Send + Sync>>> {
100 self.tools.get(name).map(|entry| entry.tool.clone())
101 }
102}
103
104static TOOL_REGISTRY: OnceLock<RwLock<ToolRegistry>> = OnceLock::new();
106
107fn registry() -> &'static RwLock<ToolRegistry> {
108 TOOL_REGISTRY.get_or_init(|| RwLock::new(ToolRegistry::new()))
109}
110
111pub fn register_tool<T: Tool + Send + Sync + 'static>(tool: T) {
113 registry().write().unwrap().register_tool(tool);
114}
115
116pub fn unregister_tool(name: &str) {
118 registry().write().unwrap().unregister_tool(name);
119}
120
121pub fn list_tool_infos() -> Vec<ToolInfo> {
123 registry()
124 .read()
125 .unwrap()
126 .tools
127 .values()
128 .map(|entry| entry.info.clone())
129 .collect()
130}
131
132pub fn list_tool_infos_patterns(patterns: &str) -> Result<Vec<ToolInfo>, regex::Error> {
134 let patterns = patterns
136 .lines()
137 .map(|line| line.trim())
138 .filter(|line| !line.is_empty())
139 .collect::<Vec<&str>>();
140 let reg_set = RegexSet::new(&patterns)?;
141 let tool_names = registry()
142 .read()
143 .unwrap()
144 .tools
145 .values()
146 .filter_map(|entry| {
147 if reg_set.is_match(&entry.info.name) {
148 Some(entry.info.clone())
149 } else {
150 None
151 }
152 })
153 .collect();
154 Ok(tool_names)
155}
156
157pub fn get_tool(name: &str) -> Option<Arc<Box<dyn Tool + Send + Sync>>> {
159 registry().read().unwrap().get_tool(name)
160}
161
162pub async fn call_tool(
164 ctx: AgentContext,
165 name: &str,
166 args: AgentValue,
167) -> Result<AgentValue, AgentError> {
168 let tool = {
169 let guard = registry().read().unwrap();
170 guard.get_tool(name)
171 };
172
173 let Some(tool) = tool else {
174 return Err(AgentError::Other(format!("Tool '{}' not found", name)));
175 };
176
177 tool.call(ctx, args).await
178}
179
180pub async fn call_tools(
181 ctx: &AgentContext,
182 tool_calls: &Vector<ToolCall>,
183) -> Result<Vector<Message>, AgentError> {
184 if tool_calls.is_empty() {
185 return Ok(vector![]);
186 };
187 let mut resp_messages = vec![];
188
189 for call in tool_calls {
190 let args: AgentValue =
191 AgentValue::from_json(call.function.parameters.clone()).map_err(|e| {
192 AgentError::InvalidValue(format!("Failed to parse tool call parameters: {}", e))
193 })?;
194 let tool_resp = call_tool(ctx.clone(), call.function.name.as_str(), args).await?;
195 resp_messages.push(Message::tool(
196 call.function.name.clone(),
197 tool_resp.to_json().to_string(),
198 ));
199 }
200
201 Ok(resp_messages.into())
202}
203
204#[askit_agent(
207 title="List Tools",
208 category=CATEGORY,
209 inputs=[PIN_PATTERNS],
210 outputs=[PIN_TOOLS],
211)]
212pub struct ListToolsAgent {
213 data: AgentData,
214}
215
216#[async_trait]
217impl AsAgent for ListToolsAgent {
218 fn new(askit: ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
219 Ok(Self {
220 data: AgentData::new(askit, id, spec),
221 })
222 }
223
224 async fn process(
225 &mut self,
226 ctx: AgentContext,
227 _pin: String,
228 value: AgentValue,
229 ) -> Result<(), AgentError> {
230 let Some(patterns) = value.as_str() else {
231 return Err(AgentError::InvalidValue(
232 "patterns input must be a string".to_string(),
233 ));
234 };
235
236 let tools = if !patterns.is_empty() {
237 list_tool_infos_patterns(patterns)
238 .map_err(|e| AgentError::InvalidValue(format!("Invalid regex patterns: {}", e)))?
239 } else {
240 list_tool_infos()
241 };
242 let tools = tools
243 .into_iter()
244 .map(|tool| tool.into())
245 .collect::<Vector<AgentValue>>();
246 let tools_array = AgentValue::array(tools);
247
248 self.output(ctx, PIN_TOOLS, tools_array).await?;
249
250 Ok(())
251 }
252}
253
254#[askit_agent(
255 title="Stream Tool",
256 category=CATEGORY,
257 inputs=[PIN_TOOL_OUT],
258 outputs=[PIN_TOOL_IN],
259 string_config(name=CONFIG_TOOL_NAME),
260 text_config(name=CONFIG_TOOL_DESCRIPTION),
261 object_config(name=CONFIG_TOOL_PARAMETERS),
262)]
263pub struct StreamToolAgent {
264 data: AgentData,
265 name: String,
266 description: String,
267 parameters: Option<serde_json::Value>,
268 pending: Arc<Mutex<HashMap<usize, oneshot::Sender<AgentValue>>>>,
269}
270
271impl StreamToolAgent {
272 fn start_tool_call(
273 &mut self,
274 ctx: AgentContext,
275 args: AgentValue,
276 ) -> Result<oneshot::Receiver<AgentValue>, AgentError> {
277 let (tx, rx) = oneshot::channel();
278
279 self.pending.lock().unwrap().insert(ctx.id(), tx);
280 self.try_output(ctx.clone(), PIN_TOOL_IN, args)?;
281
282 Ok(rx)
283 }
284}
285
286#[async_trait]
287impl AsAgent for StreamToolAgent {
288 fn new(askit: ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
289 let def_name = spec.def_name.clone();
290 let configs = spec.configs.clone();
291 let name = configs
292 .as_ref()
293 .and_then(|c| c.get_string(CONFIG_TOOL_NAME).ok())
294 .unwrap_or_else(|| def_name.clone());
295 let description = configs
296 .as_ref()
297 .and_then(|c| c.get_string(CONFIG_TOOL_DESCRIPTION).ok())
298 .unwrap_or_default();
299 let parameters = configs
300 .as_ref()
301 .and_then(|c| c.get(CONFIG_TOOL_PARAMETERS).ok())
302 .and_then(|v| serde_json::to_value(v).ok());
303 Ok(Self {
304 data: AgentData::new(askit, id, spec),
305 name,
306 description,
307 parameters,
308 pending: Arc::new(Mutex::new(HashMap::new())),
309 })
310 }
311
312 fn configs_changed(&mut self) -> Result<(), AgentError> {
313 self.name = self.configs()?.get_string_or_default(CONFIG_TOOL_NAME);
314 self.description = self
315 .configs()?
316 .get_string_or_default(CONFIG_TOOL_DESCRIPTION);
317 self.parameters = self
318 .configs()?
319 .get(CONFIG_TOOL_PARAMETERS)
320 .ok()
321 .and_then(|v| serde_json::to_value(v).ok());
322
323 Ok(())
326 }
327
328 async fn start(&mut self) -> Result<(), AgentError> {
329 let agent_handle = self
330 .askit()
331 .get_agent(self.id())
332 .ok_or_else(|| AgentError::AgentNotFound(self.id().to_string()))?;
333 let tool = StreamTool::new(
334 self.name.clone(),
335 self.description.clone(),
336 self.parameters.clone(),
337 agent_handle,
338 );
339 register_tool(tool);
340 Ok(())
341 }
342
343 async fn stop(&mut self) -> Result<(), AgentError> {
344 unregister_tool(&self.name);
345 self.pending.lock().unwrap().clear();
346 Ok(())
347 }
348
349 async fn process(
350 &mut self,
351 ctx: AgentContext,
352 _pin: String,
353 value: AgentValue,
354 ) -> Result<(), AgentError> {
355 if let Some(tx) = self.pending.lock().unwrap().remove(&ctx.id()) {
356 let _ = tx.send(value);
357 }
358 Ok(())
359 }
360}
361
362struct StreamTool {
363 info: ToolInfo,
364 agent: Arc<AsyncMutex<Box<dyn Agent>>>,
365}
366
367impl StreamTool {
368 fn new(
369 name: String,
370 description: String,
371 parameters: Option<serde_json::Value>,
372 agent: Arc<AsyncMutex<Box<dyn Agent>>>,
373 ) -> Self {
374 Self {
375 info: ToolInfo {
376 name: name,
377 description: description,
378 parameters: parameters,
379 },
380 agent,
381 }
382 }
383
384 async fn tool_call(
385 &self,
386 ctx: AgentContext,
387 args: AgentValue,
388 ) -> Result<AgentValue, AgentError> {
389 let rx = {
391 let mut guard = self.agent.lock().await;
392 let Some(stream_tool_agent) = guard.as_agent_mut::<StreamToolAgent>() else {
393 return Err(AgentError::Other(
394 "Agent is not StreamToolAgent".to_string(),
395 ));
396 };
397 stream_tool_agent.start_tool_call(ctx, args)?
398 };
399
400 tokio::time::timeout(Duration::from_secs(60), rx)
401 .await
402 .map_err(|_| AgentError::Other("tool_call timed out".to_string()))?
403 .map_err(|_| AgentError::Other("tool_out dropped".to_string()))
404 }
405}
406
407#[async_trait]
408impl Tool for StreamTool {
409 fn info(&self) -> &ToolInfo {
410 &self.info
411 }
412
413 async fn call(&self, ctx: AgentContext, args: AgentValue) -> Result<AgentValue, AgentError> {
414 self.tool_call(ctx, args).await
415 }
416}
417
418#[askit_agent(
420 title="Call Tool Message",
421 category=CATEGORY,
422 inputs=[PIN_MESSAGE],
423 outputs=[PIN_MESSAGE],
424 string_config(name=CONFIG_TOOLS),
425)]
426pub struct CallToolMessageAgent {
427 data: AgentData,
428}
429
430#[async_trait]
431impl AsAgent for CallToolMessageAgent {
432 fn new(askit: ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
433 Ok(Self {
434 data: AgentData::new(askit, id, spec),
435 })
436 }
437
438 async fn process(
439 &mut self,
440 ctx: AgentContext,
441 _pin: String,
442 value: AgentValue,
443 ) -> Result<(), AgentError> {
444 let Some(message) = value.as_message() else {
445 return Ok(());
446 };
447 let Some(mut tool_calls) = message.tool_calls.clone() else {
448 return Ok(());
449 };
450
451 let config_tools = self.configs()?.get_string_or_default(CONFIG_TOOLS);
453 if !config_tools.is_empty() {
454 let tools = list_tool_infos_patterns(&config_tools)
455 .map_err(|e| AgentError::InvalidValue(format!("Invalid regex patterns: {}", e)))?;
456 let allowed_tool_names: HashSet<String> = tools.into_iter().map(|t| t.name).collect();
458 tool_calls = tool_calls
459 .iter()
460 .filter(|call| allowed_tool_names.contains(&call.function.name))
461 .cloned()
462 .collect();
463 }
464
465 let resp_messages = call_tools(&ctx, &tool_calls).await?;
466 for resp_msg in resp_messages {
467 self.output(ctx.clone(), PIN_MESSAGE, AgentValue::message(resp_msg))
468 .await?;
469 }
470 Ok(())
471 }
472}
473
474#[askit_agent(
476 title="Call Tool",
477 category=CATEGORY,
478 inputs=[PIN_TOOL_CALL],
479 outputs=[PIN_VALUE],
480)]
481pub struct CallToolAgent {
482 data: AgentData,
483}
484
485#[async_trait]
486impl AsAgent for CallToolAgent {
487 fn new(askit: ASKit, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
488 Ok(Self {
489 data: AgentData::new(askit, id, spec),
490 })
491 }
492
493 async fn process(
494 &mut self,
495 ctx: AgentContext,
496 _pin: String,
497 value: AgentValue,
498 ) -> Result<(), AgentError> {
499 let obj = value.as_object().ok_or_else(|| {
500 AgentError::InvalidValue("tool_call input must be an object".to_string())
501 })?;
502 let tool_name = obj.get("name").and_then(|v| v.as_str()).ok_or_else(|| {
503 AgentError::InvalidValue("tool_call.name must be a string".to_string())
504 })?;
505 let tool_parameters = obj.get("parameters").cloned().unwrap_or(AgentValue::unit());
506
507 let resp = call_tool(ctx.clone(), tool_name, tool_parameters).await?;
508 self.output(ctx, PIN_VALUE, resp).await?;
509
510 Ok(())
511 }
512}