Skip to main content

zagens_tools/
lib.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::Result;
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use tokio::sync::RwLock;
11use zagens_protocol::{ToolKind, ToolOutput, ToolPayload};
12
13mod dag_scheduler;
14mod policy_engine;
15mod resource_locks;
16mod tool_manifest;
17
18pub use dag_scheduler::{
19    DagPlanView, ScheduleResource, SchedulerShadowStats, build_execution_waves,
20    record_scheduler_shadow_diff, scheduler_shadow_stats, wave_parallel_eligible,
21};
22pub use policy_engine::{
23    ApprovalNeed, ParallelResourceKey, PolicyDecision, PolicyEngine, PolicyInput, PolicyPlanMeta,
24    PolicySessionMode, PolicyShadowStats, SandboxClass, policy_shadow_stats,
25    record_policy_shadow_diff,
26};
27pub use resource_locks::{ResourceLockMode, resource_lock_order, resource_lock_targets};
28pub use tool_manifest::{
29    Footprint, FootprintProvenance, ResourceSet, SpawnClass, ToolManifest,
30    derive_conservative_footprint,
31};
32
33/// Capabilities that a tool may have or require.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub enum ToolCapability {
36    /// Tool only reads data, never modifies state.
37    ReadOnly,
38    /// Tool writes to the filesystem.
39    WritesFiles,
40    /// Tool executes arbitrary shell commands.
41    ExecutesCode,
42    /// Tool makes network requests.
43    Network,
44    /// Tool can be run in a sandbox.
45    Sandboxable,
46    /// Tool requires user approval before execution.
47    RequiresApproval,
48}
49
50/// Approval requirement for a tool.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum ApprovalRequirement {
53    /// Never needs approval: safe read-only operations.
54    #[default]
55    Auto,
56    /// Suggest approval but allow user to skip.
57    Suggest,
58    /// Always require explicit user approval.
59    Required,
60}
61
62/// Errors that can occur during tool execution.
63#[derive(Debug, Clone)]
64pub enum ToolError {
65    InvalidInput { message: String },
66    MissingField { field: String },
67    PathEscape { path: PathBuf },
68    ExecutionFailed { message: String },
69    Timeout { seconds: u64 },
70    NotAvailable { message: String },
71    PermissionDenied { message: String },
72}
73
74impl std::fmt::Display for ToolError {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        match self {
77            Self::InvalidInput { message } => {
78                write!(f, "Failed to validate input: {message}")
79            }
80            Self::MissingField { field } => {
81                write!(
82                    f,
83                    "Failed to validate input: missing required field '{field}'"
84                )
85            }
86            Self::PathEscape { path } => {
87                write!(
88                    f,
89                    "Failed to resolve path '{}': path escapes workspace",
90                    path.display()
91                )
92            }
93            Self::ExecutionFailed { message } => {
94                write!(f, "Failed to execute tool: {message}")
95            }
96            Self::Timeout { seconds } => {
97                write!(
98                    f,
99                    "Failed to execute tool: operation timed out after {seconds}s"
100                )
101            }
102            Self::NotAvailable { message } => {
103                write!(f, "Failed to locate tool: {message}")
104            }
105            Self::PermissionDenied { message } => {
106                write!(f, "Failed to authorize tool execution: {message}")
107            }
108        }
109    }
110}
111
112impl std::error::Error for ToolError {}
113
114impl ToolError {
115    #[must_use]
116    pub fn invalid_input(msg: impl Into<String>) -> Self {
117        Self::InvalidInput {
118            message: msg.into(),
119        }
120    }
121
122    #[must_use]
123    pub fn missing_field(field: impl Into<String>) -> Self {
124        Self::MissingField {
125            field: field.into(),
126        }
127    }
128
129    #[must_use]
130    pub fn execution_failed(msg: impl Into<String>) -> Self {
131        Self::ExecutionFailed {
132            message: msg.into(),
133        }
134    }
135
136    #[must_use]
137    pub fn path_escape(path: impl Into<PathBuf>) -> Self {
138        Self::PathEscape { path: path.into() }
139    }
140
141    #[must_use]
142    pub fn not_available(msg: impl Into<String>) -> Self {
143        Self::NotAvailable {
144            message: msg.into(),
145        }
146    }
147
148    #[must_use]
149    pub fn permission_denied(msg: impl Into<String>) -> Self {
150        Self::PermissionDenied {
151            message: msg.into(),
152        }
153    }
154}
155
156/// Result of a tool execution.
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct ToolResult {
159    /// The output content, which may be JSON or plain text.
160    pub content: String,
161    /// Whether the execution was successful.
162    pub success: bool,
163    /// Optional structured metadata.
164    #[serde(skip_serializing_if = "Option::is_none")]
165    pub metadata: Option<Value>,
166}
167
168impl ToolResult {
169    /// Create a successful result with content.
170    #[must_use]
171    pub fn success(content: impl Into<String>) -> Self {
172        Self {
173            content: content.into(),
174            success: true,
175            metadata: None,
176        }
177    }
178
179    /// Create an error result with message.
180    #[must_use]
181    pub fn error(message: impl Into<String>) -> Self {
182        Self {
183            content: message.into(),
184            success: false,
185            metadata: None,
186        }
187    }
188
189    /// Create a successful result from JSON.
190    pub fn json<T: Serialize>(value: &T) -> std::result::Result<Self, serde_json::Error> {
191        Ok(Self {
192            content: serde_json::to_string_pretty(value)?,
193            success: true,
194            metadata: None,
195        })
196    }
197
198    /// Add metadata to the result.
199    #[must_use]
200    pub fn with_metadata(mut self, metadata: Value) -> Self {
201        self.metadata = Some(metadata);
202        self
203    }
204}
205
206/// Helper to extract a required string field from JSON input.
207pub fn required_str<'a>(input: &'a Value, field: &str) -> std::result::Result<&'a str, ToolError> {
208    input.get(field).and_then(Value::as_str).ok_or_else(|| {
209        // When the field is missing, list the fields the caller *did*
210        // supply so the model can spot the mismatch without a retry.
211        let provided: Vec<&str> = input
212            .as_object()
213            .map(|obj| obj.keys().map(|k| k.as_str()).collect())
214            .unwrap_or_default();
215        if provided.is_empty() {
216            ToolError::missing_field(field)
217        } else {
218            let hint = format!(
219                "missing required field '{field}'. Input provided: {}",
220                provided.join(", ")
221            );
222            ToolError::invalid_input(hint)
223        }
224    })
225}
226
227/// Helper to extract an optional string field from JSON input.
228#[must_use]
229pub fn optional_str<'a>(input: &'a Value, field: &str) -> Option<&'a str> {
230    input.get(field).and_then(Value::as_str)
231}
232
233/// Helper to extract a required u64 field from JSON input.
234pub fn required_u64(input: &Value, field: &str) -> std::result::Result<u64, ToolError> {
235    input
236        .get(field)
237        .and_then(Value::as_u64)
238        .ok_or_else(|| ToolError::missing_field(field))
239}
240
241/// Helper to extract an optional u64 field with default.
242#[must_use]
243pub fn optional_u64(input: &Value, field: &str, default: u64) -> u64 {
244    input.get(field).and_then(Value::as_u64).unwrap_or(default)
245}
246
247/// Helper to extract an optional bool field with default.
248#[must_use]
249pub fn optional_bool(input: &Value, field: &str, default: bool) -> bool {
250    input.get(field).and_then(Value::as_bool).unwrap_or(default)
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct ToolSpec {
255    pub name: String,
256    pub input_schema: Value,
257    pub output_schema: Value,
258    pub supports_parallel_tool_calls: bool,
259    pub timeout_ms: Option<u64>,
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct ConfiguredToolSpec {
264    pub spec: ToolSpec,
265    pub supports_parallel_tool_calls: bool,
266}
267
268#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
269#[serde(rename_all = "snake_case")]
270pub enum ToolCallSource {
271    Direct,
272    JsRepl,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct ToolCall {
277    pub name: String,
278    pub payload: ToolPayload,
279    pub source: ToolCallSource,
280    pub raw_tool_call_id: Option<String>,
281}
282
283impl ToolCall {
284    pub fn execution_subject(&self, fallback_cwd: &str) -> (String, String, &'static str) {
285        match &self.payload {
286            ToolPayload::LocalShell { params } => (
287                params.command.clone(),
288                params
289                    .cwd
290                    .clone()
291                    .unwrap_or_else(|| fallback_cwd.to_string()),
292                "shell",
293            ),
294            _ => (self.name.clone(), fallback_cwd.to_string(), "tool"),
295        }
296    }
297}
298
299#[derive(Debug, Clone)]
300pub struct ToolInvocation {
301    pub call_id: String,
302    pub tool_name: String,
303    pub payload: ToolPayload,
304    pub source: ToolCallSource,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub enum FunctionCallError {
309    ToolNotFound { name: String },
310    KindMismatch { expected: ToolKind, got: ToolKind },
311    MutatingToolRejected { name: String },
312    TimedOut { name: String, timeout_ms: u64 },
313    Cancelled { name: String },
314    ExecutionFailed { name: String, error: String },
315}
316
317#[async_trait]
318pub trait ToolHandler: Send + Sync {
319    fn kind(&self) -> ToolKind;
320    fn matches_kind(&self, kind: ToolKind) -> bool {
321        self.kind() == kind
322    }
323    fn is_mutating(&self) -> bool {
324        false
325    }
326    async fn handle(
327        &self,
328        invocation: ToolInvocation,
329    ) -> std::result::Result<ToolOutput, FunctionCallError>;
330}
331
332#[derive(Debug, Default)]
333pub struct ToolCallRuntime {
334    pub parallel_execution: Arc<RwLock<()>>,
335}
336
337#[derive(Default)]
338pub struct ToolRegistry {
339    handlers: HashMap<String, Arc<dyn ToolHandler>>,
340    specs: HashMap<String, ConfiguredToolSpec>,
341    runtime: ToolCallRuntime,
342}
343
344impl ToolRegistry {
345    pub fn register(&mut self, spec: ToolSpec, handler: Arc<dyn ToolHandler>) -> Result<()> {
346        let name = spec.name.clone();
347        self.specs.insert(
348            name.clone(),
349            ConfiguredToolSpec {
350                supports_parallel_tool_calls: spec.supports_parallel_tool_calls,
351                spec,
352            },
353        );
354        self.handlers.insert(name, handler);
355        Ok(())
356    }
357
358    pub fn list_specs(&self) -> Vec<ConfiguredToolSpec> {
359        self.specs.values().cloned().collect()
360    }
361
362    pub async fn dispatch(
363        &self,
364        call: ToolCall,
365        allow_mutating: bool,
366    ) -> std::result::Result<ToolOutput, FunctionCallError> {
367        let handler = self.handlers.get(&call.name).cloned().ok_or_else(|| {
368            FunctionCallError::ToolNotFound {
369                name: call.name.clone(),
370            }
371        })?;
372        let configured =
373            self.specs
374                .get(&call.name)
375                .cloned()
376                .ok_or_else(|| FunctionCallError::ToolNotFound {
377                    name: call.name.clone(),
378                })?;
379
380        let payload_kind = tool_payload_kind(&call.payload);
381        let expected = handler.kind();
382        if !handler.matches_kind(payload_kind) {
383            return Err(FunctionCallError::KindMismatch {
384                expected,
385                got: payload_kind,
386            });
387        }
388        if handler.is_mutating() && !allow_mutating {
389            return Err(FunctionCallError::MutatingToolRejected { name: call.name });
390        }
391
392        let invocation = ToolInvocation {
393            call_id: call
394                .raw_tool_call_id
395                .clone()
396                .unwrap_or_else(|| format!("tool-call-{}", uuid::Uuid::new_v4())),
397            tool_name: call.name.clone(),
398            payload: call.payload,
399            source: call.source,
400        };
401
402        if configured.supports_parallel_tool_calls {
403            let _guard = self.runtime.parallel_execution.read().await;
404            self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
405                .await
406        } else {
407            let _guard = self.runtime.parallel_execution.write().await;
408            self.execute_with_timeout(handler, configured.spec.timeout_ms, invocation)
409                .await
410        }
411    }
412
413    async fn execute_with_timeout(
414        &self,
415        handler: Arc<dyn ToolHandler>,
416        timeout_ms: Option<u64>,
417        invocation: ToolInvocation,
418    ) -> std::result::Result<ToolOutput, FunctionCallError> {
419        if let Some(timeout_ms) = timeout_ms {
420            let name = invocation.tool_name.clone();
421            match tokio::time::timeout(
422                Duration::from_millis(timeout_ms),
423                handler.handle(invocation),
424            )
425            .await
426            {
427                Ok(result) => result,
428                Err(_) => Err(FunctionCallError::TimedOut { name, timeout_ms }),
429            }
430        } else {
431            handler.handle(invocation).await
432        }
433    }
434}
435
436fn tool_payload_kind(payload: &ToolPayload) -> ToolKind {
437    match payload {
438        ToolPayload::Mcp { .. } => ToolKind::Mcp,
439        ToolPayload::Function { .. }
440        | ToolPayload::Custom { .. }
441        | ToolPayload::LocalShell { .. } => ToolKind::Function,
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use serde_json::json;
448
449    use super::*;
450
451    #[test]
452    fn tool_result_json_round_trips_content() {
453        let result = ToolResult::json(&json!({"ok": true})).expect("json");
454        assert!(result.success);
455        assert!(result.content.contains("\"ok\": true"));
456    }
457
458    #[test]
459    fn helper_extractors_validate_shape() {
460        let input = json!({"name": "demo", "count": 7, "enabled": true});
461        assert_eq!(required_str(&input, "name").expect("name"), "demo");
462        assert_eq!(optional_u64(&input, "count", 0), 7);
463        assert!(optional_bool(&input, "enabled", false));
464        assert!(matches!(
465            required_u64(&input, "name"),
466            Err(ToolError::MissingField { .. })
467        ));
468    }
469
470    #[test]
471    fn required_str_reports_provided_fields_on_missing_required_field() {
472        let input = json!({"path": "src/lib.rs", "content": "new body"});
473        let err = required_str(&input, "replace").expect_err("replace is missing");
474        let message = err.to_string();
475        assert!(message.contains("missing required field 'replace'"));
476        assert!(message.contains("Input provided:"));
477        assert!(message.contains("path"));
478        assert!(message.contains("content"));
479    }
480
481    #[test]
482    fn tool_error_display_matches_legacy_text() {
483        let err = ToolError::missing_field("path");
484        assert_eq!(
485            err.to_string(),
486            "Failed to validate input: missing required field 'path'"
487        );
488    }
489}