1use std::collections::HashMap;
9use std::time::Instant;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use thiserror::Error;
15use tokio::sync::Mutex;
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum ToolCapability {
21 ShellExec,
22 FileRead,
23 FileWrite,
24 GitRead,
25 GitWrite,
26 NetworkFetch,
27 Custom(String),
28}
29
30impl ToolCapability {
31 fn as_policy_key(&self) -> String {
32 match self {
33 Self::ShellExec => "shell_exec".to_string(),
34 Self::FileRead => "file_read".to_string(),
35 Self::FileWrite => "file_write".to_string(),
36 Self::GitRead => "git_read".to_string(),
37 Self::GitWrite => "git_write".to_string(),
38 Self::NetworkFetch => "network_fetch".to_string(),
39 Self::Custom(name) => format!("custom:{name}"),
40 }
41 }
42}
43
44impl std::fmt::Display for ToolCapability {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 write!(f, "{}", self.as_policy_key())
47 }
48}
49
50#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
52pub struct JsonFieldSchema {
53 pub required_fields: Vec<String>,
54}
55
56impl JsonFieldSchema {
57 pub fn required<const N: usize>(fields: [&str; N]) -> Self {
58 Self {
59 required_fields: fields.iter().map(|f| (*f).to_string()).collect(),
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66pub struct ToolSpec {
67 pub name: String,
68 pub capability: ToolCapability,
69 pub input_schema: JsonFieldSchema,
70 pub output_schema: JsonFieldSchema,
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct ToolRegistry {
76 tools: HashMap<String, ToolSpec>,
77}
78
79impl ToolRegistry {
80 pub fn register(&mut self, spec: ToolSpec) -> Result<(), ToolExecutionError> {
81 if self.tools.contains_key(&spec.name) {
82 return Err(ToolExecutionError::DuplicateTool {
83 tool_name: spec.name,
84 });
85 }
86 self.tools.insert(spec.name.clone(), spec);
87 Ok(())
88 }
89
90 pub fn get(&self, name: &str) -> Option<&ToolSpec> {
91 self.tools.get(name)
92 }
93}
94
95#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
97#[serde(rename_all = "snake_case")]
98pub enum PolicyAction {
99 Allow,
100 Deny,
101 RequireApproval,
102}
103
104#[derive(Debug, Clone, Default)]
106pub struct PolicyMatrix {
107 by_capability: HashMap<ToolCapability, PolicyAction>,
108 by_tool: HashMap<String, PolicyAction>,
109}
110
111impl PolicyMatrix {
112 pub fn safe_defaults() -> Self {
114 Self::default()
115 .with_capability(ToolCapability::ShellExec, PolicyAction::RequireApproval)
116 .with_capability(ToolCapability::FileWrite, PolicyAction::RequireApproval)
117 .with_capability(ToolCapability::GitWrite, PolicyAction::RequireApproval)
118 .with_capability(ToolCapability::NetworkFetch, PolicyAction::RequireApproval)
119 }
120
121 pub fn with_capability(mut self, capability: ToolCapability, action: PolicyAction) -> Self {
122 self.by_capability.insert(capability, action);
123 self
124 }
125
126 pub fn with_tool_action(mut self, tool_name: impl Into<String>, action: PolicyAction) -> Self {
127 self.by_tool.insert(tool_name.into(), action);
128 self
129 }
130
131 fn action_for(&self, tool: &ToolSpec) -> PolicyAction {
132 if let Some(action) = self.by_tool.get(&tool.name) {
133 *action
134 } else if let Some(action) = self.by_capability.get(&tool.capability) {
135 *action
136 } else {
137 PolicyAction::Allow
138 }
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
144pub struct ToolInvocation {
145 pub name: String,
146 pub input: Value,
147}
148
149impl ToolInvocation {
150 pub fn new(name: impl Into<String>, input: Value) -> Self {
151 Self {
152 name: name.into(),
153 input,
154 }
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
160pub struct ToolExecutionConfig {
161 pub timeout_ms: u64,
162 pub max_retries: u32,
163 pub circuit_breaker_threshold: u32,
164}
165
166impl Default for ToolExecutionConfig {
167 fn default() -> Self {
168 Self {
169 timeout_ms: 5_000,
170 max_retries: 0,
171 circuit_breaker_threshold: 3,
172 }
173 }
174}
175
176#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
178#[serde(rename_all = "snake_case")]
179pub enum SchemaStage {
180 Input,
181 Output,
182}
183
184#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
186#[serde(rename_all = "snake_case")]
187pub enum ToolCallStatus {
188 Succeeded,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
193pub struct ToolTelemetry {
194 pub run_id: Option<String>,
195 pub tool_name: String,
196 pub retries: u32,
197 pub duration_ms: u128,
198 pub status: ToolCallStatus,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
203pub struct ToolExecutionReport {
204 pub output: Value,
205 pub telemetry: ToolTelemetry,
206}
207
208#[derive(Debug, Error, PartialEq, Eq)]
210pub enum ToolExecutionError {
211 #[error("unknown tool: {tool_name}")]
212 UnknownTool { tool_name: String },
213
214 #[error("duplicate tool registration: {tool_name}")]
215 DuplicateTool { tool_name: String },
216
217 #[error("policy denied tool '{tool_name}': {reason}")]
218 PolicyDenied { tool_name: String, reason: String },
219
220 #[error("approval required for tool '{tool_name}': {reason}")]
221 ApprovalRequired { tool_name: String, reason: String },
222
223 #[error("schema violation for tool '{tool_name}' ({stage:?}): missing field '{field}'")]
224 SchemaViolation {
225 tool_name: String,
226 stage: SchemaStage,
227 field: String,
228 },
229
230 #[error("tool '{tool_name}' timed out after {timeout_ms}ms")]
231 Timeout { tool_name: String, timeout_ms: u64 },
232
233 #[error("tool '{tool_name}' adapter error: {message}")]
234 Adapter { tool_name: String, message: String },
235
236 #[error("circuit breaker open for tool '{tool_name}' (failures={failures})")]
237 CircuitOpen { tool_name: String, failures: u32 },
238}
239
240#[async_trait]
242pub trait ToolAdapter: Send + Sync + 'static {
243 async fn call(&self, tool_name: &str, input: &Value) -> std::result::Result<Value, String>;
244}
245
246pub struct ToolExecutor<A: ToolAdapter> {
248 registry: ToolRegistry,
249 policy: PolicyMatrix,
250 adapter: A,
251 config: ToolExecutionConfig,
252 failure_counts: Mutex<HashMap<String, u32>>,
253}
254
255impl<A: ToolAdapter> ToolExecutor<A> {
256 pub fn new(
257 registry: ToolRegistry,
258 policy: PolicyMatrix,
259 adapter: A,
260 config: ToolExecutionConfig,
261 ) -> Self {
262 Self {
263 registry,
264 policy,
265 adapter,
266 config,
267 failure_counts: Mutex::new(HashMap::new()),
268 }
269 }
270
271 pub fn new_with_safe_defaults(
273 registry: ToolRegistry,
274 adapter: A,
275 config: ToolExecutionConfig,
276 ) -> Self {
277 Self::new(registry, PolicyMatrix::safe_defaults(), adapter, config)
278 }
279
280 pub async fn execute(
281 &self,
282 call: ToolInvocation,
283 run_id: Option<String>,
284 ) -> Result<ToolExecutionReport, ToolExecutionError> {
285 let started = Instant::now();
286
287 let spec =
288 self.registry
289 .get(&call.name)
290 .ok_or_else(|| ToolExecutionError::UnknownTool {
291 tool_name: call.name.clone(),
292 })?;
293
294 match self.policy.action_for(spec) {
295 PolicyAction::Allow => {}
296 PolicyAction::Deny => {
297 return Err(ToolExecutionError::PolicyDenied {
298 tool_name: call.name.clone(),
299 reason: format!("capability '{}' is denied", spec.capability.as_policy_key()),
300 });
301 }
302 PolicyAction::RequireApproval => {
303 return Err(ToolExecutionError::ApprovalRequired {
304 tool_name: call.name.clone(),
305 reason: format!(
306 "capability '{}' requires explicit approval",
307 spec.capability.as_policy_key()
308 ),
309 });
310 }
311 }
312
313 validate_schema(
314 &call.name,
315 SchemaStage::Input,
316 &spec.input_schema,
317 &call.input,
318 )?;
319
320 let current_failures = self.current_failure_count(&call.name).await;
321 if self.config.circuit_breaker_threshold > 0
322 && current_failures >= self.config.circuit_breaker_threshold
323 {
324 return Err(ToolExecutionError::CircuitOpen {
325 tool_name: call.name.clone(),
326 failures: current_failures,
327 });
328 }
329
330 let mut retries = 0u32;
331 let max_attempts = self.config.max_retries + 1;
332 for attempt in 0..max_attempts {
333 let timeout = tokio::time::Duration::from_millis(self.config.timeout_ms);
334 let call_result =
335 tokio::time::timeout(timeout, self.adapter.call(&call.name, &call.input)).await;
336
337 match call_result {
338 Err(_) => {
339 if attempt < self.config.max_retries {
340 retries += 1;
341 continue;
342 }
343 self.increment_failure(&call.name).await;
344 return Err(ToolExecutionError::Timeout {
345 tool_name: call.name.clone(),
346 timeout_ms: self.config.timeout_ms,
347 });
348 }
349 Ok(Err(message)) => {
350 if attempt < self.config.max_retries {
351 retries += 1;
352 continue;
353 }
354 self.increment_failure(&call.name).await;
355 return Err(ToolExecutionError::Adapter {
356 tool_name: call.name.clone(),
357 message,
358 });
359 }
360 Ok(Ok(output)) => {
361 validate_schema(
362 &call.name,
363 SchemaStage::Output,
364 &spec.output_schema,
365 &output,
366 )?;
367 self.reset_failure(&call.name).await;
368 return Ok(ToolExecutionReport {
369 output,
370 telemetry: ToolTelemetry {
371 run_id,
372 tool_name: call.name,
373 retries,
374 duration_ms: started.elapsed().as_millis(),
375 status: ToolCallStatus::Succeeded,
376 },
377 });
378 }
379 }
380 }
381
382 Err(ToolExecutionError::Adapter {
383 tool_name: call.name,
384 message: "unreachable execution state".to_string(),
385 })
386 }
387
388 async fn current_failure_count(&self, tool_name: &str) -> u32 {
389 let guard = self.failure_counts.lock().await;
390 *guard.get(tool_name).unwrap_or(&0)
391 }
392
393 async fn increment_failure(&self, tool_name: &str) {
394 let mut guard = self.failure_counts.lock().await;
395 let count = guard.entry(tool_name.to_string()).or_insert(0);
396 *count += 1;
397 }
398
399 async fn reset_failure(&self, tool_name: &str) {
400 let mut guard = self.failure_counts.lock().await;
401 guard.insert(tool_name.to_string(), 0);
402 }
403}
404
405fn validate_schema(
406 tool_name: &str,
407 stage: SchemaStage,
408 schema: &JsonFieldSchema,
409 payload: &Value,
410) -> Result<(), ToolExecutionError> {
411 for field in &schema.required_fields {
412 if payload.get(field).is_none() {
413 return Err(ToolExecutionError::SchemaViolation {
414 tool_name: tool_name.to_string(),
415 stage,
416 field: field.clone(),
417 });
418 }
419 }
420 Ok(())
421}