Skip to main content

aivcs_core/
tooling.rs

1//! Tooling and sandbox policy core.
2//!
3//! Provides a deterministic execution layer for tool calls with:
4//! - capability-scoped policy checks
5//! - input/output JSON field validation
6//! - timeout, retry, and circuit-breaker controls
7
8use std::collections::HashMap;
9use std::time::Instant;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use thiserror::Error;
15use tokio::sync::Mutex;
16
17/// Capability class required by a tool.
18#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum ToolCapability {
21    ShellExec,
22    FileRead,
23    FileWrite,
24    GitRead,
25    GitWrite,
26    NetworkFetch,
27    Custom(String),
28}
29
30impl ToolCapability {
31    fn as_policy_key(&self) -> String {
32        match self {
33            Self::ShellExec => "shell_exec".to_string(),
34            Self::FileRead => "file_read".to_string(),
35            Self::FileWrite => "file_write".to_string(),
36            Self::GitRead => "git_read".to_string(),
37            Self::GitWrite => "git_write".to_string(),
38            Self::NetworkFetch => "network_fetch".to_string(),
39            Self::Custom(name) => format!("custom:{name}"),
40        }
41    }
42}
43
44impl std::fmt::Display for ToolCapability {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(f, "{}", self.as_policy_key())
47    }
48}
49
50/// Minimal JSON schema: required top-level fields.
51#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
52pub struct JsonFieldSchema {
53    pub required_fields: Vec<String>,
54}
55
56impl JsonFieldSchema {
57    pub fn required<const N: usize>(fields: [&str; N]) -> Self {
58        Self {
59            required_fields: fields.iter().map(|f| (*f).to_string()).collect(),
60        }
61    }
62}
63
64/// Tool spec in the capability registry.
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66pub struct ToolSpec {
67    pub name: String,
68    pub capability: ToolCapability,
69    pub input_schema: JsonFieldSchema,
70    pub output_schema: JsonFieldSchema,
71}
72
73/// In-memory capability registry.
74#[derive(Debug, Clone, Default)]
75pub struct ToolRegistry {
76    tools: HashMap<String, ToolSpec>,
77}
78
79impl ToolRegistry {
80    pub fn register(&mut self, spec: ToolSpec) -> Result<(), ToolExecutionError> {
81        if self.tools.contains_key(&spec.name) {
82            return Err(ToolExecutionError::DuplicateTool {
83                tool_name: spec.name,
84            });
85        }
86        self.tools.insert(spec.name.clone(), spec);
87        Ok(())
88    }
89
90    pub fn get(&self, name: &str) -> Option<&ToolSpec> {
91        self.tools.get(name)
92    }
93}
94
95/// Policy action for a capability or tool.
96#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
97#[serde(rename_all = "snake_case")]
98pub enum PolicyAction {
99    Allow,
100    Deny,
101    RequireApproval,
102}
103
104/// Policy matrix for capability and tool-level controls.
105#[derive(Debug, Clone, Default)]
106pub struct PolicyMatrix {
107    by_capability: HashMap<ToolCapability, PolicyAction>,
108    by_tool: HashMap<String, PolicyAction>,
109}
110
111impl PolicyMatrix {
112    /// Secure baseline for high-risk operations.
113    pub fn safe_defaults() -> Self {
114        Self::default()
115            .with_capability(ToolCapability::ShellExec, PolicyAction::RequireApproval)
116            .with_capability(ToolCapability::FileWrite, PolicyAction::RequireApproval)
117            .with_capability(ToolCapability::GitWrite, PolicyAction::RequireApproval)
118            .with_capability(ToolCapability::NetworkFetch, PolicyAction::RequireApproval)
119    }
120
121    pub fn with_capability(mut self, capability: ToolCapability, action: PolicyAction) -> Self {
122        self.by_capability.insert(capability, action);
123        self
124    }
125
126    pub fn with_tool_action(mut self, tool_name: impl Into<String>, action: PolicyAction) -> Self {
127        self.by_tool.insert(tool_name.into(), action);
128        self
129    }
130
131    fn action_for(&self, tool: &ToolSpec) -> PolicyAction {
132        if let Some(action) = self.by_tool.get(&tool.name) {
133            *action
134        } else if let Some(action) = self.by_capability.get(&tool.capability) {
135            *action
136        } else {
137            PolicyAction::Allow
138        }
139    }
140}
141
142/// Tool call request.
143#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
144pub struct ToolInvocation {
145    pub name: String,
146    pub input: Value,
147}
148
149impl ToolInvocation {
150    pub fn new(name: impl Into<String>, input: Value) -> Self {
151        Self {
152            name: name.into(),
153            input,
154        }
155    }
156}
157
158/// Timeout/retry/circuit-breaker execution controls.
159#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
160pub struct ToolExecutionConfig {
161    pub timeout_ms: u64,
162    pub max_retries: u32,
163    pub circuit_breaker_threshold: u32,
164}
165
166impl Default for ToolExecutionConfig {
167    fn default() -> Self {
168        Self {
169            timeout_ms: 5_000,
170            max_retries: 0,
171            circuit_breaker_threshold: 3,
172        }
173    }
174}
175
176/// Input or output validation stage.
177#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
178#[serde(rename_all = "snake_case")]
179pub enum SchemaStage {
180    Input,
181    Output,
182}
183
184/// Tool execution status for observability.
185#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
186#[serde(rename_all = "snake_case")]
187pub enum ToolCallStatus {
188    Succeeded,
189}
190
191/// Telemetry emitted for a successful tool execution.
192#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
193pub struct ToolTelemetry {
194    pub run_id: Option<String>,
195    pub tool_name: String,
196    pub retries: u32,
197    pub duration_ms: u128,
198    pub status: ToolCallStatus,
199}
200
201/// Successful tool execution output + telemetry.
202#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
203pub struct ToolExecutionReport {
204    pub output: Value,
205    pub telemetry: ToolTelemetry,
206}
207
208/// Execution failure taxonomy.
209#[derive(Debug, Error, PartialEq, Eq)]
210pub enum ToolExecutionError {
211    #[error("unknown tool: {tool_name}")]
212    UnknownTool { tool_name: String },
213
214    #[error("duplicate tool registration: {tool_name}")]
215    DuplicateTool { tool_name: String },
216
217    #[error("policy denied tool '{tool_name}': {reason}")]
218    PolicyDenied { tool_name: String, reason: String },
219
220    #[error("approval required for tool '{tool_name}': {reason}")]
221    ApprovalRequired { tool_name: String, reason: String },
222
223    #[error("schema violation for tool '{tool_name}' ({stage:?}): missing field '{field}'")]
224    SchemaViolation {
225        tool_name: String,
226        stage: SchemaStage,
227        field: String,
228    },
229
230    #[error("tool '{tool_name}' timed out after {timeout_ms}ms")]
231    Timeout { tool_name: String, timeout_ms: u64 },
232
233    #[error("tool '{tool_name}' adapter error: {message}")]
234    Adapter { tool_name: String, message: String },
235
236    #[error("circuit breaker open for tool '{tool_name}' (failures={failures})")]
237    CircuitOpen { tool_name: String, failures: u32 },
238}
239
240/// Adapter contract for actual tool invocation.
241#[async_trait]
242pub trait ToolAdapter: Send + Sync + 'static {
243    async fn call(&self, tool_name: &str, input: &Value) -> std::result::Result<Value, String>;
244}
245
246/// Policy-aware tool executor.
247pub struct ToolExecutor<A: ToolAdapter> {
248    registry: ToolRegistry,
249    policy: PolicyMatrix,
250    adapter: A,
251    config: ToolExecutionConfig,
252    failure_counts: Mutex<HashMap<String, u32>>,
253}
254
255impl<A: ToolAdapter> ToolExecutor<A> {
256    pub fn new(
257        registry: ToolRegistry,
258        policy: PolicyMatrix,
259        adapter: A,
260        config: ToolExecutionConfig,
261    ) -> Self {
262        Self {
263            registry,
264            policy,
265            adapter,
266            config,
267            failure_counts: Mutex::new(HashMap::new()),
268        }
269    }
270
271    /// Convenience constructor that applies secure policy defaults.
272    pub fn new_with_safe_defaults(
273        registry: ToolRegistry,
274        adapter: A,
275        config: ToolExecutionConfig,
276    ) -> Self {
277        Self::new(registry, PolicyMatrix::safe_defaults(), adapter, config)
278    }
279
280    pub async fn execute(
281        &self,
282        call: ToolInvocation,
283        run_id: Option<String>,
284    ) -> Result<ToolExecutionReport, ToolExecutionError> {
285        let started = Instant::now();
286
287        let spec =
288            self.registry
289                .get(&call.name)
290                .ok_or_else(|| ToolExecutionError::UnknownTool {
291                    tool_name: call.name.clone(),
292                })?;
293
294        match self.policy.action_for(spec) {
295            PolicyAction::Allow => {}
296            PolicyAction::Deny => {
297                return Err(ToolExecutionError::PolicyDenied {
298                    tool_name: call.name.clone(),
299                    reason: format!("capability '{}' is denied", spec.capability.as_policy_key()),
300                });
301            }
302            PolicyAction::RequireApproval => {
303                return Err(ToolExecutionError::ApprovalRequired {
304                    tool_name: call.name.clone(),
305                    reason: format!(
306                        "capability '{}' requires explicit approval",
307                        spec.capability.as_policy_key()
308                    ),
309                });
310            }
311        }
312
313        validate_schema(
314            &call.name,
315            SchemaStage::Input,
316            &spec.input_schema,
317            &call.input,
318        )?;
319
320        let current_failures = self.current_failure_count(&call.name).await;
321        if self.config.circuit_breaker_threshold > 0
322            && current_failures >= self.config.circuit_breaker_threshold
323        {
324            return Err(ToolExecutionError::CircuitOpen {
325                tool_name: call.name.clone(),
326                failures: current_failures,
327            });
328        }
329
330        let mut retries = 0u32;
331        let max_attempts = self.config.max_retries + 1;
332        for attempt in 0..max_attempts {
333            let timeout = tokio::time::Duration::from_millis(self.config.timeout_ms);
334            let call_result =
335                tokio::time::timeout(timeout, self.adapter.call(&call.name, &call.input)).await;
336
337            match call_result {
338                Err(_) => {
339                    if attempt < self.config.max_retries {
340                        retries += 1;
341                        continue;
342                    }
343                    self.increment_failure(&call.name).await;
344                    return Err(ToolExecutionError::Timeout {
345                        tool_name: call.name.clone(),
346                        timeout_ms: self.config.timeout_ms,
347                    });
348                }
349                Ok(Err(message)) => {
350                    if attempt < self.config.max_retries {
351                        retries += 1;
352                        continue;
353                    }
354                    self.increment_failure(&call.name).await;
355                    return Err(ToolExecutionError::Adapter {
356                        tool_name: call.name.clone(),
357                        message,
358                    });
359                }
360                Ok(Ok(output)) => {
361                    validate_schema(
362                        &call.name,
363                        SchemaStage::Output,
364                        &spec.output_schema,
365                        &output,
366                    )?;
367                    self.reset_failure(&call.name).await;
368                    return Ok(ToolExecutionReport {
369                        output,
370                        telemetry: ToolTelemetry {
371                            run_id,
372                            tool_name: call.name,
373                            retries,
374                            duration_ms: started.elapsed().as_millis(),
375                            status: ToolCallStatus::Succeeded,
376                        },
377                    });
378                }
379            }
380        }
381
382        Err(ToolExecutionError::Adapter {
383            tool_name: call.name,
384            message: "unreachable execution state".to_string(),
385        })
386    }
387
388    async fn current_failure_count(&self, tool_name: &str) -> u32 {
389        let guard = self.failure_counts.lock().await;
390        *guard.get(tool_name).unwrap_or(&0)
391    }
392
393    async fn increment_failure(&self, tool_name: &str) {
394        let mut guard = self.failure_counts.lock().await;
395        let count = guard.entry(tool_name.to_string()).or_insert(0);
396        *count += 1;
397    }
398
399    async fn reset_failure(&self, tool_name: &str) {
400        let mut guard = self.failure_counts.lock().await;
401        guard.insert(tool_name.to_string(), 0);
402    }
403}
404
405fn validate_schema(
406    tool_name: &str,
407    stage: SchemaStage,
408    schema: &JsonFieldSchema,
409    payload: &Value,
410) -> Result<(), ToolExecutionError> {
411    for field in &schema.required_fields {
412        if payload.get(field).is_none() {
413            return Err(ToolExecutionError::SchemaViolation {
414                tool_name: tool_name.to_string(),
415                stage,
416                field: field.clone(),
417            });
418        }
419    }
420    Ok(())
421}