1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::Result;
7use async_trait::async_trait;
8use codewhale_protocol::{ToolKind, ToolOutput, ToolPayload};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
12
13tokio::task_local! {
14 static TOOL_EXECUTION_LOCK_HELD: ();
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum ToolCapability {
20 ReadOnly,
22 WritesFiles,
24 ExecutesCode,
26 Network,
28 Sandboxable,
30 RequiresApproval,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
36pub enum ApprovalRequirement {
37 #[default]
39 Auto,
40 Suggest,
42 Required,
44}
45
46#[derive(Debug, Clone, thiserror::Error)]
48pub enum ToolError {
49 #[error("Failed to validate input: {message}")]
50 InvalidInput { message: String },
51 #[error("Failed to validate input: missing required field '{field}'")]
52 MissingField { field: String },
53 #[error("Failed to resolve path '{}': path escapes workspace", path.display())]
54 PathEscape { path: PathBuf },
55 #[error("Failed to execute tool: {message}")]
56 ExecutionFailed { message: String },
57 #[error("Failed to execute tool: operation timed out after {seconds}s")]
58 Timeout { seconds: u64 },
59 #[error("Failed to locate tool: {message}")]
60 NotAvailable { message: String },
61 #[error("Failed to authorize tool execution: {message}")]
62 PermissionDenied { message: String },
63}
64
65impl ToolError {
66 #[must_use]
67 pub fn invalid_input(msg: impl Into<String>) -> Self {
68 Self::InvalidInput {
69 message: msg.into(),
70 }
71 }
72
73 #[must_use]
74 pub fn missing_field(field: impl Into<String>) -> Self {
75 Self::MissingField {
76 field: field.into(),
77 }
78 }
79
80 #[must_use]
81 pub fn execution_failed(msg: impl Into<String>) -> Self {
82 Self::ExecutionFailed {
83 message: msg.into(),
84 }
85 }
86
87 #[must_use]
88 pub fn path_escape(path: impl Into<PathBuf>) -> Self {
89 Self::PathEscape { path: path.into() }
90 }
91
92 #[must_use]
93 pub fn not_available(msg: impl Into<String>) -> Self {
94 Self::NotAvailable {
95 message: msg.into(),
96 }
97 }
98
99 #[must_use]
100 pub fn permission_denied(msg: impl Into<String>) -> Self {
101 Self::PermissionDenied {
102 message: msg.into(),
103 }
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ToolResult {
110 pub content: String,
112 pub success: bool,
114 #[serde(skip_serializing_if = "Option::is_none")]
116 pub metadata: Option<Value>,
117}
118
119impl ToolResult {
120 #[must_use]
122 pub fn success(content: impl Into<String>) -> Self {
123 Self {
124 content: content.into(),
125 success: true,
126 metadata: None,
127 }
128 }
129
130 #[must_use]
132 pub fn error(message: impl Into<String>) -> Self {
133 Self {
134 content: message.into(),
135 success: false,
136 metadata: None,
137 }
138 }
139
140 pub fn json<T: Serialize>(value: &T) -> std::result::Result<Self, serde_json::Error> {
142 Ok(Self {
143 content: serde_json::to_string_pretty(value)?,
144 success: true,
145 metadata: None,
146 })
147 }
148
149 #[must_use]
151 pub fn with_metadata(mut self, metadata: Value) -> Self {
152 self.metadata = Some(metadata);
153 self
154 }
155}
156
157pub fn required_str<'a>(input: &'a Value, field: &str) -> std::result::Result<&'a str, ToolError> {
159 input.get(field).and_then(Value::as_str).ok_or_else(|| {
160 let provided: Vec<&str> = input
163 .as_object()
164 .map(|obj| obj.keys().map(|k| k.as_str()).collect())
165 .unwrap_or_default();
166 if provided.is_empty() {
167 ToolError::missing_field(field)
168 } else {
169 let hint = format!(
170 "missing required field '{field}'. Input provided: {}",
171 provided.join(", ")
172 );
173 ToolError::invalid_input(hint)
174 }
175 })
176}
177
178#[must_use]
180pub fn optional_str<'a>(input: &'a Value, field: &str) -> Option<&'a str> {
181 input.get(field).and_then(Value::as_str)
182}
183
184pub fn required_u64(input: &Value, field: &str) -> std::result::Result<u64, ToolError> {
186 input
187 .get(field)
188 .and_then(Value::as_u64)
189 .ok_or_else(|| ToolError::missing_field(field))
190}
191
192#[must_use]
194pub fn optional_u64(input: &Value, field: &str, default: u64) -> u64 {
195 input.get(field).and_then(Value::as_u64).unwrap_or(default)
196}
197
198#[must_use]
200pub fn optional_bool(input: &Value, field: &str, default: bool) -> bool {
201 input.get(field).and_then(Value::as_bool).unwrap_or(default)
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct ToolSpec {
210 pub name: String,
212 pub input_schema: Value,
214 pub output_schema: Value,
216 pub supports_parallel_tool_calls: bool,
218 pub timeout_ms: Option<u64>,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct ConfiguredToolSpec {
228 pub spec: ToolSpec,
230 pub supports_parallel_tool_calls: bool,
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
236#[serde(rename_all = "snake_case")]
237pub enum ToolCallSource {
238 Direct,
240 JsRepl,
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct ToolCall {
250 pub name: String,
252 pub payload: ToolPayload,
254 pub source: ToolCallSource,
256 pub raw_tool_call_id: Option<String>,
258}
259
260impl ToolCall {
261 pub fn execution_subject(&self, fallback_cwd: &str) -> (String, String, &'static str) {
268 match &self.payload {
269 ToolPayload::LocalShell { params } => (
270 params.command.clone(),
271 params
272 .cwd
273 .clone()
274 .unwrap_or_else(|| fallback_cwd.to_string()),
275 "shell",
276 ),
277 _ => (self.name.clone(), fallback_cwd.to_string(), "tool"),
278 }
279 }
280}
281
282#[derive(Debug, Clone)]
287pub struct ToolInvocation {
288 pub call_id: String,
290 pub tool_name: String,
292 pub payload: ToolPayload,
294 pub source: ToolCallSource,
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
306pub enum FunctionCallError {
307 ToolNotFound { name: String },
309 KindMismatch { expected: ToolKind, got: ToolKind },
311 MutatingToolRejected { name: String },
313 TimedOut { name: String, timeout_ms: u64 },
315 Cancelled { name: String },
317 ExecutionFailed { name: String, error: String },
319}
320
321#[async_trait]
326pub trait ToolHandler: Send + Sync {
327 fn kind(&self) -> ToolKind;
329
330 fn matches_kind(&self, kind: ToolKind) -> bool {
334 self.kind() == kind
335 }
336
337 fn is_mutating(&self) -> bool {
341 false
342 }
343
344 async fn handle(
346 &self,
347 invocation: ToolInvocation,
348 ) -> std::result::Result<ToolOutput, FunctionCallError>;
349}
350
351#[derive(Debug)]
357pub struct ToolCallRuntime {
358 execution_lock: Arc<RwLock<()>>,
359}
360
361impl Default for ToolCallRuntime {
362 fn default() -> Self {
363 Self {
364 execution_lock: Arc::new(RwLock::new(())),
365 }
366 }
367}
368
369#[derive(Debug)]
370enum ToolExecutionGuard {
371 Parallel(#[allow(dead_code)] OwnedRwLockReadGuard<()>),
372 Serial(#[allow(dead_code)] OwnedRwLockWriteGuard<()>),
373 Reentrant,
374}
375
376impl ToolCallRuntime {
377 async fn acquire(&self, supports_parallel: bool) -> ToolExecutionGuard {
378 if TOOL_EXECUTION_LOCK_HELD.try_with(|_| ()).is_ok() {
379 return ToolExecutionGuard::Reentrant;
380 }
381
382 if supports_parallel {
383 ToolExecutionGuard::Parallel(self.execution_lock.clone().read_owned().await)
384 } else {
385 ToolExecutionGuard::Serial(self.execution_lock.clone().write_owned().await)
386 }
387 }
388}
389
390#[derive(Default)]
396pub struct ToolRegistry {
397 handlers: HashMap<String, Arc<dyn ToolHandler>>,
398 specs: HashMap<String, ConfiguredToolSpec>,
399 runtime: ToolCallRuntime,
400}
401
402impl ToolRegistry {
403 pub fn register(&mut self, spec: ToolSpec, handler: Arc<dyn ToolHandler>) -> Result<()> {
409 let name = spec.name.clone();
410 self.specs.insert(
411 name.clone(),
412 ConfiguredToolSpec {
413 supports_parallel_tool_calls: spec.supports_parallel_tool_calls,
414 spec,
415 },
416 );
417 self.handlers.insert(name, handler);
418 Ok(())
419 }
420
421 pub fn list_specs(&self) -> Vec<ConfiguredToolSpec> {
423 self.specs.values().cloned().collect()
424 }
425
426 pub async fn dispatch(
434 &self,
435 call: ToolCall,
436 allow_mutating: bool,
437 ) -> std::result::Result<ToolOutput, FunctionCallError> {
438 let handler = self.handlers.get(&call.name).cloned().ok_or_else(|| {
439 FunctionCallError::ToolNotFound {
440 name: call.name.clone(),
441 }
442 })?;
443 let configured =
444 self.specs
445 .get(&call.name)
446 .cloned()
447 .ok_or_else(|| FunctionCallError::ToolNotFound {
448 name: call.name.clone(),
449 })?;
450
451 let payload_kind = tool_payload_kind(&call.payload);
452 let expected = handler.kind();
453 if !handler.matches_kind(payload_kind) {
454 return Err(FunctionCallError::KindMismatch {
455 expected,
456 got: payload_kind,
457 });
458 }
459 if handler.is_mutating() && !allow_mutating {
460 return Err(FunctionCallError::MutatingToolRejected { name: call.name });
461 }
462
463 let invocation = ToolInvocation {
464 call_id: call
465 .raw_tool_call_id
466 .clone()
467 .unwrap_or_else(|| format!("tool-call-{}", uuid::Uuid::new_v4())),
468 tool_name: call.name.clone(),
469 payload: call.payload,
470 source: call.source,
471 };
472
473 let _guard = self
474 .runtime
475 .acquire(configured.supports_parallel_tool_calls)
476 .await;
477
478 TOOL_EXECUTION_LOCK_HELD
479 .scope(
480 (),
481 self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation),
482 )
483 .await
484 }
485
486 async fn execute_with_timeout(
487 &self,
488 handler: Arc<dyn ToolHandler>,
489 timeout_ms: Option<u64>,
490 invocation: ToolInvocation,
491 ) -> std::result::Result<ToolOutput, FunctionCallError> {
492 if let Some(timeout_ms) = timeout_ms {
493 let name = invocation.tool_name.clone();
494 match tokio::time::timeout(
495 Duration::from_millis(timeout_ms),
496 handler.handle(invocation),
497 )
498 .await
499 {
500 Ok(result) => result,
501 Err(_) => Err(FunctionCallError::TimedOut { name, timeout_ms }),
502 }
503 } else {
504 handler.handle(invocation).await
505 }
506 }
507}
508
509fn tool_payload_kind(payload: &ToolPayload) -> ToolKind {
510 match payload {
511 ToolPayload::Mcp { .. } => ToolKind::Mcp,
512 ToolPayload::Function { .. }
513 | ToolPayload::Custom { .. }
514 | ToolPayload::LocalShell { .. } => ToolKind::Function,
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use serde_json::json;
521
522 use super::*;
523
524 #[test]
525 fn tool_result_success_sets_plain_content() {
526 let content = "operation completed successfully";
527 let result = ToolResult::success(content);
528
529 assert!(result.success);
530 assert_eq!(result.content, content);
531 assert!(result.metadata.is_none());
532 }
533
534 #[test]
535 fn tool_result_json_round_trips_content() {
536 let result = ToolResult::json(&json!({"ok": true})).expect("json");
537 assert!(result.success);
538 assert!(result.content.contains("\"ok\": true"));
539 }
540
541 #[test]
542 fn helper_extractors_validate_shape() {
543 let input = json!({"name": "demo", "count": 7, "enabled": true});
544 assert_eq!(required_str(&input, "name").expect("name"), "demo");
545 assert_eq!(optional_str(&input, "name"), Some("demo"));
546 assert_eq!(optional_str(&input, "missing"), None);
547 assert_eq!(optional_str(&input, "count"), None);
548 assert_eq!(optional_str(&json!({"name": null}), "name"), None);
549 assert_eq!(optional_u64(&input, "count", 0), 7);
550 assert!(optional_bool(&input, "enabled", false));
551 assert!(matches!(
552 required_u64(&input, "name"),
553 Err(ToolError::MissingField { .. })
554 ));
555 }
556
557 #[test]
558 fn required_u64_rejects_missing_or_non_integer_values() {
559 assert!(matches!(
560 required_u64(&json!({}), "count"),
561 Err(ToolError::MissingField { .. })
562 ));
563 assert_eq!(required_u64(&json!({"count": 42}), "count").unwrap(), 42);
564 assert_eq!(
565 required_u64(&json!({"count": u64::MAX}), "count").unwrap(),
566 u64::MAX
567 );
568
569 for value in [json!(-1), json!(2.5), json!("42")] {
570 assert!(matches!(
571 required_u64(&json!({"count": value}), "count"),
572 Err(ToolError::MissingField { .. })
573 ));
574 }
575 }
576
577 #[test]
578 fn required_str_reports_provided_fields_on_missing_required_field() {
579 let input = json!({"path": "src/lib.rs", "content": "new body"});
580 let err = required_str(&input, "replace").expect_err("replace is missing");
581 let message = err.to_string();
582 assert!(message.contains("missing required field 'replace'"));
583 assert!(message.contains("Input provided:"));
584 assert!(message.contains("path"));
585 assert!(message.contains("content"));
586 }
587
588 #[test]
589 fn tool_error_display_matches_legacy_text() {
590 let err = ToolError::missing_field("path");
591 assert_eq!(
592 err.to_string(),
593 "Failed to validate input: missing required field 'path'"
594 );
595 }
596
597 #[test]
598 fn tool_error_missing_field_constructor() {
599 let err = ToolError::missing_field("my_field");
600 assert!(matches!(err, ToolError::MissingField { field } if field == "my_field"));
601 }
602
603 #[test]
604 fn tool_error_not_available_displays_reason() {
605 let err = ToolError::not_available("custom tool not found");
606
607 assert!(matches!(err, ToolError::NotAvailable { .. }));
608 assert_eq!(
609 err.to_string(),
610 "Failed to locate tool: custom tool not found"
611 );
612 }
613
614 #[test]
615 fn tool_error_permission_denied_displays_reason() {
616 let err = ToolError::permission_denied("unauthorized user");
617
618 assert!(matches!(err, ToolError::PermissionDenied { .. }));
619 assert_eq!(
620 err.to_string(),
621 "Failed to authorize tool execution: unauthorized user"
622 );
623 }
624
625 #[test]
626 fn tool_error_execution_failed_displays_reason() {
627 let err = ToolError::execution_failed("process crashed");
628
629 assert!(
630 matches!(err, ToolError::ExecutionFailed { ref message } if message == "process crashed")
631 );
632 assert_eq!(err.to_string(), "Failed to execute tool: process crashed");
633 }
634
635 #[test]
636 fn tool_error_invalid_input_creates_correct_variant() {
637 let err = ToolError::invalid_input("test invalid message");
638 match err {
639 ToolError::InvalidInput { message } => {
640 assert_eq!(message, "test invalid message");
641 }
642 _ => panic!("Expected ToolError::InvalidInput, got {err:?}"),
643 }
644 }
645
646 #[test]
647 fn tool_error_path_escape_display() {
648 let path = std::path::PathBuf::from("../outside");
649 let err = ToolError::path_escape(path);
650 assert_eq!(
651 err.to_string(),
652 "Failed to resolve path '../outside': path escapes workspace"
653 );
654 }
655
656 #[test]
657 fn tool_call_execution_subject_uses_local_shell_command_and_cwd() {
658 let call = ToolCall {
659 name: "shell".to_string(),
660 payload: ToolPayload::LocalShell {
661 params: codewhale_protocol::LocalShellParams {
662 command: "ls -l".to_string(),
663 cwd: Some("/custom/dir".to_string()),
664 timeout_ms: None,
665 },
666 },
667 source: ToolCallSource::Direct,
668 raw_tool_call_id: None,
669 };
670
671 assert_eq!(
672 call.execution_subject("/fallback/dir"),
673 ("ls -l".to_string(), "/custom/dir".to_string(), "shell")
674 );
675 }
676
677 #[test]
678 fn tool_call_execution_subject_falls_back_for_shell_without_cwd() {
679 let call = ToolCall {
680 name: "shell".to_string(),
681 payload: ToolPayload::LocalShell {
682 params: codewhale_protocol::LocalShellParams {
683 command: "echo hello".to_string(),
684 cwd: None,
685 timeout_ms: None,
686 },
687 },
688 source: ToolCallSource::Direct,
689 raw_tool_call_id: None,
690 };
691
692 assert_eq!(
693 call.execution_subject("/fallback/dir"),
694 (
695 "echo hello".to_string(),
696 "/fallback/dir".to_string(),
697 "shell"
698 )
699 );
700 }
701
702 #[test]
703 fn tool_call_execution_subject_uses_tool_name_for_non_shell_payloads() {
704 let call = ToolCall {
705 name: "my_tool".to_string(),
706 payload: ToolPayload::Function {
707 arguments: "{}".to_string(),
708 },
709 source: ToolCallSource::Direct,
710 raw_tool_call_id: None,
711 };
712
713 assert_eq!(
714 call.execution_subject("/fallback/dir"),
715 ("my_tool".to_string(), "/fallback/dir".to_string(), "tool")
716 );
717 }
718}