Skip to main content

modular_agent_core/
tool.rs

1//! Tool registry and agents for LLM function calling.
2//!
3//! This module provides infrastructure for registering, managing, and invoking tools
4//! that can be called by LLMs. It includes:
5//!
6//! - A global tool registry for registering and looking up tools by name
7//! - The `Tool` trait for implementing custom tools
8//! - Agents for working with tools in workflows:
9//!   - `ListToolsAgent` - Lists available tools matching a pattern
10//!   - `PresetToolAgent` - Exposes a workflow as a callable tool
11//!   - `CallToolMessageAgent` - Processes tool calls from LLM messages
12//!   - `CallToolAgent` - Directly invokes a tool by name
13//! ```
14
15#![cfg(feature = "llm")]
16
17use std::{
18    collections::{BTreeMap, HashMap, HashSet},
19    sync::{Arc, Mutex, OnceLock, RwLock},
20    time::Duration,
21};
22
23use crate::{
24    ModularAgent, Agent, AgentContext, AgentData, AgentError, AgentOutput, AgentSpec, AgentValue, AsAgent,
25    Message, ToolCall, async_trait, modular_agent,
26};
27use im::{Vector, vector};
28use regex::RegexSet;
29use tokio::sync::{Mutex as AsyncMutex, oneshot};
30
31const CATEGORY: &str = "Core/Tool";
32
33const PORT_MESSAGE: &str = "message";
34const PORT_PATTERNS: &str = "patterns";
35const PORT_TOOLS: &str = "tools";
36const PORT_TOOL_CALL: &str = "tool_call";
37const PORT_TOOL_IN: &str = "tool_in";
38const PORT_TOOL_OUT: &str = "tool_out";
39const PORT_VALUE: &str = "value";
40
41const CONFIG_TOOLS: &str = "tools";
42const CONFIG_TOOL_NAME: &str = "name";
43const CONFIG_TOOL_DESCRIPTION: &str = "description";
44const CONFIG_TOOL_PARAMETERS: &str = "parameters";
45
46/// Metadata describing a tool available for LLM function calling.
47///
48/// This information is typically sent to the LLM to describe what tools
49/// are available and how to call them.
50#[derive(Clone, Debug)]
51pub struct ToolInfo {
52    /// Unique name identifying the tool.
53    pub name: String,
54
55    /// Human-readable description of what the tool does.
56    pub description: String,
57
58    /// JSON Schema describing the tool's parameters (optional).
59    pub parameters: Option<serde_json::Value>,
60}
61
62/// Trait for implementing callable tools.
63///
64/// Tools are functions that can be invoked by LLMs during conversations.
65/// Implement this trait to create custom tools that can be registered
66/// with the global tool registry.
67///
68/// # Example
69///
70/// ```ignore
71/// use modular_agent_core::{Tool, ToolInfo, AgentContext, AgentValue, AgentError, async_trait};
72///
73/// struct MyTool {
74///     info: ToolInfo,
75/// }
76///
77/// #[async_trait]
78/// impl Tool for MyTool {
79///     fn info(&self) -> &ToolInfo {
80///         &self.info
81///     }
82///
83///     async fn call(&self, ctx: AgentContext, args: AgentValue) -> Result<AgentValue, AgentError> {
84///         // Tool implementation
85///         Ok(AgentValue::string("result"))
86///     }
87/// }
88/// ```
89#[async_trait]
90pub trait Tool {
91    /// Returns metadata about this tool.
92    fn info(&self) -> &ToolInfo;
93
94    /// Invokes the tool with the given context and arguments.
95    ///
96    /// # Arguments
97    ///
98    /// * `ctx` - The agent context for this invocation
99    /// * `args` - Arguments passed to the tool (typically from LLM)
100    ///
101    /// # Returns
102    ///
103    /// The tool's result as an `AgentValue`, or an error if the call fails.
104    async fn call(&self, ctx: AgentContext, args: AgentValue) -> Result<AgentValue, AgentError>;
105}
106
107impl From<ToolInfo> for AgentValue {
108    fn from(info: ToolInfo) -> Self {
109        let mut obj: BTreeMap<String, AgentValue> = BTreeMap::new();
110        obj.insert("name".to_string(), AgentValue::from(info.name));
111        obj.insert(
112            "description".to_string(),
113            AgentValue::from(info.description),
114        );
115        if let Some(params) = &info.parameters {
116            if let Ok(params_value) = AgentValue::from_serialize(params) {
117                obj.insert("parameters".to_string(), params_value);
118            }
119        }
120        AgentValue::object(obj.into())
121    }
122}
123
124/// Internal entry for a registered tool.
125#[derive(Clone)]
126struct ToolEntry {
127    info: ToolInfo,
128    tool: Arc<Box<dyn Tool + Send + Sync>>,
129}
130
131impl ToolEntry {
132    /// Creates a new tool entry from a tool implementation.
133    fn new<T: Tool + Send + Sync + 'static>(tool: T) -> Self {
134        Self {
135            info: tool.info().clone(),
136            tool: Arc::new(Box::new(tool)),
137        }
138    }
139}
140
141/// Thread-safe registry for managing tools.
142struct ToolRegistry {
143    tools: HashMap<String, ToolEntry>,
144}
145
146impl ToolRegistry {
147    /// Creates a new empty tool registry.
148    fn new() -> Self {
149        Self {
150            tools: HashMap::new(),
151        }
152    }
153
154    /// Registers a tool with the registry.
155    fn register_tool<T: Tool + Send + Sync + 'static>(&mut self, tool: T) {
156        let name = tool.info().name.to_string();
157        let entry = ToolEntry::new(tool);
158        self.tools.insert(name, entry);
159    }
160
161    /// Removes a tool from the registry by name.
162    fn unregister_tool(&mut self, name: &str) {
163        self.tools.remove(name);
164    }
165
166    /// Retrieves a tool by name, if it exists.
167    fn get_tool(&self, name: &str) -> Option<Arc<Box<dyn Tool + Send + Sync>>> {
168        self.tools.get(name).map(|entry| entry.tool.clone())
169    }
170}
171
172/// Global tool registry instance.
173static TOOL_REGISTRY: OnceLock<RwLock<ToolRegistry>> = OnceLock::new();
174
175/// Returns the global tool registry, initializing it if necessary.
176fn registry() -> &'static RwLock<ToolRegistry> {
177    TOOL_REGISTRY.get_or_init(|| RwLock::new(ToolRegistry::new()))
178}
179
180/// Registers a tool with the global registry.
181///
182/// The tool will be available for lookup and invocation by its name.
183/// If a tool with the same name already exists, it will be replaced.
184///
185/// # Arguments
186///
187/// * `tool` - The tool implementation to register
188pub fn register_tool<T: Tool + Send + Sync + 'static>(tool: T) {
189    registry().write().unwrap().register_tool(tool);
190}
191
192/// Removes a tool from the global registry by name.
193///
194/// # Arguments
195///
196/// * `name` - The name of the tool to unregister
197pub fn unregister_tool(name: &str) {
198    registry().write().unwrap().unregister_tool(name);
199}
200
201/// Returns information about all registered tools.
202///
203/// # Returns
204///
205/// A vector of `ToolInfo` for all currently registered tools.
206pub fn list_tool_infos() -> Vec<ToolInfo> {
207    registry()
208        .read()
209        .unwrap()
210        .tools
211        .values()
212        .map(|entry| entry.info.clone())
213        .collect()
214}
215
216/// Returns tool information for tools matching the given regex patterns.
217///
218/// Patterns are newline-separated regular expressions. A tool is included
219/// if its name matches any of the patterns.
220///
221/// # Arguments
222///
223/// * `patterns` - Newline-separated regex patterns to match tool names
224///
225/// # Returns
226///
227/// A vector of `ToolInfo` for tools whose names match the patterns.
228///
229/// # Errors
230///
231/// Returns an error if any of the patterns are invalid regular expressions.
232pub fn list_tool_infos_patterns(patterns: &str) -> Result<Vec<ToolInfo>, regex::Error> {
233    // Split patterns by newline and trim whitespace
234    let patterns = patterns
235        .lines()
236        .map(|line| line.trim())
237        .filter(|line| !line.is_empty())
238        .collect::<Vec<&str>>();
239    let reg_set = RegexSet::new(&patterns)?;
240    let tool_names = registry()
241        .read()
242        .unwrap()
243        .tools
244        .values()
245        .filter_map(|entry| {
246            if reg_set.is_match(&entry.info.name) {
247                Some(entry.info.clone())
248            } else {
249                None
250            }
251        })
252        .collect();
253    Ok(tool_names)
254}
255
256/// Retrieves a tool by name from the global registry.
257///
258/// # Arguments
259///
260/// * `name` - The name of the tool to retrieve
261///
262/// # Returns
263///
264/// The tool if found, or `None` if no tool with that name is registered.
265pub fn get_tool(name: &str) -> Option<Arc<Box<dyn Tool + Send + Sync>>> {
266    registry().read().unwrap().get_tool(name)
267}
268
269/// Invokes a tool by name with the given arguments.
270///
271/// # Arguments
272///
273/// * `ctx` - The agent context for the invocation
274/// * `name` - The name of the tool to call
275/// * `args` - Arguments to pass to the tool
276///
277/// # Returns
278///
279/// The tool's result, or an error if the tool is not found or fails.
280pub async fn call_tool(
281    ctx: AgentContext,
282    name: &str,
283    args: AgentValue,
284) -> Result<AgentValue, AgentError> {
285    let tool = {
286        let guard = registry().read().unwrap();
287        guard.get_tool(name)
288    };
289
290    let Some(tool) = tool else {
291        return Err(AgentError::Other(format!("Tool '{}' not found", name)));
292    };
293
294    tool.call(ctx, args).await
295}
296
297/// Executes multiple tool calls and returns the results as messages.
298///
299/// Processes each tool call sequentially and returns tool response messages
300/// suitable for continuing an LLM conversation.
301///
302/// # Arguments
303///
304/// * `ctx` - The agent context for the invocations
305/// * `tool_calls` - The tool calls to execute
306///
307/// # Returns
308///
309/// A vector of tool response messages, one for each tool call.
310pub async fn call_tools(
311    ctx: &AgentContext,
312    tool_calls: &Vector<ToolCall>,
313) -> Result<Vector<Message>, AgentError> {
314    if tool_calls.is_empty() {
315        return Ok(vector![]);
316    };
317    let mut resp_messages = vec![];
318
319    for call in tool_calls {
320        let args: AgentValue =
321            AgentValue::from_json(call.function.parameters.clone()).map_err(|e| {
322                AgentError::InvalidValue(format!("Failed to parse tool call parameters: {}", e))
323            })?;
324        let tool_resp = call_tool(ctx.clone(), call.function.name.as_str(), args).await?;
325        let mut msg = Message::tool(
326            call.function.name.clone(),
327            tool_resp.to_json().to_string(),
328        );
329        msg.id = call.function.id.clone();
330        resp_messages.push(msg);
331    }
332
333    Ok(resp_messages.into())
334}
335
336// ============================================================================
337// Tool Agents
338// ============================================================================
339
340/// Agent that lists available tools.
341///
342/// Outputs tool information for all registered tools, optionally filtered
343/// by regex patterns provided on the input port.
344///
345/// # Inputs
346///
347/// * `patterns` - Optional regex patterns (newline-separated) to filter tools
348///
349/// # Outputs
350///
351/// * `tools` - Array of tool information objects
352#[modular_agent(
353    title="List Tools",
354    category=CATEGORY,
355    inputs=[PORT_PATTERNS],
356    outputs=[PORT_TOOLS],
357)]
358pub struct ListToolsAgent {
359    data: AgentData,
360}
361
362#[async_trait]
363impl AsAgent for ListToolsAgent {
364    fn new(ma: ModularAgent, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
365        Ok(Self {
366            data: AgentData::new(ma, id, spec),
367        })
368    }
369
370    async fn process(
371        &mut self,
372        ctx: AgentContext,
373        _port: String,
374        value: AgentValue,
375    ) -> Result<(), AgentError> {
376        let Some(patterns) = value.as_str() else {
377            return Err(AgentError::InvalidValue(
378                "patterns input must be a string".to_string(),
379            ));
380        };
381
382        let tools = if !patterns.is_empty() {
383            list_tool_infos_patterns(patterns)
384                .map_err(|e| AgentError::InvalidValue(format!("Invalid regex patterns: {}", e)))?
385        } else {
386            list_tool_infos()
387        };
388        let tools = tools
389            .into_iter()
390            .map(|tool| tool.into())
391            .collect::<Vector<AgentValue>>();
392        let tools_array = AgentValue::array(tools);
393
394        self.output(ctx, PORT_TOOLS, tools_array).await?;
395
396        Ok(())
397    }
398}
399
400/// Agent that exposes a workflow as a callable tool.
401///
402/// This agent registers itself as a tool that can be invoked by LLMs.
403/// When called, it forwards the arguments to the `tool_in` output port
404/// and waits for a response on the `tool_out` input port.
405///
406/// # Configuration
407///
408/// * `name` - The tool name (defaults to agent definition name)
409/// * `description` - Human-readable description of the tool
410/// * `parameters` - JSON Schema describing the tool's parameters
411///
412/// # Ports
413///
414/// * Input `tool_out` - Receives the tool's result from the workflow
415/// * Output `tool_in` - Emits the tool call arguments to the workflow
416#[modular_agent(
417    title="Preset Tool",
418    category=CATEGORY,
419    inputs=[PORT_TOOL_OUT],
420    outputs=[PORT_TOOL_IN],
421    string_config(name=CONFIG_TOOL_NAME),
422    text_config(name=CONFIG_TOOL_DESCRIPTION),
423    object_config(name=CONFIG_TOOL_PARAMETERS),
424)]
425pub struct PresetToolAgent {
426    data: AgentData,
427    name: String,
428    description: String,
429    parameters: Option<serde_json::Value>,
430    /// Pending tool calls awaiting results, keyed by context ID.
431    pending: Arc<Mutex<HashMap<usize, oneshot::Sender<AgentValue>>>>,
432}
433
434impl PresetToolAgent {
435    /// Initiates a tool call and returns a receiver for the result.
436    ///
437    /// Emits the arguments to the workflow and registers a pending receiver
438    /// that will be fulfilled when the result arrives on the input port.
439    fn start_tool_call(
440        &mut self,
441        ctx: AgentContext,
442        args: AgentValue,
443    ) -> Result<oneshot::Receiver<AgentValue>, AgentError> {
444        let (tx, rx) = oneshot::channel();
445
446        self.pending.lock().unwrap().insert(ctx.id(), tx);
447        self.try_output(ctx.clone(), PORT_TOOL_IN, args)?;
448
449        Ok(rx)
450    }
451}
452
453#[async_trait]
454impl AsAgent for PresetToolAgent {
455    fn new(ma: ModularAgent, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
456        let def_name = spec.def_name.clone();
457        let configs = spec.configs.clone();
458        let name = configs
459            .as_ref()
460            .and_then(|c| c.get_string(CONFIG_TOOL_NAME).ok())
461            .unwrap_or_else(|| def_name.clone());
462        let description = configs
463            .as_ref()
464            .and_then(|c| c.get_string(CONFIG_TOOL_DESCRIPTION).ok())
465            .unwrap_or_default();
466        let parameters = configs
467            .as_ref()
468            .and_then(|c| c.get(CONFIG_TOOL_PARAMETERS).ok())
469            .and_then(|v| serde_json::to_value(v).ok());
470        Ok(Self {
471            data: AgentData::new(ma, id, spec),
472            name,
473            description,
474            parameters,
475            pending: Arc::new(Mutex::new(HashMap::new())),
476        })
477    }
478
479    fn configs_changed(&mut self) -> Result<(), AgentError> {
480        self.name = self.configs()?.get_string_or_default(CONFIG_TOOL_NAME);
481        self.description = self
482            .configs()?
483            .get_string_or_default(CONFIG_TOOL_DESCRIPTION);
484        self.parameters = self
485            .configs()?
486            .get(CONFIG_TOOL_PARAMETERS)
487            .ok()
488            .and_then(|v| serde_json::to_value(v).ok());
489
490        // TODO: update registered tool info
491
492        Ok(())
493    }
494
495    async fn start(&mut self) -> Result<(), AgentError> {
496        let agent_handle = self
497            .ma()
498            .get_agent(self.id())
499            .ok_or_else(|| AgentError::AgentNotFound(self.id().to_string()))?;
500        let tool = PresetTool::new(
501            self.name.clone(),
502            self.description.clone(),
503            self.parameters.clone(),
504            agent_handle,
505        );
506        register_tool(tool);
507        Ok(())
508    }
509
510    async fn stop(&mut self) -> Result<(), AgentError> {
511        unregister_tool(&self.name);
512        self.pending.lock().unwrap().clear();
513        Ok(())
514    }
515
516    async fn process(
517        &mut self,
518        ctx: AgentContext,
519        _port: String,
520        value: AgentValue,
521    ) -> Result<(), AgentError> {
522        if let Some(tx) = self.pending.lock().unwrap().remove(&ctx.id()) {
523            let _ = tx.send(value);
524        }
525        Ok(())
526    }
527}
528
529/// Internal Tool implementation that delegates to a PresetToolAgent.
530struct PresetTool {
531    info: ToolInfo,
532    agent: Arc<AsyncMutex<Box<dyn Agent>>>,
533}
534
535impl PresetTool {
536    /// Creates a new PresetTool wrapping a PresetToolAgent.
537    fn new(
538        name: String,
539        description: String,
540        parameters: Option<serde_json::Value>,
541        agent: Arc<AsyncMutex<Box<dyn Agent>>>,
542    ) -> Self {
543        Self {
544            info: ToolInfo {
545                name: name,
546                description: description,
547                parameters: parameters,
548            },
549            agent,
550        }
551    }
552
553    /// Executes a tool call through the wrapped agent.
554    ///
555    /// Times out after 60 seconds if no response is received.
556    async fn tool_call(
557        &self,
558        ctx: AgentContext,
559        args: AgentValue,
560    ) -> Result<AgentValue, AgentError> {
561        // Kick off the tool call while holding the lock, then drop it before awaiting the result
562        let rx = {
563            let mut guard = self.agent.lock().await;
564            let Some(preset_tool_agent) = guard.as_agent_mut::<PresetToolAgent>() else {
565                return Err(AgentError::Other(
566                    "Agent is not PresetToolAgent".to_string(),
567                ));
568            };
569            preset_tool_agent.start_tool_call(ctx, args)?
570        };
571
572        tokio::time::timeout(Duration::from_secs(60), rx)
573            .await
574            .map_err(|_| AgentError::Other("tool_call timed out".to_string()))?
575            .map_err(|_| AgentError::Other("tool_out dropped".to_string()))
576    }
577}
578
579#[async_trait]
580impl Tool for PresetTool {
581    fn info(&self) -> &ToolInfo {
582        &self.info
583    }
584
585    async fn call(&self, ctx: AgentContext, args: AgentValue) -> Result<AgentValue, AgentError> {
586        self.tool_call(ctx, args).await
587    }
588}
589
590/// Agent that processes tool calls from LLM messages.
591///
592/// When an LLM response contains tool calls, this agent executes them
593/// and outputs the results as tool response messages.
594///
595/// # Configuration
596///
597/// * `tools` - Optional regex patterns to filter which tools can be called
598///
599/// # Ports
600///
601/// * Input `message` - LLM message that may contain tool calls
602/// * Output `message` - Tool response messages (one per tool call)
603#[modular_agent(
604    title="Call Tool Message",
605    category=CATEGORY,
606    inputs=[PORT_MESSAGE],
607    outputs=[PORT_MESSAGE],
608    string_config(name=CONFIG_TOOLS),
609)]
610pub struct CallToolMessageAgent {
611    data: AgentData,
612}
613
614#[async_trait]
615impl AsAgent for CallToolMessageAgent {
616    fn new(ma: ModularAgent, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
617        Ok(Self {
618            data: AgentData::new(ma, id, spec),
619        })
620    }
621
622    async fn process(
623        &mut self,
624        ctx: AgentContext,
625        _port: String,
626        value: AgentValue,
627    ) -> Result<(), AgentError> {
628        let Some(message) = value.as_message() else {
629            return Ok(());
630        };
631        let Some(mut tool_calls) = message.tool_calls.clone() else {
632            return Ok(());
633        };
634
635        // Filter tools
636        let config_tools = self.configs()?.get_string_or_default(CONFIG_TOOLS);
637        if !config_tools.is_empty() {
638            let tools = list_tool_infos_patterns(&config_tools)
639                .map_err(|e| AgentError::InvalidValue(format!("Invalid regex patterns: {}", e)))?;
640            // FIXME: cache allowed tool names
641            let allowed_tool_names: HashSet<String> = tools.into_iter().map(|t| t.name).collect();
642            tool_calls = tool_calls
643                .iter()
644                .filter(|call| allowed_tool_names.contains(&call.function.name))
645                .cloned()
646                .collect();
647        }
648
649        let resp_messages = call_tools(&ctx, &tool_calls).await?;
650        for resp_msg in resp_messages {
651            self.output(ctx.clone(), PORT_MESSAGE, AgentValue::message(resp_msg))
652                .await?;
653        }
654        Ok(())
655    }
656}
657
658/// Agent that directly invokes a tool by name.
659///
660/// Takes a tool call specification (name and parameters) and invokes
661/// the corresponding registered tool, outputting the result.
662///
663/// # Ports
664///
665/// * Input `tool_call` - Object with `name` (string) and optional `parameters`
666/// * Output `value` - The tool's return value
667#[modular_agent(
668    title="Call Tool",
669    category=CATEGORY,
670    inputs=[PORT_TOOL_CALL],
671    outputs=[PORT_VALUE],
672)]
673pub struct CallToolAgent {
674    data: AgentData,
675}
676
677#[async_trait]
678impl AsAgent for CallToolAgent {
679    fn new(ma: ModularAgent, id: String, spec: AgentSpec) -> Result<Self, AgentError> {
680        Ok(Self {
681            data: AgentData::new(ma, id, spec),
682        })
683    }
684
685    async fn process(
686        &mut self,
687        ctx: AgentContext,
688        _port: String,
689        value: AgentValue,
690    ) -> Result<(), AgentError> {
691        let obj = value.as_object().ok_or_else(|| {
692            AgentError::InvalidValue("tool_call input must be an object".to_string())
693        })?;
694        let tool_name = obj.get("name").and_then(|v| v.as_str()).ok_or_else(|| {
695            AgentError::InvalidValue("tool_call.name must be a string".to_string())
696        })?;
697        let tool_parameters = obj.get("parameters").cloned().unwrap_or(AgentValue::unit());
698
699        let resp = call_tool(ctx.clone(), tool_name, tool_parameters).await?;
700        self.output(ctx, PORT_VALUE, resp).await?;
701
702        Ok(())
703    }
704}