Skip to main content

codewhale_tools/
lib.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::Result;
7use async_trait::async_trait;
8use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
12
13tokio::task_local! {
14    static TOOL_EXECUTION_LOCK_HELD: ();
15}
16
17/// Capabilities that a tool may have or require.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum ToolCapability {
20    /// Tool only reads data, never modifies state.
21    ReadOnly,
22    /// Tool writes to the filesystem.
23    WritesFiles,
24    /// Tool executes arbitrary shell commands.
25    ExecutesCode,
26    /// Tool makes network requests.
27    Network,
28    /// Tool can be run in a sandbox.
29    Sandboxable,
30    /// Tool requires user approval before execution.
31    RequiresApproval,
32}
33
34/// Approval requirement for a tool.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
36pub enum ApprovalRequirement {
37    /// Never needs approval: safe read-only operations.
38    #[default]
39    Auto,
40    /// Suggest approval but allow user to skip.
41    Suggest,
42    /// Always require explicit user approval.
43    Required,
44}
45
46/// Errors that can occur during tool execution.
47#[derive(Debug, Clone, thiserror::Error)]
48pub enum ToolError {
49    #[error("Failed to validate input: {message}")]
50    InvalidInput { message: String },
51    #[error("Failed to validate input: missing required field '{field}'")]
52    MissingField { field: String },
53    #[error("Failed to resolve path '{}': path escapes workspace", path.display())]
54    PathEscape { path: PathBuf },
55    #[error("Failed to execute tool: {message}")]
56    ExecutionFailed { message: String },
57    #[error("Failed to execute tool: operation timed out after {seconds}s")]
58    Timeout { seconds: u64 },
59    #[error("Failed to locate tool: {message}")]
60    NotAvailable { message: String },
61    #[error("Failed to authorize tool execution: {message}")]
62    PermissionDenied { message: String },
63}
64
65impl ToolError {
66    #[must_use]
67    pub fn invalid_input(msg: impl Into<String>) -> Self {
68        Self::InvalidInput {
69            message: msg.into(),
70        }
71    }
72
73    #[must_use]
74    pub fn missing_field(field: impl Into<String>) -> Self {
75        Self::MissingField {
76            field: field.into(),
77        }
78    }
79
80    #[must_use]
81    pub fn execution_failed(msg: impl Into<String>) -> Self {
82        Self::ExecutionFailed {
83            message: msg.into(),
84        }
85    }
86
87    #[must_use]
88    pub fn path_escape(path: impl Into<PathBuf>) -> Self {
89        Self::PathEscape { path: path.into() }
90    }
91
92    #[must_use]
93    pub fn not_available(msg: impl Into<String>) -> Self {
94        Self::NotAvailable {
95            message: msg.into(),
96        }
97    }
98
99    #[must_use]
100    pub fn permission_denied(msg: impl Into<String>) -> Self {
101        Self::PermissionDenied {
102            message: msg.into(),
103        }
104    }
105}
106
107/// Result of a tool execution.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ToolResult {
110    /// The output content, which may be JSON or plain text.
111    pub content: String,
112    /// Whether the execution was successful.
113    pub success: bool,
114    /// Optional structured metadata.
115    #[serde(skip_serializing_if = "Option::is_none")]
116    pub metadata: Option<Value>,
117}
118
119impl ToolResult {
120    /// Create a successful result with content.
121    #[must_use]
122    pub fn success(content: impl Into<String>) -> Self {
123        Self {
124            content: content.into(),
125            success: true,
126            metadata: None,
127        }
128    }
129
130    /// Create an error result with message.
131    #[must_use]
132    pub fn error(message: impl Into<String>) -> Self {
133        Self {
134            content: message.into(),
135            success: false,
136            metadata: None,
137        }
138    }
139
140    /// Create a successful result from JSON.
141    pub fn json<T: Serialize>(value: &T) -> std::result::Result<Self, serde_json::Error> {
142        Ok(Self {
143            content: serde_json::to_string_pretty(value)?,
144            success: true,
145            metadata: None,
146        })
147    }
148
149    /// Add metadata to the result.
150    #[must_use]
151    pub fn with_metadata(mut self, metadata: Value) -> Self {
152        self.metadata = Some(metadata);
153        self
154    }
155}
156
157/// Helper to extract a required string field from JSON input.
158pub fn required_str<'a>(input: &'a Value, field: &str) -> std::result::Result<&'a str, ToolError> {
159    input.get(field).and_then(Value::as_str).ok_or_else(|| {
160        // When the field is missing, list the fields the caller *did*
161        // supply so the model can spot the mismatch without a retry.
162        let provided: Vec<&str> = input
163            .as_object()
164            .map(|obj| obj.keys().map(|k| k.as_str()).collect())
165            .unwrap_or_default();
166        if provided.is_empty() {
167            ToolError::missing_field(field)
168        } else {
169            let hint = format!(
170                "missing required field '{field}'. Input provided: {}",
171                provided.join(", ")
172            );
173            ToolError::invalid_input(hint)
174        }
175    })
176}
177
178/// Helper to extract an optional string field from JSON input.
179#[must_use]
180pub fn optional_str<'a>(input: &'a Value, field: &str) -> Option<&'a str> {
181    input.get(field).and_then(Value::as_str)
182}
183
184/// Helper to extract a required u64 field from JSON input.
185pub fn required_u64(input: &Value, field: &str) -> std::result::Result<u64, ToolError> {
186    input
187        .get(field)
188        .and_then(Value::as_u64)
189        .ok_or_else(|| ToolError::missing_field(field))
190}
191
192/// Helper to extract an optional u64 field with default.
193#[must_use]
194pub fn optional_u64(input: &Value, field: &str, default: u64) -> u64 {
195    input.get(field).and_then(Value::as_u64).unwrap_or(default)
196}
197
198/// Helper to extract an optional bool field with default.
199#[must_use]
200pub fn optional_bool(input: &Value, field: &str, default: bool) -> bool {
201    input.get(field).and_then(Value::as_bool).unwrap_or(default)
202}
203
204/// Specification that describes a tool available in the registry.
205///
206/// Contains the tool's name, its JSON input/output schemas, and
207/// execution constraints such as timeout and parallelism.
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct ToolSpec {
210    /// Unique name used to look up the tool.
211    pub name: String,
212    /// JSON Schema describing the tool's expected input parameters.
213    pub input_schema: Value,
214    /// JSON Schema describing the tool's output format.
215    pub output_schema: Value,
216    /// Whether multiple invocations of this tool may run concurrently.
217    pub supports_parallel_tool_calls: bool,
218    /// Optional per-call timeout in milliseconds; `None` means no timeout.
219    pub timeout_ms: Option<u64>,
220}
221
222/// A [`ToolSpec`] together with its runtime configuration.
223///
224/// Wraps a `ToolSpec` and exposes the parallelism flag directly so the
225/// dispatcher can check it without digging into the inner spec.
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct ConfiguredToolSpec {
228    /// The underlying tool specification.
229    pub spec: ToolSpec,
230    /// Whether this tool supports concurrent invocations.
231    pub supports_parallel_tool_calls: bool,
232}
233
234/// Identifies where a tool call originated from.
235#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
236#[serde(rename_all = "snake_case")]
237pub enum ToolCallSource {
238    /// Direct invocation from the model or user.
239    Direct,
240    /// Invocation through the JavaScript REPL environment.
241    JsRepl,
242}
243
244/// A tool invocation request before it has been validated and dispatched.
245///
246/// Contains the tool name, its input payload, and metadata about where the
247/// call originated.
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct ToolCall {
250    /// Name of the tool to invoke.
251    pub name: String,
252    /// The input payload for the tool.
253    pub payload: ToolPayload,
254    /// Where this call originated (direct or REPL).
255    pub source: ToolCallSource,
256    /// Optional raw tool-call identifier from the upstream provider.
257    pub raw_tool_call_id: Option<String>,
258}
259
260impl ToolCall {
261    /// Derive the execution subject for this call.
262    ///
263    /// For local shell payloads this returns the shell command and its
264    /// working directory; for all other payloads the tool name and the
265    /// provided `fallback_cwd` are returned instead. The third element
266    /// of the tuple is a human-readable kind label (`"shell"` or `"tool"`).
267    pub fn execution_subject(&self, fallback_cwd: &str) -> (String, String, &'static str) {
268        match &self.payload {
269            ToolPayload::LocalShell { params } => (
270                params.command.clone(),
271                params
272                    .cwd
273                    .clone()
274                    .unwrap_or_else(|| fallback_cwd.to_string()),
275                "shell",
276            ),
277            _ => (self.name.clone(), fallback_cwd.to_string(), "tool"),
278        }
279    }
280}
281
282/// A validated tool invocation ready to be handled.
283///
284/// Created by the registry after a [`ToolCall`] passes validation, this
285/// carries all the context a [`ToolHandler`] needs to execute the tool.
286#[derive(Debug, Clone)]
287pub struct ToolInvocation {
288    /// Unique identifier for this invocation (generated or from the provider).
289    pub call_id: String,
290    /// Name of the tool being invoked.
291    pub tool_name: String,
292    /// The input payload for the tool.
293    pub payload: ToolPayload,
294    /// Where this invocation originated.
295    pub source: ToolCallSource,
296}
297
298/// Errors that can occur during tool dispatch and execution.
299///
300/// Unlike [`ToolError`], which represents input validation failures within
301/// a tool, `FunctionCallError` covers problems at the dispatch layer: the
302/// tool was not found, its kind did not match, it was rejected because it
303/// is mutating, it timed out, was cancelled, or its handler returned an
304/// error.
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub enum FunctionCallError {
307    /// No tool with the given name is registered.
308    ToolNotFound { name: String },
309    /// The payload kind does not match the handler's expected kind.
310    KindMismatch { expected: ToolKind, got: ToolKind },
311    /// The tool is mutating but `allow_mutating` was `false`.
312    MutatingToolRejected { name: String },
313    /// The tool execution exceeded its configured timeout.
314    TimedOut { name: String, timeout_ms: u64 },
315    /// The tool execution was cancelled.
316    Cancelled { name: String },
317    /// The tool handler returned an error.
318    ExecutionFailed { name: String, error: String },
319}
320
321/// Trait implemented by concrete tool handlers.
322///
323/// Each registered tool is backed by a handler that reports its kind,
324/// whether it is mutating, and performs the actual execution.
325#[async_trait]
326pub trait ToolHandler: Send + Sync {
327    /// The [`ToolKind`] this handler expects (e.g. `Function` or `Mcp`).
328    fn kind(&self) -> ToolKind;
329
330    /// Returns `true` if `kind` matches this handler's expected kind.
331    ///
332    /// The default implementation compares against [`kind()`](ToolHandler::kind).
333    fn matches_kind(&self, kind: ToolKind) -> bool {
334        self.kind() == kind
335    }
336
337    /// Whether this tool performs side-effects that require user approval.
338    ///
339    /// Defaults to `false` (read-only / safe).
340    fn is_mutating(&self) -> bool {
341        false
342    }
343
344    /// Execute the tool with the given invocation context.
345    async fn handle(
346        &self,
347        invocation: ToolInvocation,
348    ) -> std::result::Result<ToolOutput, FunctionCallError>;
349}
350
351/// Manages concurrent tool execution via a read/write lock.
352///
353/// Parallel-safe tools acquire a read lock (allowing overlap), while
354/// serial tools acquire a write lock (exclusive access). Reentrant calls
355/// (e.g. a tool invoking another tool) skip locking to avoid deadlock.
356#[derive(Debug)]
357pub struct ToolCallRuntime {
358    execution_lock: Arc<RwLock<()>>,
359}
360
361impl Default for ToolCallRuntime {
362    fn default() -> Self {
363        Self {
364            execution_lock: Arc::new(RwLock::new(())),
365        }
366    }
367}
368
369#[derive(Debug)]
370enum ToolExecutionGuard {
371    Parallel(#[allow(dead_code)] OwnedRwLockReadGuard<()>),
372    Serial(#[allow(dead_code)] OwnedRwLockWriteGuard<()>),
373    Reentrant,
374}
375
376impl ToolCallRuntime {
377    async fn acquire(&self, supports_parallel: bool) -> ToolExecutionGuard {
378        if TOOL_EXECUTION_LOCK_HELD.try_with(|_| ()).is_ok() {
379            return ToolExecutionGuard::Reentrant;
380        }
381
382        if supports_parallel {
383            ToolExecutionGuard::Parallel(self.execution_lock.clone().read_owned().await)
384        } else {
385            ToolExecutionGuard::Serial(self.execution_lock.clone().write_owned().await)
386        }
387    }
388}
389
390/// Central registry that maps tool names to their specs and handlers.
391///
392/// Use [`register()`](ToolRegistry::register) to add tools, then
393/// [`dispatch()`](ToolRegistry::dispatch) to invoke them. The registry
394/// owns a [`ToolCallRuntime`] that manages concurrent execution.
395#[derive(Default)]
396pub struct ToolRegistry {
397    handlers: HashMap<String, Arc<dyn ToolHandler>>,
398    specs: HashMap<String, ConfiguredToolSpec>,
399    runtime: ToolCallRuntime,
400}
401
402impl ToolRegistry {
403    /// Register a tool with its specification and handler.
404    ///
405    /// The tool's name is taken from `spec.name`. Returns an error if
406    /// registration fails (currently infallible, but the `Result` is
407    /// reserved for future validation).
408    pub fn register(&mut self, spec: ToolSpec, handler: Arc<dyn ToolHandler>) -> Result<()> {
409        let name = spec.name.clone();
410        self.specs.insert(
411            name.clone(),
412            ConfiguredToolSpec {
413                supports_parallel_tool_calls: spec.supports_parallel_tool_calls,
414                spec,
415            },
416        );
417        self.handlers.insert(name, handler);
418        Ok(())
419    }
420
421    /// Return the configured specs for every registered tool.
422    pub fn list_specs(&self) -> Vec<ConfiguredToolSpec> {
423        self.specs.values().cloned().collect()
424    }
425
426    /// Validate and execute a tool call.
427    ///
428    /// Looks up the tool by name, verifies the payload kind matches the
429    /// handler, enforces the `allow_mutating` guard, acquires the
430    /// appropriate execution lock, and forwards the call to the handler.
431    /// Returns a [`FunctionCallError`] if any validation step fails or
432    /// the handler returns an error.
433    pub async fn dispatch(
434        &self,
435        call: ToolCall,
436        allow_mutating: bool,
437    ) -> std::result::Result<ToolOutput, FunctionCallError> {
438        let handler = self.handlers.get(&call.name).cloned().ok_or_else(|| {
439            FunctionCallError::ToolNotFound {
440                name: call.name.clone(),
441            }
442        })?;
443        let configured =
444            self.specs
445                .get(&call.name)
446                .cloned()
447                .ok_or_else(|| FunctionCallError::ToolNotFound {
448                    name: call.name.clone(),
449                })?;
450
451        let payload_kind = tool_payload_kind(&call.payload);
452        let expected = handler.kind();
453        if !handler.matches_kind(payload_kind) {
454            return Err(FunctionCallError::KindMismatch {
455                expected,
456                got: payload_kind,
457            });
458        }
459        if handler.is_mutating() && !allow_mutating {
460            return Err(FunctionCallError::MutatingToolRejected { name: call.name });
461        }
462
463        let invocation = ToolInvocation {
464            call_id: call
465                .raw_tool_call_id
466                .clone()
467                .unwrap_or_else(|| format!("tool-call-{}", uuid::Uuid::new_v4())),
468            tool_name: call.name.clone(),
469            payload: call.payload,
470            source: call.source,
471        };
472
473        let _guard = self
474            .runtime
475            .acquire(configured.supports_parallel_tool_calls)
476            .await;
477
478        TOOL_EXECUTION_LOCK_HELD
479            .scope(
480                (),
481                self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation),
482            )
483            .await
484    }
485
486    async fn execute_with_timeout(
487        &self,
488        handler: Arc<dyn ToolHandler>,
489        timeout_ms: Option<u64>,
490        invocation: ToolInvocation,
491    ) -> std::result::Result<ToolOutput, FunctionCallError> {
492        if let Some(timeout_ms) = timeout_ms {
493            let name = invocation.tool_name.clone();
494            match tokio::time::timeout(
495                Duration::from_millis(timeout_ms),
496                handler.handle(invocation),
497            )
498            .await
499            {
500                Ok(result) => result,
501                Err(_) => Err(FunctionCallError::TimedOut { name, timeout_ms }),
502            }
503        } else {
504            handler.handle(invocation).await
505        }
506    }
507}
508
509fn tool_payload_kind(payload: &ToolPayload) -> ToolKind {
510    match payload {
511        ToolPayload::Mcp { .. } => ToolKind::Mcp,
512        ToolPayload::Function { .. }
513        | ToolPayload::Custom { .. }
514        | ToolPayload::LocalShell { .. } => ToolKind::Function,
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use serde_json::json;
521
522    use super::*;
523
524    #[test]
525    fn tool_result_json_round_trips_content() {
526        let result = ToolResult::json(&json!({"ok": true})).expect("json");
527        assert!(result.success);
528        assert!(result.content.contains("\"ok\": true"));
529    }
530
531    #[test]
532    fn helper_extractors_validate_shape() {
533        let input = json!({"name": "demo", "count": 7, "enabled": true});
534        assert_eq!(required_str(&input, "name").expect("name"), "demo");
535        assert_eq!(optional_u64(&input, "count", 0), 7);
536        assert!(optional_bool(&input, "enabled", false));
537        assert!(matches!(
538            required_u64(&input, "name"),
539            Err(ToolError::MissingField { .. })
540        ));
541    }
542
543    #[test]
544    fn required_str_reports_provided_fields_on_missing_required_field() {
545        let input = json!({"path": "src/lib.rs", "content": "new body"});
546        let err = required_str(&input, "replace").expect_err("replace is missing");
547        let message = err.to_string();
548        assert!(message.contains("missing required field 'replace'"));
549        assert!(message.contains("Input provided:"));
550        assert!(message.contains("path"));
551        assert!(message.contains("content"));
552    }
553
554    #[test]
555    fn tool_error_display_matches_legacy_text() {
556        let err = ToolError::missing_field("path");
557        assert_eq!(
558            err.to_string(),
559            "Failed to validate input: missing required field 'path'"
560        );
561    }
562}