1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::Result;
6use async_trait::async_trait;
7use deepseek_protocol::{ToolKind, ToolOutput, ToolPayload};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ToolSpec {
14 pub name: String,
15 pub input_schema: Value,
16 pub output_schema: Value,
17 pub supports_parallel_tool_calls: bool,
18 pub timeout_ms: Option<u64>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ConfiguredToolSpec {
23 pub spec: ToolSpec,
24 pub supports_parallel_tool_calls: bool,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "snake_case")]
29pub enum ToolCallSource {
30 Direct,
31 JsRepl,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ToolCall {
36 pub name: String,
37 pub payload: ToolPayload,
38 pub source: ToolCallSource,
39 pub raw_tool_call_id: Option<String>,
40}
41
42impl ToolCall {
43 pub fn execution_subject(&self, fallback_cwd: &str) -> (String, String, &'static str) {
44 match &self.payload {
45 ToolPayload::LocalShell { params } => (
46 params.command.clone(),
47 params
48 .cwd
49 .clone()
50 .unwrap_or_else(|| fallback_cwd.to_string()),
51 "shell",
52 ),
53 _ => (self.name.clone(), fallback_cwd.to_string(), "tool"),
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
59pub struct ToolInvocation {
60 pub call_id: String,
61 pub tool_name: String,
62 pub payload: ToolPayload,
63 pub source: ToolCallSource,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub enum FunctionCallError {
68 ToolNotFound { name: String },
69 KindMismatch { expected: ToolKind, got: ToolKind },
70 MutatingToolRejected { name: String },
71 TimedOut { name: String, timeout_ms: u64 },
72 Cancelled { name: String },
73 ExecutionFailed { name: String, error: String },
74}
75
76#[async_trait]
77pub trait ToolHandler: Send + Sync {
78 fn kind(&self) -> ToolKind;
79 fn matches_kind(&self, kind: ToolKind) -> bool {
80 self.kind() == kind
81 }
82 fn is_mutating(&self) -> bool {
83 false
84 }
85 async fn handle(
86 &self,
87 invocation: ToolInvocation,
88 ) -> std::result::Result<ToolOutput, FunctionCallError>;
89}
90
91#[derive(Debug, Default)]
92pub struct ToolCallRuntime {
93 pub parallel_execution: Arc<RwLock<()>>,
94}
95
96#[derive(Default)]
97pub struct ToolRegistry {
98 handlers: HashMap<String, Arc<dyn ToolHandler>>,
99 specs: HashMap<String, ConfiguredToolSpec>,
100 runtime: ToolCallRuntime,
101}
102
103impl ToolRegistry {
104 pub fn register(&mut self, spec: ToolSpec, handler: Arc<dyn ToolHandler>) -> Result<()> {
105 let name = spec.name.clone();
106 self.specs.insert(
107 name.clone(),
108 ConfiguredToolSpec {
109 supports_parallel_tool_calls: spec.supports_parallel_tool_calls,
110 spec,
111 },
112 );
113 self.handlers.insert(name, handler);
114 Ok(())
115 }
116
117 pub fn list_specs(&self) -> Vec<ConfiguredToolSpec> {
118 self.specs.values().cloned().collect()
119 }
120
121 pub async fn dispatch(
122 &self,
123 call: ToolCall,
124 allow_mutating: bool,
125 ) -> std::result::Result<ToolOutput, FunctionCallError> {
126 let handler = self.handlers.get(&call.name).cloned().ok_or_else(|| {
127 FunctionCallError::ToolNotFound {
128 name: call.name.clone(),
129 }
130 })?;
131 let configured =
132 self.specs
133 .get(&call.name)
134 .cloned()
135 .ok_or_else(|| FunctionCallError::ToolNotFound {
136 name: call.name.clone(),
137 })?;
138
139 let payload_kind = tool_payload_kind(&call.payload);
140 let expected = handler.kind();
141 if !handler.matches_kind(payload_kind) {
142 return Err(FunctionCallError::KindMismatch {
143 expected,
144 got: payload_kind,
145 });
146 }
147 if handler.is_mutating() && !allow_mutating {
148 return Err(FunctionCallError::MutatingToolRejected { name: call.name });
149 }
150
151 let invocation = ToolInvocation {
152 call_id: call
153 .raw_tool_call_id
154 .clone()
155 .unwrap_or_else(|| format!("tool-call-{}", uuid::Uuid::new_v4())),
156 tool_name: call.name.clone(),
157 payload: call.payload,
158 source: call.source,
159 };
160
161 if configured.supports_parallel_tool_calls {
162 let _guard = self.runtime.parallel_execution.read().await;
163 self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
164 .await
165 } else {
166 let _guard = self.runtime.parallel_execution.write().await;
167 self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
168 .await
169 }
170 }
171
172 async fn execute_with_timeout(
173 &self,
174 handler: Arc<dyn ToolHandler>,
175 timeout_ms: Option<u64>,
176 invocation: ToolInvocation,
177 ) -> std::result::Result<ToolOutput, FunctionCallError> {
178 if let Some(timeout_ms) = timeout_ms {
179 let name = invocation.tool_name.clone();
180 match tokio::time::timeout(
181 Duration::from_millis(timeout_ms),
182 handler.handle(invocation),
183 )
184 .await
185 {
186 Ok(result) => result,
187 Err(_) => Err(FunctionCallError::TimedOut { name, timeout_ms }),
188 }
189 } else {
190 handler.handle(invocation).await
191 }
192 }
193}
194
195fn tool_payload_kind(payload: &ToolPayload) -> ToolKind {
196 match payload {
197 ToolPayload::Mcp { .. } => ToolKind::Mcp,
198 ToolPayload::Function { .. }
199 | ToolPayload::Custom { .. }
200 | ToolPayload::LocalShell { .. } => ToolKind::Function,
201 }
202}