Skip to main content

agent_base/tool/
mod.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::{json, Value};
6use tokio::sync::broadcast;
7
8use crate::llm::LlmClient;
9use crate::types::{AgentResult, AgentError, AgentEvent, SessionId};
10use crate::engine::SessionStore;
11
12pub mod mcp;
13pub mod policy;
14pub mod subagent;
15
16pub use mcp::{McpClient, McpToolInfo, McpToolRegistry};
17pub use subagent::{SubAgentSessionPolicy, SubAgentTool};
18
19pub use policy::ToolPolicy;
20
21#[derive(Clone, Debug, Default)]
22pub struct ToolOutput {
23    pub summary: String,
24    pub raw: Option<Value>,
25    pub control_flow: ToolControlFlow,
26    pub truncated: bool,
27}
28
29#[derive(Clone, Debug, Default)]
30pub enum ToolControlFlow {
31    #[default]
32    Break,
33    Continue,
34}
35
36#[derive(Clone)]
37pub struct ToolContext {
38    pub session_id: SessionId,
39    pub event_bus: broadcast::Sender<AgentEvent>,
40    pub llm_client: Option<Arc<dyn LlmClient>>,
41    pub session_store: Option<Arc<dyn SessionStore>>,
42}
43
44#[async_trait]
45pub trait Tool: Send + Sync {
46    fn name(&self) -> &'static str;
47    fn definition(&self) -> Value;
48    async fn call(&self, args: &Value, ctx: &ToolContext) -> AgentResult<ToolOutput>;
49}
50
51#[async_trait]
52pub trait TypedTool: Send + Sync {
53    type Args: serde::de::DeserializeOwned;
54    type Output: serde::Serialize;
55
56    fn name(&self) -> &'static str;
57    fn description(&self) -> &'static str;
58    fn parameters_schema(&self) -> Value;
59    async fn call_typed(&self, args: Self::Args, ctx: &ToolContext) -> AgentResult<Self::Output>;
60
61    fn control_flow() -> ToolControlFlow
62    where
63        Self: Sized,
64    {
65        ToolControlFlow::Break
66    }
67
68    fn format_output(&self, output: Self::Output) -> String {
69        serde_json::to_string(&output).unwrap_or_default()
70    }
71}
72
73#[async_trait]
74impl<T: TypedTool + Send + Sync + 'static> Tool for T {
75    fn name(&self) -> &'static str {
76        TypedTool::name(self)
77    }
78
79    fn definition(&self) -> Value {
80        json!({
81            "type": "function",
82            "function": {
83                "name": self.name(),
84                "description": self.description(),
85                "parameters": self.parameters_schema(),
86            }
87        })
88    }
89
90    async fn call(&self, args: &Value, ctx: &ToolContext) -> AgentResult<ToolOutput> {
91        let typed_args: T::Args = serde_json::from_value(args.clone())
92            .map_err(|_| AgentError::ToolArgsInvalid {
93                name: self.name().to_string(),
94                raw: args.to_string(),
95            })?;
96        let output = self.call_typed(typed_args, ctx).await?;
97        let output_json = serde_json::to_value(&output).ok();
98        let summary = self.format_output(output);
99        Ok(ToolOutput {
100            summary,
101            raw: output_json,
102            control_flow: T::control_flow(),
103            truncated: false,
104        })
105    }
106}
107
108pub(crate) type ToolRef = Arc<dyn Tool>;
109
110#[derive(Clone, Default)]
111pub struct ToolRegistry {
112    tools: HashMap<String, ToolRef>,
113}
114
115impl ToolRegistry {
116    pub fn register(&mut self, tool: impl Tool + 'static) {
117        self.tools.insert(tool.name().to_string(), Arc::new(tool));
118    }
119
120    pub fn register_arc(&mut self, tool: Arc<dyn Tool>) {
121        self.tools.insert(tool.name().to_string(), tool);
122    }
123
124    pub fn get(&self, name: &str) -> Option<ToolRef> {
125        self.tools.get(name).cloned()
126    }
127
128    pub fn definitions(&self) -> Vec<Value> {
129        self.tools.values().map(|tool| tool.definition()).collect()
130    }
131
132    pub fn len(&self) -> usize {
133        self.tools.len()
134    }
135
136    pub fn is_empty(&self) -> bool {
137        self.tools.is_empty()
138    }
139}