Skip to main content

codewhale_tools/
lib.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::Result;
7use async_trait::async_trait;
8use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use tokio::sync::RwLock;
12
13/// Capabilities that a tool may have or require.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum ToolCapability {
16    /// Tool only reads data, never modifies state.
17    ReadOnly,
18    /// Tool writes to the filesystem.
19    WritesFiles,
20    /// Tool executes arbitrary shell commands.
21    ExecutesCode,
22    /// Tool makes network requests.
23    Network,
24    /// Tool can be run in a sandbox.
25    Sandboxable,
26    /// Tool requires user approval before execution.
27    RequiresApproval,
28}
29
30/// Approval requirement for a tool.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
32pub enum ApprovalRequirement {
33    /// Never needs approval: safe read-only operations.
34    #[default]
35    Auto,
36    /// Suggest approval but allow user to skip.
37    Suggest,
38    /// Always require explicit user approval.
39    Required,
40}
41
42/// Errors that can occur during tool execution.
43#[derive(Debug, Clone)]
44pub enum ToolError {
45    InvalidInput { message: String },
46    MissingField { field: String },
47    PathEscape { path: PathBuf },
48    ExecutionFailed { message: String },
49    Timeout { seconds: u64 },
50    NotAvailable { message: String },
51    PermissionDenied { message: String },
52}
53
54impl std::fmt::Display for ToolError {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        match self {
57            Self::InvalidInput { message } => {
58                write!(f, "Failed to validate input: {message}")
59            }
60            Self::MissingField { field } => {
61                write!(
62                    f,
63                    "Failed to validate input: missing required field '{field}'"
64                )
65            }
66            Self::PathEscape { path } => {
67                write!(
68                    f,
69                    "Failed to resolve path '{}': path escapes workspace",
70                    path.display()
71                )
72            }
73            Self::ExecutionFailed { message } => {
74                write!(f, "Failed to execute tool: {message}")
75            }
76            Self::Timeout { seconds } => {
77                write!(
78                    f,
79                    "Failed to execute tool: operation timed out after {seconds}s"
80                )
81            }
82            Self::NotAvailable { message } => {
83                write!(f, "Failed to locate tool: {message}")
84            }
85            Self::PermissionDenied { message } => {
86                write!(f, "Failed to authorize tool execution: {message}")
87            }
88        }
89    }
90}
91
92impl std::error::Error for ToolError {}
93
94impl ToolError {
95    #[must_use]
96    pub fn invalid_input(msg: impl Into<String>) -> Self {
97        Self::InvalidInput {
98            message: msg.into(),
99        }
100    }
101
102    #[must_use]
103    pub fn missing_field(field: impl Into<String>) -> Self {
104        Self::MissingField {
105            field: field.into(),
106        }
107    }
108
109    #[must_use]
110    pub fn execution_failed(msg: impl Into<String>) -> Self {
111        Self::ExecutionFailed {
112            message: msg.into(),
113        }
114    }
115
116    #[must_use]
117    pub fn path_escape(path: impl Into<PathBuf>) -> Self {
118        Self::PathEscape { path: path.into() }
119    }
120
121    #[must_use]
122    pub fn not_available(msg: impl Into<String>) -> Self {
123        Self::NotAvailable {
124            message: msg.into(),
125        }
126    }
127
128    #[must_use]
129    pub fn permission_denied(msg: impl Into<String>) -> Self {
130        Self::PermissionDenied {
131            message: msg.into(),
132        }
133    }
134}
135
136/// Result of a tool execution.
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ToolResult {
139    /// The output content, which may be JSON or plain text.
140    pub content: String,
141    /// Whether the execution was successful.
142    pub success: bool,
143    /// Optional structured metadata.
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub metadata: Option<Value>,
146}
147
148impl ToolResult {
149    /// Create a successful result with content.
150    #[must_use]
151    pub fn success(content: impl Into<String>) -> Self {
152        Self {
153            content: content.into(),
154            success: true,
155            metadata: None,
156        }
157    }
158
159    /// Create an error result with message.
160    #[must_use]
161    pub fn error(message: impl Into<String>) -> Self {
162        Self {
163            content: message.into(),
164            success: false,
165            metadata: None,
166        }
167    }
168
169    /// Create a successful result from JSON.
170    pub fn json<T: Serialize>(value: &T) -> std::result::Result<Self, serde_json::Error> {
171        Ok(Self {
172            content: serde_json::to_string_pretty(value)?,
173            success: true,
174            metadata: None,
175        })
176    }
177
178    /// Add metadata to the result.
179    #[must_use]
180    pub fn with_metadata(mut self, metadata: Value) -> Self {
181        self.metadata = Some(metadata);
182        self
183    }
184}
185
186/// Helper to extract a required string field from JSON input.
187pub fn required_str<'a>(input: &'a Value, field: &str) -> std::result::Result<&'a str, ToolError> {
188    input.get(field).and_then(Value::as_str).ok_or_else(|| {
189        // When the field is missing, list the fields the caller *did*
190        // supply so the model can spot the mismatch without a retry.
191        let provided: Vec<&str> = input
192            .as_object()
193            .map(|obj| obj.keys().map(|k| k.as_str()).collect())
194            .unwrap_or_default();
195        if provided.is_empty() {
196            ToolError::missing_field(field)
197        } else {
198            let hint = format!(
199                "missing required field '{field}'. Input provided: {}",
200                provided.join(", ")
201            );
202            ToolError::invalid_input(hint)
203        }
204    })
205}
206
207/// Helper to extract an optional string field from JSON input.
208#[must_use]
209pub fn optional_str<'a>(input: &'a Value, field: &str) -> Option<&'a str> {
210    input.get(field).and_then(Value::as_str)
211}
212
213/// Helper to extract a required u64 field from JSON input.
214pub fn required_u64(input: &Value, field: &str) -> std::result::Result<u64, ToolError> {
215    input
216        .get(field)
217        .and_then(Value::as_u64)
218        .ok_or_else(|| ToolError::missing_field(field))
219}
220
221/// Helper to extract an optional u64 field with default.
222#[must_use]
223pub fn optional_u64(input: &Value, field: &str, default: u64) -> u64 {
224    input.get(field).and_then(Value::as_u64).unwrap_or(default)
225}
226
227/// Helper to extract an optional bool field with default.
228#[must_use]
229pub fn optional_bool(input: &Value, field: &str, default: bool) -> bool {
230    input.get(field).and_then(Value::as_bool).unwrap_or(default)
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct ToolSpec {
235    pub name: String,
236    pub input_schema: Value,
237    pub output_schema: Value,
238    pub supports_parallel_tool_calls: bool,
239    pub timeout_ms: Option<u64>,
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct ConfiguredToolSpec {
244    pub spec: ToolSpec,
245    pub supports_parallel_tool_calls: bool,
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
249#[serde(rename_all = "snake_case")]
250pub enum ToolCallSource {
251    Direct,
252    JsRepl,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct ToolCall {
257    pub name: String,
258    pub payload: ToolPayload,
259    pub source: ToolCallSource,
260    pub raw_tool_call_id: Option<String>,
261}
262
263impl ToolCall {
264    pub fn execution_subject(&self, fallback_cwd: &str) -> (String, String, &'static str) {
265        match &self.payload {
266            ToolPayload::LocalShell { params } => (
267                params.command.clone(),
268                params
269                    .cwd
270                    .clone()
271                    .unwrap_or_else(|| fallback_cwd.to_string()),
272                "shell",
273            ),
274            _ => (self.name.clone(), fallback_cwd.to_string(), "tool"),
275        }
276    }
277}
278
279#[derive(Debug, Clone)]
280pub struct ToolInvocation {
281    pub call_id: String,
282    pub tool_name: String,
283    pub payload: ToolPayload,
284    pub source: ToolCallSource,
285}
286
287#[derive(Debug, Clone, Serialize, Deserialize)]
288pub enum FunctionCallError {
289    ToolNotFound { name: String },
290    KindMismatch { expected: ToolKind, got: ToolKind },
291    MutatingToolRejected { name: String },
292    TimedOut { name: String, timeout_ms: u64 },
293    Cancelled { name: String },
294    ExecutionFailed { name: String, error: String },
295}
296
297#[async_trait]
298pub trait ToolHandler: Send + Sync {
299    fn kind(&self) -> ToolKind;
300    fn matches_kind(&self, kind: ToolKind) -> bool {
301        self.kind() == kind
302    }
303    fn is_mutating(&self) -> bool {
304        false
305    }
306    async fn handle(
307        &self,
308        invocation: ToolInvocation,
309    ) -> std::result::Result<ToolOutput, FunctionCallError>;
310}
311
312#[derive(Debug, Default)]
313pub struct ToolCallRuntime {
314    pub parallel_execution: Arc<RwLock<()>>,
315}
316
317#[derive(Default)]
318pub struct ToolRegistry {
319    handlers: HashMap<String, Arc<dyn ToolHandler>>,
320    specs: HashMap<String, ConfiguredToolSpec>,
321    runtime: ToolCallRuntime,
322}
323
324impl ToolRegistry {
325    pub fn register(&mut self, spec: ToolSpec, handler: Arc<dyn ToolHandler>) -> Result<()> {
326        let name = spec.name.clone();
327        self.specs.insert(
328            name.clone(),
329            ConfiguredToolSpec {
330                supports_parallel_tool_calls: spec.supports_parallel_tool_calls,
331                spec,
332            },
333        );
334        self.handlers.insert(name, handler);
335        Ok(())
336    }
337
338    pub fn list_specs(&self) -> Vec<ConfiguredToolSpec> {
339        self.specs.values().cloned().collect()
340    }
341
342    pub async fn dispatch(
343        &self,
344        call: ToolCall,
345        allow_mutating: bool,
346    ) -> std::result::Result<ToolOutput, FunctionCallError> {
347        let handler = self.handlers.get(&call.name).cloned().ok_or_else(|| {
348            FunctionCallError::ToolNotFound {
349                name: call.name.clone(),
350            }
351        })?;
352        let configured =
353            self.specs
354                .get(&call.name)
355                .cloned()
356                .ok_or_else(|| FunctionCallError::ToolNotFound {
357                    name: call.name.clone(),
358                })?;
359
360        let payload_kind = tool_payload_kind(&call.payload);
361        let expected = handler.kind();
362        if !handler.matches_kind(payload_kind) {
363            return Err(FunctionCallError::KindMismatch {
364                expected,
365                got: payload_kind,
366            });
367        }
368        if handler.is_mutating() && !allow_mutating {
369            return Err(FunctionCallError::MutatingToolRejected { name: call.name });
370        }
371
372        let invocation = ToolInvocation {
373            call_id: call
374                .raw_tool_call_id
375                .clone()
376                .unwrap_or_else(|| format!("tool-call-{}", uuid::Uuid::new_v4())),
377            tool_name: call.name.clone(),
378            payload: call.payload,
379            source: call.source,
380        };
381
382        if configured.supports_parallel_tool_calls {
383            let _guard = self.runtime.parallel_execution.read().await;
384            self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
385                .await
386        } else {
387            let _guard = self.runtime.parallel_execution.write().await;
388            self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
389                .await
390        }
391    }
392
393    async fn execute_with_timeout(
394        &self,
395        handler: Arc<dyn ToolHandler>,
396        timeout_ms: Option<u64>,
397        invocation: ToolInvocation,
398    ) -> std::result::Result<ToolOutput, FunctionCallError> {
399        if let Some(timeout_ms) = timeout_ms {
400            let name = invocation.tool_name.clone();
401            match tokio::time::timeout(
402                Duration::from_millis(timeout_ms),
403                handler.handle(invocation),
404            )
405            .await
406            {
407                Ok(result) => result,
408                Err(_) => Err(FunctionCallError::TimedOut { name, timeout_ms }),
409            }
410        } else {
411            handler.handle(invocation).await
412        }
413    }
414}
415
416fn tool_payload_kind(payload: &ToolPayload) -> ToolKind {
417    match payload {
418        ToolPayload::Mcp { .. } => ToolKind::Mcp,
419        ToolPayload::Function { .. }
420        | ToolPayload::Custom { .. }
421        | ToolPayload::LocalShell { .. } => ToolKind::Function,
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use serde_json::json;
428
429    use super::*;
430
431    #[test]
432    fn tool_result_json_round_trips_content() {
433        let result = ToolResult::json(&json!({"ok": true})).expect("json");
434        assert!(result.success);
435        assert!(result.content.contains("\"ok\": true"));
436    }
437
438    #[test]
439    fn helper_extractors_validate_shape() {
440        let input = json!({"name": "demo", "count": 7, "enabled": true});
441        assert_eq!(required_str(&input, "name").expect("name"), "demo");
442        assert_eq!(optional_u64(&input, "count", 0), 7);
443        assert!(optional_bool(&input, "enabled", false));
444        assert!(matches!(
445            required_u64(&input, "name"),
446            Err(ToolError::MissingField { .. })
447        ));
448    }
449
450    #[test]
451    fn required_str_reports_provided_fields_on_missing_required_field() {
452        let input = json!({"path": "src/lib.rs", "content": "new body"});
453        let err = required_str(&input, "replace").expect_err("replace is missing");
454        let message = err.to_string();
455        assert!(message.contains("missing required field 'replace'"));
456        assert!(message.contains("Input provided:"));
457        assert!(message.contains("path"));
458        assert!(message.contains("content"));
459    }
460
461    #[test]
462    fn tool_error_display_matches_legacy_text() {
463        let err = ToolError::missing_field("path");
464        assert_eq!(
465            err.to_string(),
466            "Failed to validate input: missing required field 'path'"
467        );
468    }
469}