Skip to main content

codewhale_tools/
lib.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::Result;
7use async_trait::async_trait;
8use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
12
13tokio::task_local! {
14    static TOOL_EXECUTION_LOCK_HELD: ();
15}
16
17/// Capabilities that a tool may have or require.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum ToolCapability {
20    /// Tool only reads data, never modifies state.
21    ReadOnly,
22    /// Tool writes to the filesystem.
23    WritesFiles,
24    /// Tool executes arbitrary shell commands.
25    ExecutesCode,
26    /// Tool makes network requests.
27    Network,
28    /// Tool can be run in a sandbox.
29    Sandboxable,
30    /// Tool requires user approval before execution.
31    RequiresApproval,
32}
33
34/// Approval requirement for a tool.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
36pub enum ApprovalRequirement {
37    /// Never needs approval: safe read-only operations.
38    #[default]
39    Auto,
40    /// Suggest approval but allow user to skip.
41    Suggest,
42    /// Always require explicit user approval.
43    Required,
44}
45
46/// Errors that can occur during tool execution.
47#[derive(Debug, Clone)]
48pub enum ToolError {
49    InvalidInput { message: String },
50    MissingField { field: String },
51    PathEscape { path: PathBuf },
52    ExecutionFailed { message: String },
53    Timeout { seconds: u64 },
54    NotAvailable { message: String },
55    PermissionDenied { message: String },
56}
57
58impl std::fmt::Display for ToolError {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        match self {
61            Self::InvalidInput { message } => {
62                write!(f, "Failed to validate input: {message}")
63            }
64            Self::MissingField { field } => {
65                write!(
66                    f,
67                    "Failed to validate input: missing required field '{field}'"
68                )
69            }
70            Self::PathEscape { path } => {
71                write!(
72                    f,
73                    "Failed to resolve path '{}': path escapes workspace",
74                    path.display()
75                )
76            }
77            Self::ExecutionFailed { message } => {
78                write!(f, "Failed to execute tool: {message}")
79            }
80            Self::Timeout { seconds } => {
81                write!(
82                    f,
83                    "Failed to execute tool: operation timed out after {seconds}s"
84                )
85            }
86            Self::NotAvailable { message } => {
87                write!(f, "Failed to locate tool: {message}")
88            }
89            Self::PermissionDenied { message } => {
90                write!(f, "Failed to authorize tool execution: {message}")
91            }
92        }
93    }
94}
95
96impl std::error::Error for ToolError {}
97
98impl ToolError {
99    #[must_use]
100    pub fn invalid_input(msg: impl Into<String>) -> Self {
101        Self::InvalidInput {
102            message: msg.into(),
103        }
104    }
105
106    #[must_use]
107    pub fn missing_field(field: impl Into<String>) -> Self {
108        Self::MissingField {
109            field: field.into(),
110        }
111    }
112
113    #[must_use]
114    pub fn execution_failed(msg: impl Into<String>) -> Self {
115        Self::ExecutionFailed {
116            message: msg.into(),
117        }
118    }
119
120    #[must_use]
121    pub fn path_escape(path: impl Into<PathBuf>) -> Self {
122        Self::PathEscape { path: path.into() }
123    }
124
125    #[must_use]
126    pub fn not_available(msg: impl Into<String>) -> Self {
127        Self::NotAvailable {
128            message: msg.into(),
129        }
130    }
131
132    #[must_use]
133    pub fn permission_denied(msg: impl Into<String>) -> Self {
134        Self::PermissionDenied {
135            message: msg.into(),
136        }
137    }
138}
139
140/// Result of a tool execution.
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct ToolResult {
143    /// The output content, which may be JSON or plain text.
144    pub content: String,
145    /// Whether the execution was successful.
146    pub success: bool,
147    /// Optional structured metadata.
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub metadata: Option<Value>,
150}
151
152impl ToolResult {
153    /// Create a successful result with content.
154    #[must_use]
155    pub fn success(content: impl Into<String>) -> Self {
156        Self {
157            content: content.into(),
158            success: true,
159            metadata: None,
160        }
161    }
162
163    /// Create an error result with message.
164    #[must_use]
165    pub fn error(message: impl Into<String>) -> Self {
166        Self {
167            content: message.into(),
168            success: false,
169            metadata: None,
170        }
171    }
172
173    /// Create a successful result from JSON.
174    pub fn json<T: Serialize>(value: &T) -> std::result::Result<Self, serde_json::Error> {
175        Ok(Self {
176            content: serde_json::to_string_pretty(value)?,
177            success: true,
178            metadata: None,
179        })
180    }
181
182    /// Add metadata to the result.
183    #[must_use]
184    pub fn with_metadata(mut self, metadata: Value) -> Self {
185        self.metadata = Some(metadata);
186        self
187    }
188}
189
190/// Helper to extract a required string field from JSON input.
191pub fn required_str<'a>(input: &'a Value, field: &str) -> std::result::Result<&'a str, ToolError> {
192    input.get(field).and_then(Value::as_str).ok_or_else(|| {
193        // When the field is missing, list the fields the caller *did*
194        // supply so the model can spot the mismatch without a retry.
195        let provided: Vec<&str> = input
196            .as_object()
197            .map(|obj| obj.keys().map(|k| k.as_str()).collect())
198            .unwrap_or_default();
199        if provided.is_empty() {
200            ToolError::missing_field(field)
201        } else {
202            let hint = format!(
203                "missing required field '{field}'. Input provided: {}",
204                provided.join(", ")
205            );
206            ToolError::invalid_input(hint)
207        }
208    })
209}
210
211/// Helper to extract an optional string field from JSON input.
212#[must_use]
213pub fn optional_str<'a>(input: &'a Value, field: &str) -> Option<&'a str> {
214    input.get(field).and_then(Value::as_str)
215}
216
217/// Helper to extract a required u64 field from JSON input.
218pub fn required_u64(input: &Value, field: &str) -> std::result::Result<u64, ToolError> {
219    input
220        .get(field)
221        .and_then(Value::as_u64)
222        .ok_or_else(|| ToolError::missing_field(field))
223}
224
225/// Helper to extract an optional u64 field with default.
226#[must_use]
227pub fn optional_u64(input: &Value, field: &str, default: u64) -> u64 {
228    input.get(field).and_then(Value::as_u64).unwrap_or(default)
229}
230
231/// Helper to extract an optional bool field with default.
232#[must_use]
233pub fn optional_bool(input: &Value, field: &str, default: bool) -> bool {
234    input.get(field).and_then(Value::as_bool).unwrap_or(default)
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct ToolSpec {
239    pub name: String,
240    pub input_schema: Value,
241    pub output_schema: Value,
242    pub supports_parallel_tool_calls: bool,
243    pub timeout_ms: Option<u64>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct ConfiguredToolSpec {
248    pub spec: ToolSpec,
249    pub supports_parallel_tool_calls: bool,
250}
251
252#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
253#[serde(rename_all = "snake_case")]
254pub enum ToolCallSource {
255    Direct,
256    JsRepl,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct ToolCall {
261    pub name: String,
262    pub payload: ToolPayload,
263    pub source: ToolCallSource,
264    pub raw_tool_call_id: Option<String>,
265}
266
267impl ToolCall {
268    pub fn execution_subject(&self, fallback_cwd: &str) -> (String, String, &'static str) {
269        match &self.payload {
270            ToolPayload::LocalShell { params } => (
271                params.command.clone(),
272                params
273                    .cwd
274                    .clone()
275                    .unwrap_or_else(|| fallback_cwd.to_string()),
276                "shell",
277            ),
278            _ => (self.name.clone(), fallback_cwd.to_string(), "tool"),
279        }
280    }
281}
282
283#[derive(Debug, Clone)]
284pub struct ToolInvocation {
285    pub call_id: String,
286    pub tool_name: String,
287    pub payload: ToolPayload,
288    pub source: ToolCallSource,
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub enum FunctionCallError {
293    ToolNotFound { name: String },
294    KindMismatch { expected: ToolKind, got: ToolKind },
295    MutatingToolRejected { name: String },
296    TimedOut { name: String, timeout_ms: u64 },
297    Cancelled { name: String },
298    ExecutionFailed { name: String, error: String },
299}
300
301#[async_trait]
302pub trait ToolHandler: Send + Sync {
303    fn kind(&self) -> ToolKind;
304    fn matches_kind(&self, kind: ToolKind) -> bool {
305        self.kind() == kind
306    }
307    fn is_mutating(&self) -> bool {
308        false
309    }
310    async fn handle(
311        &self,
312        invocation: ToolInvocation,
313    ) -> std::result::Result<ToolOutput, FunctionCallError>;
314}
315
316#[derive(Debug)]
317pub struct ToolCallRuntime {
318    /// Preserve read/write tool execution semantics: parallel-safe tools may
319    /// overlap, while serial tools run exclusively.
320    execution_lock: Arc<RwLock<()>>,
321}
322
323impl Default for ToolCallRuntime {
324    fn default() -> Self {
325        Self {
326            execution_lock: Arc::new(RwLock::new(())),
327        }
328    }
329}
330
331#[derive(Debug)]
332enum ToolExecutionGuard {
333    Parallel(#[allow(dead_code)] OwnedRwLockReadGuard<()>),
334    Serial(#[allow(dead_code)] OwnedRwLockWriteGuard<()>),
335    Reentrant,
336}
337
338impl ToolCallRuntime {
339    async fn acquire(&self, supports_parallel: bool) -> ToolExecutionGuard {
340        if TOOL_EXECUTION_LOCK_HELD.try_with(|_| ()).is_ok() {
341            return ToolExecutionGuard::Reentrant;
342        }
343
344        if supports_parallel {
345            ToolExecutionGuard::Parallel(self.execution_lock.clone().read_owned().await)
346        } else {
347            ToolExecutionGuard::Serial(self.execution_lock.clone().write_owned().await)
348        }
349    }
350}
351
352#[derive(Default)]
353pub struct ToolRegistry {
354    handlers: HashMap<String, Arc<dyn ToolHandler>>,
355    specs: HashMap<String, ConfiguredToolSpec>,
356    runtime: ToolCallRuntime,
357}
358
359impl ToolRegistry {
360    pub fn register(&mut self, spec: ToolSpec, handler: Arc<dyn ToolHandler>) -> Result<()> {
361        let name = spec.name.clone();
362        self.specs.insert(
363            name.clone(),
364            ConfiguredToolSpec {
365                supports_parallel_tool_calls: spec.supports_parallel_tool_calls,
366                spec,
367            },
368        );
369        self.handlers.insert(name, handler);
370        Ok(())
371    }
372
373    pub fn list_specs(&self) -> Vec<ConfiguredToolSpec> {
374        self.specs.values().cloned().collect()
375    }
376
377    pub async fn dispatch(
378        &self,
379        call: ToolCall,
380        allow_mutating: bool,
381    ) -> std::result::Result<ToolOutput, FunctionCallError> {
382        let handler = self.handlers.get(&call.name).cloned().ok_or_else(|| {
383            FunctionCallError::ToolNotFound {
384                name: call.name.clone(),
385            }
386        })?;
387        let configured =
388            self.specs
389                .get(&call.name)
390                .cloned()
391                .ok_or_else(|| FunctionCallError::ToolNotFound {
392                    name: call.name.clone(),
393                })?;
394
395        let payload_kind = tool_payload_kind(&call.payload);
396        let expected = handler.kind();
397        if !handler.matches_kind(payload_kind) {
398            return Err(FunctionCallError::KindMismatch {
399                expected,
400                got: payload_kind,
401            });
402        }
403        if handler.is_mutating() && !allow_mutating {
404            return Err(FunctionCallError::MutatingToolRejected { name: call.name });
405        }
406
407        let invocation = ToolInvocation {
408            call_id: call
409                .raw_tool_call_id
410                .clone()
411                .unwrap_or_else(|| format!("tool-call-{}", uuid::Uuid::new_v4())),
412            tool_name: call.name.clone(),
413            payload: call.payload,
414            source: call.source,
415        };
416
417        let _guard = self
418            .runtime
419            .acquire(configured.supports_parallel_tool_calls)
420            .await;
421
422        TOOL_EXECUTION_LOCK_HELD
423            .scope(
424                (),
425                self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation),
426            )
427            .await
428    }
429
430    async fn execute_with_timeout(
431        &self,
432        handler: Arc<dyn ToolHandler>,
433        timeout_ms: Option<u64>,
434        invocation: ToolInvocation,
435    ) -> std::result::Result<ToolOutput, FunctionCallError> {
436        if let Some(timeout_ms) = timeout_ms {
437            let name = invocation.tool_name.clone();
438            match tokio::time::timeout(
439                Duration::from_millis(timeout_ms),
440                handler.handle(invocation),
441            )
442            .await
443            {
444                Ok(result) => result,
445                Err(_) => Err(FunctionCallError::TimedOut { name, timeout_ms }),
446            }
447        } else {
448            handler.handle(invocation).await
449        }
450    }
451}
452
453fn tool_payload_kind(payload: &ToolPayload) -> ToolKind {
454    match payload {
455        ToolPayload::Mcp { .. } => ToolKind::Mcp,
456        ToolPayload::Function { .. }
457        | ToolPayload::Custom { .. }
458        | ToolPayload::LocalShell { .. } => ToolKind::Function,
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use serde_json::json;
465
466    use super::*;
467
468    #[test]
469    fn tool_result_json_round_trips_content() {
470        let result = ToolResult::json(&json!({"ok": true})).expect("json");
471        assert!(result.success);
472        assert!(result.content.contains("\"ok\": true"));
473    }
474
475    #[test]
476    fn helper_extractors_validate_shape() {
477        let input = json!({"name": "demo", "count": 7, "enabled": true});
478        assert_eq!(required_str(&input, "name").expect("name"), "demo");
479        assert_eq!(optional_u64(&input, "count", 0), 7);
480        assert!(optional_bool(&input, "enabled", false));
481        assert!(matches!(
482            required_u64(&input, "name"),
483            Err(ToolError::MissingField { .. })
484        ));
485    }
486
487    #[test]
488    fn required_str_reports_provided_fields_on_missing_required_field() {
489        let input = json!({"path": "src/lib.rs", "content": "new body"});
490        let err = required_str(&input, "replace").expect_err("replace is missing");
491        let message = err.to_string();
492        assert!(message.contains("missing required field 'replace'"));
493        assert!(message.contains("Input provided:"));
494        assert!(message.contains("path"));
495        assert!(message.contains("content"));
496    }
497
498    #[test]
499    fn tool_error_display_matches_legacy_text() {
500        let err = ToolError::missing_field("path");
501        assert_eq!(
502            err.to_string(),
503            "Failed to validate input: missing required field 'path'"
504        );
505    }
506}