Skip to main content

deepseek_tools/
lib.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::Result;
6use async_trait::async_trait;
7use deepseek_protocol::{ToolKind, ToolOutput, ToolPayload};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ToolSpec {
14    pub name: String,
15    pub input_schema: Value,
16    pub output_schema: Value,
17    pub supports_parallel_tool_calls: bool,
18    pub timeout_ms: Option<u64>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ConfiguredToolSpec {
23    pub spec: ToolSpec,
24    pub supports_parallel_tool_calls: bool,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "snake_case")]
29pub enum ToolCallSource {
30    Direct,
31    JsRepl,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ToolCall {
36    pub name: String,
37    pub payload: ToolPayload,
38    pub source: ToolCallSource,
39    pub raw_tool_call_id: Option<String>,
40}
41
42impl ToolCall {
43    pub fn execution_subject(&self, fallback_cwd: &str) -> (String, String, &'static str) {
44        match &self.payload {
45            ToolPayload::LocalShell { params } => (
46                params.command.clone(),
47                params
48                    .cwd
49                    .clone()
50                    .unwrap_or_else(|| fallback_cwd.to_string()),
51                "shell",
52            ),
53            _ => (self.name.clone(), fallback_cwd.to_string(), "tool"),
54        }
55    }
56}
57
58#[derive(Debug, Clone)]
59pub struct ToolInvocation {
60    pub call_id: String,
61    pub tool_name: String,
62    pub payload: ToolPayload,
63    pub source: ToolCallSource,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub enum FunctionCallError {
68    ToolNotFound { name: String },
69    KindMismatch { expected: ToolKind, got: ToolKind },
70    MutatingToolRejected { name: String },
71    TimedOut { name: String, timeout_ms: u64 },
72    Cancelled { name: String },
73    ExecutionFailed { name: String, error: String },
74}
75
76#[async_trait]
77pub trait ToolHandler: Send + Sync {
78    fn kind(&self) -> ToolKind;
79    fn matches_kind(&self, kind: ToolKind) -> bool {
80        self.kind() == kind
81    }
82    fn is_mutating(&self) -> bool {
83        false
84    }
85    async fn handle(
86        &self,
87        invocation: ToolInvocation,
88    ) -> std::result::Result<ToolOutput, FunctionCallError>;
89}
90
91#[derive(Debug, Default)]
92pub struct ToolCallRuntime {
93    pub parallel_execution: Arc<RwLock<()>>,
94}
95
96#[derive(Default)]
97pub struct ToolRegistry {
98    handlers: HashMap<String, Arc<dyn ToolHandler>>,
99    specs: HashMap<String, ConfiguredToolSpec>,
100    runtime: ToolCallRuntime,
101}
102
103impl ToolRegistry {
104    pub fn register(&mut self, spec: ToolSpec, handler: Arc<dyn ToolHandler>) -> Result<()> {
105        let name = spec.name.clone();
106        self.specs.insert(
107            name.clone(),
108            ConfiguredToolSpec {
109                supports_parallel_tool_calls: spec.supports_parallel_tool_calls,
110                spec,
111            },
112        );
113        self.handlers.insert(name, handler);
114        Ok(())
115    }
116
117    pub fn list_specs(&self) -> Vec<ConfiguredToolSpec> {
118        self.specs.values().cloned().collect()
119    }
120
121    pub async fn dispatch(
122        &self,
123        call: ToolCall,
124        allow_mutating: bool,
125    ) -> std::result::Result<ToolOutput, FunctionCallError> {
126        let handler = self.handlers.get(&call.name).cloned().ok_or_else(|| {
127            FunctionCallError::ToolNotFound {
128                name: call.name.clone(),
129            }
130        })?;
131        let configured =
132            self.specs
133                .get(&call.name)
134                .cloned()
135                .ok_or_else(|| FunctionCallError::ToolNotFound {
136                    name: call.name.clone(),
137                })?;
138
139        let payload_kind = tool_payload_kind(&call.payload);
140        let expected = handler.kind();
141        if !handler.matches_kind(payload_kind) {
142            return Err(FunctionCallError::KindMismatch {
143                expected,
144                got: payload_kind,
145            });
146        }
147        if handler.is_mutating() && !allow_mutating {
148            return Err(FunctionCallError::MutatingToolRejected { name: call.name });
149        }
150
151        let invocation = ToolInvocation {
152            call_id: call
153                .raw_tool_call_id
154                .clone()
155                .unwrap_or_else(|| format!("tool-call-{}", uuid::Uuid::new_v4())),
156            tool_name: call.name.clone(),
157            payload: call.payload,
158            source: call.source,
159        };
160
161        if configured.supports_parallel_tool_calls {
162            let _guard = self.runtime.parallel_execution.read().await;
163            self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
164                .await
165        } else {
166            let _guard = self.runtime.parallel_execution.write().await;
167            self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
168                .await
169        }
170    }
171
172    async fn execute_with_timeout(
173        &self,
174        handler: Arc<dyn ToolHandler>,
175        timeout_ms: Option<u64>,
176        invocation: ToolInvocation,
177    ) -> std::result::Result<ToolOutput, FunctionCallError> {
178        if let Some(timeout_ms) = timeout_ms {
179            let name = invocation.tool_name.clone();
180            match tokio::time::timeout(
181                Duration::from_millis(timeout_ms),
182                handler.handle(invocation),
183            )
184            .await
185            {
186                Ok(result) => result,
187                Err(_) => Err(FunctionCallError::TimedOut { name, timeout_ms }),
188            }
189        } else {
190            handler.handle(invocation).await
191        }
192    }
193}
194
195fn tool_payload_kind(payload: &ToolPayload) -> ToolKind {
196    match payload {
197        ToolPayload::Mcp { .. } => ToolKind::Mcp,
198        ToolPayload::Function { .. }
199        | ToolPayload::Custom { .. }
200        | ToolPayload::LocalShell { .. } => ToolKind::Function,
201    }
202}