1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::{json, Value};
6use tokio::sync::broadcast;
7
8use crate::llm::LlmClient;
9use crate::types::{AgentResult, AgentError, AgentEvent, SessionId};
10use crate::engine::SessionStore;
11
12pub mod mcp;
13pub mod policy;
14pub mod subagent;
15
16pub use mcp::{McpClient, McpToolInfo, McpToolRegistry};
17pub use subagent::{SubAgentSessionPolicy, SubAgentTool};
18
19pub use policy::ToolPolicy;
20
21#[derive(Clone, Debug, Default)]
22pub struct ToolOutput {
23 pub summary: String,
24 pub raw: Option<Value>,
25 pub control_flow: ToolControlFlow,
26 pub truncated: bool,
27}
28
29#[derive(Clone, Debug, Default)]
30pub enum ToolControlFlow {
31 #[default]
32 Break,
33 Continue,
34}
35
36#[derive(Clone)]
37pub struct ToolContext {
38 pub session_id: SessionId,
39 pub event_bus: broadcast::Sender<AgentEvent>,
40 pub llm_client: Option<Arc<dyn LlmClient>>,
41 pub session_store: Option<Arc<dyn SessionStore>>,
42}
43
44#[async_trait]
45pub trait Tool: Send + Sync {
46 fn name(&self) -> &'static str;
47 fn definition(&self) -> Value;
48 async fn call(&self, args: &Value, ctx: &ToolContext) -> AgentResult<ToolOutput>;
49}
50
51#[async_trait]
52pub trait TypedTool: Send + Sync {
53 type Args: serde::de::DeserializeOwned;
54 type Output: serde::Serialize;
55
56 fn name(&self) -> &'static str;
57 fn description(&self) -> &'static str;
58 fn parameters_schema(&self) -> Value;
59 async fn call_typed(&self, args: Self::Args, ctx: &ToolContext) -> AgentResult<Self::Output>;
60
61 fn control_flow() -> ToolControlFlow
62 where
63 Self: Sized,
64 {
65 ToolControlFlow::Break
66 }
67
68 fn format_output(&self, output: Self::Output) -> String {
69 serde_json::to_string(&output).unwrap_or_default()
70 }
71}
72
73#[async_trait]
74impl<T: TypedTool + Send + Sync + 'static> Tool for T {
75 fn name(&self) -> &'static str {
76 TypedTool::name(self)
77 }
78
79 fn definition(&self) -> Value {
80 json!({
81 "type": "function",
82 "function": {
83 "name": self.name(),
84 "description": self.description(),
85 "parameters": self.parameters_schema(),
86 }
87 })
88 }
89
90 async fn call(&self, args: &Value, ctx: &ToolContext) -> AgentResult<ToolOutput> {
91 let typed_args: T::Args = serde_json::from_value(args.clone())
92 .map_err(|_| AgentError::ToolArgsInvalid {
93 name: self.name().to_string(),
94 raw: args.to_string(),
95 })?;
96 let output = self.call_typed(typed_args, ctx).await?;
97 let output_json = serde_json::to_value(&output).ok();
98 let summary = self.format_output(output);
99 Ok(ToolOutput {
100 summary,
101 raw: output_json,
102 control_flow: T::control_flow(),
103 truncated: false,
104 })
105 }
106}
107
108pub(crate) type ToolRef = Arc<dyn Tool>;
109
110#[derive(Clone, Default)]
111pub struct ToolRegistry {
112 tools: HashMap<String, ToolRef>,
113}
114
115impl ToolRegistry {
116 pub fn register(&mut self, tool: impl Tool + 'static) {
117 self.tools.insert(tool.name().to_string(), Arc::new(tool));
118 }
119
120 pub fn register_arc(&mut self, tool: Arc<dyn Tool>) {
121 self.tools.insert(tool.name().to_string(), tool);
122 }
123
124 pub fn get(&self, name: &str) -> Option<ToolRef> {
125 self.tools.get(name).cloned()
126 }
127
128 pub fn definitions(&self) -> Vec<Value> {
129 self.tools.values().map(|tool| tool.definition()).collect()
130 }
131
132 pub fn len(&self) -> usize {
133 self.tools.len()
134 }
135
136 pub fn is_empty(&self) -> bool {
137 self.tools.is_empty()
138 }
139}