1use super::Tool;
16use crate::context::TenantContext;
17use crate::kernel::{ExecutionId, StepId};
18use crate::policy::{PolicyAction, PolicyContext, PolicyDecision, PolicyEvaluator, ToolPolicy};
19use crate::streaming::{EventEmitter, StreamEvent};
20use serde_json::Value;
21use std::sync::Arc;
22
23#[derive(Debug, thiserror::Error)]
25pub enum ToolExecutionError {
26 #[error("Tool execution denied: {reason}")]
27 PolicyDenied { reason: String },
28
29 #[error("Tool execution error: {0}")]
30 ExecutionFailed(#[from] anyhow::Error),
31}
32
33#[derive(Debug, Clone)]
35pub struct ToolExecutionContext {
36 pub execution_id: ExecutionId,
38 pub step_id: Option<StepId>,
40 pub tenant: TenantContext,
42 pub metadata: std::collections::HashMap<String, String>,
44}
45
46impl ToolExecutionContext {
47 pub fn new(execution_id: ExecutionId, tenant: TenantContext) -> Self {
49 Self {
50 execution_id,
51 step_id: None,
52 tenant,
53 metadata: std::collections::HashMap::new(),
54 }
55 }
56
57 pub fn with_step(mut self, step_id: StepId) -> Self {
59 self.step_id = Some(step_id);
60 self
61 }
62}
63
64pub struct ToolExecutor {
70 policy: Arc<ToolPolicy>,
71 emitter: Option<EventEmitter>,
73}
74
75impl ToolExecutor {
76 pub fn new(policy: ToolPolicy) -> Self {
78 Self {
79 policy: Arc::new(policy),
80 emitter: None,
81 }
82 }
83
84 pub fn with_shared_policy(policy: Arc<ToolPolicy>) -> Self {
86 Self {
87 policy,
88 emitter: None,
89 }
90 }
91
92 pub fn with_emitter(mut self, emitter: EventEmitter) -> Self {
94 self.emitter = Some(emitter);
95 self
96 }
97
98 pub fn set_emitter(&mut self, emitter: EventEmitter) {
100 self.emitter = Some(emitter);
101 }
102
103 pub async fn execute(
108 &self,
109 tool: &dyn Tool,
110 args: Value,
111 ctx: &ToolExecutionContext,
112 ) -> Result<Value, ToolExecutionError> {
113 let policy_ctx = PolicyContext {
115 tenant_id: Some(ctx.tenant.tenant_id().as_str().to_string()),
116 user_id: ctx.tenant.user_id().map(|u| u.as_str().to_string()),
117 action: PolicyAction::InvokeTool {
118 tool_name: tool.name().to_string(),
119 },
120 metadata: ctx.metadata.clone(),
121 };
122
123 let tool_name = tool.name().to_string();
124
125 match self.policy.evaluate(&policy_ctx) {
127 PolicyDecision::Allow => {
128 if let Some(emitter) = &self.emitter {
130 emitter.emit(StreamEvent::policy_decision_allow(
131 &ctx.execution_id,
132 ctx.step_id.as_ref(),
133 &tool_name,
134 ));
135 }
136 tool.execute(args).await.map_err(ToolExecutionError::from)
138 }
139 PolicyDecision::Deny { reason } => {
140 if let Some(emitter) = &self.emitter {
142 emitter.emit(StreamEvent::policy_decision_deny(
143 &ctx.execution_id,
144 ctx.step_id.as_ref(),
145 &tool_name,
146 &reason,
147 ));
148 }
149 Err(ToolExecutionError::PolicyDenied { reason })
150 }
151 PolicyDecision::Warn { message } => {
152 if let Some(emitter) = &self.emitter {
154 emitter.emit(StreamEvent::policy_decision_warn(
155 &ctx.execution_id,
156 ctx.step_id.as_ref(),
157 &tool_name,
158 &message,
159 ));
160 }
161 tracing::warn!(tool = tool.name(), message = %message, "Tool policy warning");
163 tool.execute(args).await.map_err(ToolExecutionError::from)
164 }
165 }
166 }
167
168 pub async fn execute_sequence(
170 &self,
171 tools: &[(Arc<dyn Tool>, Value)],
172 ctx: &ToolExecutionContext,
173 ) -> Result<Vec<Value>, ToolExecutionError> {
174 let mut results = Vec::new();
175 for (tool, args) in tools {
176 let result = self.execute(tool.as_ref(), args.clone(), ctx).await?;
177 results.push(result);
178 }
179 Ok(results)
180 }
181
182 pub fn is_allowed(&self, tool_name: &str, ctx: &ToolExecutionContext) -> bool {
184 let policy_ctx = PolicyContext {
185 tenant_id: Some(ctx.tenant.tenant_id().as_str().to_string()),
186 user_id: ctx.tenant.user_id().map(|u| u.as_str().to_string()),
187 action: PolicyAction::InvokeTool {
188 tool_name: tool_name.to_string(),
189 },
190 metadata: std::collections::HashMap::new(),
191 };
192
193 matches!(
194 self.policy.evaluate(&policy_ctx),
195 PolicyDecision::Allow | PolicyDecision::Warn { .. }
196 )
197 }
198
199 pub fn get_permissions(&self, tool_name: &str) -> &crate::policy::ToolPermissions {
201 self.policy.get_permissions(tool_name)
202 }
203
204 pub fn policy(&self) -> &ToolPolicy {
206 &self.policy
207 }
208}
209
210impl Default for ToolExecutor {
211 fn default() -> Self {
212 Self::new(ToolPolicy::default())
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::kernel::TenantId;
220 use async_trait::async_trait;
221
222 struct MockTool {
223 name: String,
224 }
225
226 #[async_trait]
227 impl Tool for MockTool {
228 fn name(&self) -> &str {
229 &self.name
230 }
231
232 fn description(&self) -> &str {
233 "Mock tool for testing"
234 }
235
236 async fn execute(&self, args: Value) -> anyhow::Result<Value> {
237 Ok(args)
238 }
239 }
240
241 #[tokio::test]
242 async fn test_tool_execution_allowed() {
243 let policy = ToolPolicy::new();
244 let executor = ToolExecutor::new(policy);
245
246 let tool = MockTool {
247 name: "test_tool".to_string(),
248 };
249 let ctx = ToolExecutionContext::new(
250 ExecutionId::new(),
251 TenantContext::new(TenantId::from("tenant_123")),
252 );
253
254 let result = executor
255 .execute(&tool, Value::String("test".into()), &ctx)
256 .await;
257 assert!(result.is_ok());
258 }
259
260 #[tokio::test]
261 async fn test_tool_execution_blocked() {
262 let policy = ToolPolicy::new().block_tool("blocked_tool");
263 let executor = ToolExecutor::new(policy);
264
265 let tool = MockTool {
266 name: "blocked_tool".to_string(),
267 };
268 let ctx = ToolExecutionContext::new(
269 ExecutionId::new(),
270 TenantContext::new(TenantId::from("tenant_123")),
271 );
272
273 let result = executor.execute(&tool, Value::Null, &ctx).await;
274 assert!(matches!(
275 result,
276 Err(ToolExecutionError::PolicyDenied { .. })
277 ));
278 }
279
280 #[tokio::test]
281 async fn test_is_allowed() {
282 let policy = ToolPolicy::new().block_tool("blocked_tool");
283 let executor = ToolExecutor::new(policy);
284
285 let ctx = ToolExecutionContext::new(
286 ExecutionId::new(),
287 TenantContext::new(TenantId::from("tenant_123")),
288 );
289
290 assert!(executor.is_allowed("allowed_tool", &ctx));
291 assert!(!executor.is_allowed("blocked_tool", &ctx));
292 }
293
294 #[tokio::test]
295 async fn test_policy_decision_event_emission_allowed() {
296 let policy = ToolPolicy::new();
297 let emitter = EventEmitter::new();
298 let executor = ToolExecutor::new(policy).with_emitter(emitter.clone());
299
300 let tool = MockTool {
301 name: "test_tool".to_string(),
302 };
303 let ctx = ToolExecutionContext::new(
304 ExecutionId::new(),
305 TenantContext::new(TenantId::from("tenant_123")),
306 );
307
308 let result = executor.execute(&tool, Value::Null, &ctx).await;
309 assert!(result.is_ok());
310
311 let events = emitter.drain();
313 assert_eq!(events.len(), 1);
314 match &events[0] {
315 StreamEvent::PolicyDecision {
316 decision,
317 tool_name,
318 ..
319 } => {
320 assert_eq!(decision, "allow");
321 assert_eq!(tool_name, "test_tool");
322 }
323 _ => panic!("Expected PolicyDecision event"),
324 }
325 }
326
327 #[tokio::test]
328 async fn test_policy_decision_event_emission_denied() {
329 let policy = ToolPolicy::new().block_tool("blocked_tool");
330 let emitter = EventEmitter::new();
331 let executor = ToolExecutor::new(policy).with_emitter(emitter.clone());
332
333 let tool = MockTool {
334 name: "blocked_tool".to_string(),
335 };
336 let ctx = ToolExecutionContext::new(
337 ExecutionId::new(),
338 TenantContext::new(TenantId::from("tenant_123")),
339 );
340
341 let result = executor.execute(&tool, Value::Null, &ctx).await;
342 assert!(matches!(
343 result,
344 Err(ToolExecutionError::PolicyDenied { .. })
345 ));
346
347 let events = emitter.drain();
349 assert_eq!(events.len(), 1);
350 match &events[0] {
351 StreamEvent::PolicyDecision {
352 decision,
353 tool_name,
354 reason,
355 ..
356 } => {
357 assert_eq!(decision, "deny");
358 assert_eq!(tool_name, "blocked_tool");
359 assert!(reason.is_some());
360 }
361 _ => panic!("Expected PolicyDecision event"),
362 }
363 }
364}