1use std::sync::Arc;
21
22use futures::future::join_all;
23use tracing::{debug, error, warn};
24
25use cortexai_core::{AgentId, ToolCall, ToolResult};
26use cortexai_tools::{ExecutionContext, ToolRegistry};
27
28use crate::approvals::{ApprovalDecision, ApprovalHandler, ApprovalRequest, AutoApproveHandler};
29use crate::scope_guard::ScopeGuard;
30use crate::trace::{ExecutionTrace, ToolCallTrace};
31
32#[derive(Debug)]
34pub enum ExecutionOutcome {
35 Success(ToolResult),
37
38 Denied {
40 tool_call_id: String,
41 reason: String,
42 },
43
44 Skipped { tool_call_id: String },
46
47 ApprovalError { tool_call_id: String, error: String },
49}
50
51impl ExecutionOutcome {
52 pub fn into_tool_result(self) -> ToolResult {
53 match self {
54 Self::Success(result) => result,
55 Self::Denied {
56 tool_call_id,
57 reason,
58 } => ToolResult::failure(tool_call_id, format!("Denied: {}", reason)),
59 Self::Skipped { tool_call_id } => {
60 ToolResult::failure(tool_call_id, "Tool execution skipped".to_string())
61 }
62 Self::ApprovalError {
63 tool_call_id,
64 error,
65 } => ToolResult::failure(tool_call_id, format!("Approval error: {}", error)),
66 }
67 }
68
69 pub fn is_success(&self) -> bool {
70 matches!(self, Self::Success(_))
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct ExecutorConfig {
77 pub max_concurrency: usize,
79
80 pub timeout_seconds: u64,
82
83 pub continue_on_failure: bool,
85}
86
87impl Default for ExecutorConfig {
88 fn default() -> Self {
89 Self {
90 max_concurrency: 4,
91 timeout_seconds: 30,
92 continue_on_failure: true,
93 }
94 }
95}
96
97pub struct ToolExecutor<H: ApprovalHandler = AutoApproveHandler> {
99 config: ExecutorConfig,
100 approval_handler: Arc<H>,
101 scope_guard: Option<Arc<ScopeGuard>>,
103 trace: Option<Arc<ExecutionTrace>>,
105}
106
107impl ToolExecutor<AutoApproveHandler> {
108 pub fn new(max_concurrency: usize) -> Self {
110 Self {
111 config: ExecutorConfig {
112 max_concurrency,
113 ..Default::default()
114 },
115 approval_handler: Arc::new(AutoApproveHandler::new()),
116 scope_guard: None,
117 trace: None,
118 }
119 }
120}
121
122impl<H: ApprovalHandler + 'static> ToolExecutor<H> {
123 pub fn with_approval_handler(config: ExecutorConfig, handler: H) -> Self {
125 Self {
126 config,
127 approval_handler: Arc::new(handler),
128 scope_guard: None,
129 trace: None,
130 }
131 }
132
133 pub fn with_shared_handler(config: ExecutorConfig, handler: Arc<H>) -> Self {
135 Self {
136 config,
137 approval_handler: handler,
138 scope_guard: None,
139 trace: None,
140 }
141 }
142
143 pub fn with_scope_guard(self, guard: ScopeGuard) -> Self {
147 Self {
148 scope_guard: Some(Arc::new(guard)),
149 ..self
150 }
151 }
152
153 pub fn with_trace(self, trace: Arc<ExecutionTrace>) -> Self {
157 Self {
158 trace: Some(trace),
159 ..self
160 }
161 }
162
163 pub async fn execute_tools(
165 &self,
166 tool_calls: &[ToolCall],
167 registry: &Arc<ToolRegistry>,
168 agent_id: &AgentId,
169 ) -> Vec<ToolResult> {
170 let outcomes = self
171 .execute_tools_with_outcomes(tool_calls, registry, agent_id)
172 .await;
173 outcomes.into_iter().map(|o| o.into_tool_result()).collect()
174 }
175
176 pub async fn execute_tools_with_outcomes(
178 &self,
179 tool_calls: &[ToolCall],
180 registry: &Arc<ToolRegistry>,
181 agent_id: &AgentId,
182 ) -> Vec<ExecutionOutcome> {
183 debug!(
184 "Executing {} tools with concurrency {}",
185 tool_calls.len(),
186 self.config.max_concurrency
187 );
188
189 let tasks: Vec<_> = tool_calls
191 .iter()
192 .map(|call| {
193 let registry = registry.clone();
194 let agent_id = agent_id.clone();
195 let call = call.clone();
196 let handler = self.approval_handler.clone();
197 let config = self.config.clone();
198 let scope_guard = self.scope_guard.clone();
199 let trace = self.trace.clone();
200
201 async move {
202 Self::execute_with_approval(
203 call,
204 registry,
205 agent_id,
206 handler,
207 config,
208 scope_guard,
209 trace,
210 )
211 .await
212 }
213 })
214 .collect();
215
216 let semaphore = Arc::new(tokio::sync::Semaphore::new(self.config.max_concurrency));
218 let tasks_with_semaphore: Vec<_> = tasks
219 .into_iter()
220 .map(|task| {
221 let semaphore = semaphore.clone();
222 async move {
223 let _permit = semaphore.acquire().await.unwrap();
224 task.await
225 }
226 })
227 .collect();
228
229 join_all(tasks_with_semaphore).await
230 }
231
232 async fn execute_with_approval(
233 call: ToolCall,
234 registry: Arc<ToolRegistry>,
235 agent_id: AgentId,
236 handler: Arc<H>,
237 config: ExecutorConfig,
238 scope_guard: Option<Arc<ScopeGuard>>,
239 trace: Option<Arc<ExecutionTrace>>,
240 ) -> ExecutionOutcome {
241 let tool_schema = registry.get(&call.name).map(|t| t.schema());
243
244 if let Some(guard) = &scope_guard {
246 if let Some(schema) = &tool_schema {
247 debug!("Scope check for tool '{}'", call.name);
248 match guard
249 .check_with_escalation(schema, &call, handler.as_ref())
250 .await
251 {
252 Ok(true) => {}
253 Ok(false) => {
254 warn!("Tool '{}' denied by scope guard", call.name);
255 return ExecutionOutcome::Denied {
256 tool_call_id: call.id,
257 reason: "Insufficient scope permissions".to_string(),
258 };
259 }
260 Err(e) => {
261 error!("Scope guard error for tool '{}': {}", call.name, e);
262 return ExecutionOutcome::ApprovalError {
263 tool_call_id: call.id,
264 error: e.to_string(),
265 };
266 }
267 }
268 }
269 }
270
271 if let Some(reason) = handler.check_requires_approval(&call, tool_schema.as_ref()) {
273 debug!("Tool '{}' requires approval: {}", call.name, reason);
274
275 let request = ApprovalRequest {
276 tool_call: call.clone(),
277 tool_schema,
278 reason: reason.clone(),
279 timestamp: chrono::Utc::now(),
280 context: None,
281 };
282
283 match handler.request_approval(request).await {
284 Ok(ApprovalDecision::Approved) => {
285 debug!("Tool '{}' approved", call.name);
286 }
287 Ok(ApprovalDecision::Modify { new_arguments }) => {
288 debug!("Tool '{}' approved with modifications", call.name);
289 let modified_call = ToolCall {
290 id: call.id.clone(),
291 name: call.name.clone(),
292 arguments: new_arguments,
293 };
294 return Self::execute_single_tool(
295 modified_call,
296 registry,
297 agent_id,
298 config,
299 trace,
300 )
301 .await;
302 }
303 Ok(ApprovalDecision::Denied { reason }) => {
304 warn!("Tool '{}' denied: {}", call.name, reason);
305 return ExecutionOutcome::Denied {
306 tool_call_id: call.id,
307 reason,
308 };
309 }
310 Ok(ApprovalDecision::Skip) => {
311 debug!("Tool '{}' skipped", call.name);
312 return ExecutionOutcome::Skipped {
313 tool_call_id: call.id,
314 };
315 }
316 Err(e) => {
317 error!("Approval error for tool '{}': {}", call.name, e);
318 return ExecutionOutcome::ApprovalError {
319 tool_call_id: call.id,
320 error: e.to_string(),
321 };
322 }
323 }
324 }
325
326 Self::execute_single_tool(call, registry, agent_id, config, trace).await
327 }
328
329 async fn execute_single_tool(
330 call: ToolCall,
331 registry: Arc<ToolRegistry>,
332 agent_id: AgentId,
333 config: ExecutorConfig,
334 trace: Option<Arc<ExecutionTrace>>,
335 ) -> ExecutionOutcome {
336 debug!("Executing tool: {} (call_id: {})", call.name, call.id);
337
338 let tool = match registry.get(&call.name) {
339 Some(tool) => tool,
340 None => {
341 error!("Tool not found: {}", call.name);
342 return ExecutionOutcome::Success(ToolResult::failure(
343 call.id,
344 format!("Tool '{}' not found", call.name),
345 ));
346 }
347 };
348
349 if let Err(e) = tool.validate(&call.arguments) {
351 error!("Tool validation failed: {}", e);
352 return ExecutionOutcome::Success(ToolResult::failure(call.id, e.to_string()));
353 }
354
355 let context = ExecutionContext::new(agent_id);
357
358 let exec_start = std::time::Instant::now();
360 let timeout_duration = std::time::Duration::from_secs(config.timeout_seconds);
361 let result = match tokio::time::timeout(
362 timeout_duration,
363 tool.execute(&context, call.arguments.clone()),
364 )
365 .await
366 {
367 Ok(Ok(data)) => {
368 debug!("Tool {} executed successfully", call.name);
369 let duration_ms = exec_start.elapsed().as_millis() as u64;
370 if let Some(ref t) = trace {
371 t.record_tool_call(ToolCallTrace {
372 tool_name: call.name.clone(),
373 input: call.arguments.clone(),
374 output: data.clone(),
375 duration_ms,
376 error: None,
377 });
378 }
379 ToolResult::success(call.id, data)
380 }
381 Ok(Err(e)) => {
382 error!("Tool {} execution failed: {}", call.name, e);
383 let duration_ms = exec_start.elapsed().as_millis() as u64;
384 if let Some(ref t) = trace {
385 t.record_tool_call(ToolCallTrace {
386 tool_name: call.name.clone(),
387 input: call.arguments.clone(),
388 output: serde_json::Value::Null,
389 duration_ms,
390 error: Some(e.to_string()),
391 });
392 }
393 ToolResult::failure(call.id, e.to_string())
394 }
395 Err(_) => {
396 error!("Tool {} timed out after {:?}", call.name, timeout_duration);
397 let duration_ms = exec_start.elapsed().as_millis() as u64;
398 let err_msg = format!("Tool execution timed out after {:?}", timeout_duration);
399 if let Some(ref t) = trace {
400 t.record_tool_call(ToolCallTrace {
401 tool_name: call.name.clone(),
402 input: call.arguments.clone(),
403 output: serde_json::Value::Null,
404 duration_ms,
405 error: Some(err_msg.clone()),
406 });
407 }
408 ToolResult::failure(call.id, err_msg)
409 }
410 };
411
412 ExecutionOutcome::Success(result)
413 }
414}
415
416impl<H: ApprovalHandler + Clone> Clone for ToolExecutor<H> {
418 fn clone(&self) -> Self {
419 Self {
420 config: self.config.clone(),
421 approval_handler: self.approval_handler.clone(),
422 scope_guard: self.scope_guard.clone(),
423 trace: self.trace.clone(),
424 }
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use crate::approvals::{ApprovalConfig, ApprovalReason, TestApprovalHandler};
432 use cortexai_core::{Tool, ToolSchema};
433 use serde_json::json;
434 use std::collections::HashMap;
435
436 struct EchoTool;
438
439 #[async_trait::async_trait]
440 impl Tool for EchoTool {
441 fn schema(&self) -> ToolSchema {
442 ToolSchema {
443 name: "echo".to_string(),
444 description: "Echoes input".to_string(),
445 parameters: json!({"type": "object", "properties": {"message": {"type": "string"}}}),
446 dangerous: false,
447 metadata: HashMap::new(),
448 required_scopes: vec![],
449 }
450 }
451
452 async fn execute(
453 &self,
454 _ctx: &ExecutionContext,
455 args: serde_json::Value,
456 ) -> Result<serde_json::Value, cortexai_core::errors::ToolError> {
457 Ok(args)
458 }
459 }
460
461 struct DangerousTool;
463
464 #[async_trait::async_trait]
465 impl Tool for DangerousTool {
466 fn schema(&self) -> ToolSchema {
467 ToolSchema {
468 name: "dangerous".to_string(),
469 description: "A dangerous tool".to_string(),
470 parameters: json!({"type": "object"}),
471 dangerous: true,
472 metadata: HashMap::new(),
473 required_scopes: vec![],
474 }
475 }
476
477 async fn execute(
478 &self,
479 _ctx: &ExecutionContext,
480 _args: serde_json::Value,
481 ) -> Result<serde_json::Value, cortexai_core::errors::ToolError> {
482 Ok(json!({"result": "executed"}))
483 }
484 }
485
486 #[tokio::test]
487 async fn test_executor_basic() {
488 let mut registry = ToolRegistry::new();
489 registry.register(Arc::new(EchoTool));
490 let registry = Arc::new(registry);
491
492 let executor = ToolExecutor::new(4);
493 let agent_id = AgentId::new("test");
494
495 let calls = vec![ToolCall {
496 id: "1".to_string(),
497 name: "echo".to_string(),
498 arguments: json!({"message": "hello"}),
499 }];
500
501 let results = executor.execute_tools(&calls, ®istry, &agent_id).await;
502
503 assert_eq!(results.len(), 1);
504 assert!(results[0].success);
505 }
506
507 #[tokio::test]
508 async fn test_executor_with_approval_deny() {
509 let mut registry = ToolRegistry::new();
510 registry.register(Arc::new(DangerousTool));
511 let registry = Arc::new(registry);
512
513 let config = ApprovalConfig::builder()
514 .require_approval_for_dangerous(true)
515 .build();
516
517 let handler = TestApprovalHandler::with_config(config);
518 handler.set_decision("dangerous", ApprovalDecision::denied("Not allowed"));
519
520 let executor = ToolExecutor::with_approval_handler(ExecutorConfig::default(), handler);
521 let agent_id = AgentId::new("test");
522
523 let calls = vec![ToolCall {
524 id: "1".to_string(),
525 name: "dangerous".to_string(),
526 arguments: json!({}),
527 }];
528
529 let outcomes = executor
530 .execute_tools_with_outcomes(&calls, ®istry, &agent_id)
531 .await;
532
533 assert_eq!(outcomes.len(), 1);
534 assert!(matches!(outcomes[0], ExecutionOutcome::Denied { .. }));
535 }
536
537 #[tokio::test]
538 async fn test_executor_with_approval_approve() {
539 let mut registry = ToolRegistry::new();
540 registry.register(Arc::new(DangerousTool));
541 let registry = Arc::new(registry);
542
543 let config = ApprovalConfig::builder()
544 .require_approval_for_dangerous(true)
545 .build();
546
547 let handler = TestApprovalHandler::with_config(config);
548 let executor = ToolExecutor::with_approval_handler(ExecutorConfig::default(), handler);
551 let agent_id = AgentId::new("test");
552
553 let calls = vec![ToolCall {
554 id: "1".to_string(),
555 name: "dangerous".to_string(),
556 arguments: json!({}),
557 }];
558
559 let outcomes = executor
560 .execute_tools_with_outcomes(&calls, ®istry, &agent_id)
561 .await;
562
563 assert_eq!(outcomes.len(), 1);
564 assert!(outcomes[0].is_success());
565 }
566
567 #[tokio::test]
568 async fn test_executor_approval_request_recorded() {
569 let mut registry = ToolRegistry::new();
570 registry.register(Arc::new(DangerousTool));
571 let registry = Arc::new(registry);
572
573 let config = ApprovalConfig::builder()
574 .require_approval_for_dangerous(true)
575 .build();
576
577 let handler = Arc::new(TestApprovalHandler::with_config(config));
578
579 let executor =
580 ToolExecutor::with_shared_handler(ExecutorConfig::default(), handler.clone());
581 let agent_id = AgentId::new("test");
582
583 let calls = vec![ToolCall {
584 id: "1".to_string(),
585 name: "dangerous".to_string(),
586 arguments: json!({"key": "value"}),
587 }];
588
589 executor.execute_tools(&calls, ®istry, &agent_id).await;
590
591 assert!(handler.was_approval_requested("dangerous"));
593 assert_eq!(handler.request_count(), 1);
594
595 let requests = handler.get_requests();
596 assert_eq!(requests[0].tool_call.name, "dangerous");
597 assert!(matches!(requests[0].reason, ApprovalReason::DangerousTool));
598 }
599
600 #[tokio::test]
601 async fn test_executor_skip() {
602 let mut registry = ToolRegistry::new();
603 registry.register(Arc::new(DangerousTool));
604 let registry = Arc::new(registry);
605
606 let config = ApprovalConfig::builder()
607 .require_approval_for_dangerous(true)
608 .build();
609
610 let handler = TestApprovalHandler::with_config(config);
611 handler.set_decision("dangerous", ApprovalDecision::Skip);
612
613 let executor = ToolExecutor::with_approval_handler(ExecutorConfig::default(), handler);
614 let agent_id = AgentId::new("test");
615
616 let calls = vec![ToolCall {
617 id: "1".to_string(),
618 name: "dangerous".to_string(),
619 arguments: json!({}),
620 }];
621
622 let outcomes = executor
623 .execute_tools_with_outcomes(&calls, ®istry, &agent_id)
624 .await;
625
626 assert_eq!(outcomes.len(), 1);
627 assert!(matches!(outcomes[0], ExecutionOutcome::Skipped { .. }));
628 }
629
630 #[tokio::test]
631 async fn test_executor_safe_tool_no_approval() {
632 let mut registry = ToolRegistry::new();
633 registry.register(Arc::new(EchoTool));
634 let registry = Arc::new(registry);
635
636 let config = ApprovalConfig::builder()
637 .require_approval_for_dangerous(true)
638 .build();
639
640 let handler = Arc::new(TestApprovalHandler::with_config(config));
641
642 let executor =
643 ToolExecutor::with_shared_handler(ExecutorConfig::default(), handler.clone());
644 let agent_id = AgentId::new("test");
645
646 let calls = vec![ToolCall {
647 id: "1".to_string(),
648 name: "echo".to_string(),
649 arguments: json!({"message": "hi"}),
650 }];
651
652 let results = executor.execute_tools(&calls, ®istry, &agent_id).await;
653
654 assert_eq!(results.len(), 1);
655 assert!(results[0].success);
656
657 assert_eq!(handler.request_count(), 0);
659 }
660
661 struct ScopedWriteTool;
667
668 #[async_trait::async_trait]
669 impl Tool for ScopedWriteTool {
670 fn schema(&self) -> ToolSchema {
671 ToolSchema {
672 name: "write_file".to_string(),
673 description: "Writes a file (requires fs:write scope)".to_string(),
674 parameters: json!({"type": "object"}),
675 dangerous: false,
676 metadata: HashMap::new(),
677 required_scopes: vec!["fs:write".to_string()],
678 }
679 }
680
681 async fn execute(
682 &self,
683 _ctx: &ExecutionContext,
684 _args: serde_json::Value,
685 ) -> Result<serde_json::Value, cortexai_core::errors::ToolError> {
686 Ok(json!({"written": true}))
687 }
688 }
689
690 #[tokio::test]
691 async fn test_executor_scope_guard_allows_granted_scope() {
692 use crate::scope_guard::{ScopeGuard, ScopePolicy};
693
694 let mut registry = ToolRegistry::new();
695 registry.register(Arc::new(ScopedWriteTool));
696 let registry = Arc::new(registry);
697
698 let policy = ScopePolicy::new().grant("fs:write");
699 let guard = ScopeGuard::new(policy);
700
701 let executor = ToolExecutor::new(4).with_scope_guard(guard);
702 let agent_id = AgentId::new("agent-with-write");
703
704 let calls = vec![ToolCall {
705 id: "1".to_string(),
706 name: "write_file".to_string(),
707 arguments: json!({}),
708 }];
709
710 let outcomes = executor
711 .execute_tools_with_outcomes(&calls, ®istry, &agent_id)
712 .await;
713
714 assert_eq!(outcomes.len(), 1);
715 assert!(outcomes[0].is_success(), "should succeed with fs:write scope");
716 }
717
718 #[tokio::test]
719 async fn test_executor_scope_guard_denies_missing_scope() {
720 use crate::approvals::ApprovalConfig;
721 use crate::scope_guard::{ScopeGuard, ScopePolicy};
722
723 let mut registry = ToolRegistry::new();
724 registry.register(Arc::new(ScopedWriteTool));
725 let registry = Arc::new(registry);
726
727 let policy = ScopePolicy::new().grant("fs:read");
729 let guard = ScopeGuard::new(policy);
730
731 let config = ApprovalConfig::default();
733 let handler = TestApprovalHandler::with_config(config);
734 handler.set_decision("write_file", ApprovalDecision::denied("Scope not granted"));
735
736 let executor = ToolExecutor::with_approval_handler(ExecutorConfig::default(), handler)
737 .with_scope_guard(guard);
738 let agent_id = AgentId::new("agent-read-only");
739
740 let calls = vec![ToolCall {
741 id: "1".to_string(),
742 name: "write_file".to_string(),
743 arguments: json!({}),
744 }];
745
746 let outcomes = executor
747 .execute_tools_with_outcomes(&calls, ®istry, &agent_id)
748 .await;
749
750 assert_eq!(outcomes.len(), 1);
751 assert!(
752 matches!(outcomes[0], ExecutionOutcome::Denied { .. }),
753 "should be denied without fs:write scope"
754 );
755 }
756
757 #[tokio::test]
758 async fn test_executor_records_trace_when_present() {
759 use crate::trace::ExecutionTrace;
760 use std::sync::Arc as StdArc;
761
762 let mut registry = ToolRegistry::new();
763 registry.register(Arc::new(EchoTool));
764 let registry = Arc::new(registry);
765
766 let trace = StdArc::new(ExecutionTrace::new());
767 let executor = ToolExecutor::new(4).with_trace(trace.clone());
768 let agent_id = AgentId::new("test");
769
770 let calls = vec![
771 ToolCall {
772 id: "1".to_string(),
773 name: "echo".to_string(),
774 arguments: json!({"message": "hello"}),
775 },
776 ToolCall {
777 id: "2".to_string(),
778 name: "echo".to_string(),
779 arguments: json!({"message": "world"}),
780 },
781 ];
782
783 let results = executor.execute_tools(&calls, ®istry, &agent_id).await;
784 assert_eq!(results.len(), 2);
785 assert!(results[0].success);
786 assert!(results[1].success);
787
788 let finalized = trace.finalize("done".to_string());
789 assert_eq!(finalized.tool_calls.len(), 2);
790 assert!(finalized.tool_calls.iter().all(|t| t.tool_name == "echo"));
792 assert!(finalized.tool_calls.iter().all(|t| t.error.is_none()));
794 }
795
796 #[tokio::test]
797 async fn test_executor_no_trace_still_works() {
798 let mut registry = ToolRegistry::new();
800 registry.register(Arc::new(EchoTool));
801 let registry = Arc::new(registry);
802
803 let executor = ToolExecutor::new(4); let agent_id = AgentId::new("test");
805
806 let calls = vec![ToolCall {
807 id: "1".to_string(),
808 name: "echo".to_string(),
809 arguments: json!({"message": "hello"}),
810 }];
811
812 let results = executor.execute_tools(&calls, ®istry, &agent_id).await;
813 assert_eq!(results.len(), 1);
814 assert!(results[0].success);
815 }
816
817 #[tokio::test]
818 async fn test_executor_no_scope_guard_allows_all() {
819 let mut registry = ToolRegistry::new();
821 registry.register(Arc::new(ScopedWriteTool));
822 let registry = Arc::new(registry);
823
824 let executor = ToolExecutor::new(4); let agent_id = AgentId::new("agent-any");
826
827 let calls = vec![ToolCall {
828 id: "1".to_string(),
829 name: "write_file".to_string(),
830 arguments: json!({}),
831 }];
832
833 let outcomes = executor
834 .execute_tools_with_outcomes(&calls, ®istry, &agent_id)
835 .await;
836
837 assert_eq!(outcomes.len(), 1);
838 assert!(outcomes[0].is_success(), "no scope guard = allow all");
839 }
840}