use super::Tool;
use crate::context::TenantContext;
use crate::kernel::{ExecutionId, StepId};
use crate::policy::{PolicyAction, PolicyContext, PolicyDecision, PolicyEvaluator, ToolPolicy};
use crate::streaming::{EventEmitter, StreamEvent};
use serde_json::Value;
use std::sync::Arc;
#[derive(Debug, thiserror::Error)]
pub enum ToolExecutionError {
#[error("Tool execution denied: {reason}")]
PolicyDenied { reason: String },
#[error("Tool execution error: {0}")]
ExecutionFailed(#[from] anyhow::Error),
}
#[derive(Debug, Clone)]
pub struct ToolExecutionContext {
pub execution_id: ExecutionId,
pub step_id: Option<StepId>,
pub tenant: TenantContext,
pub metadata: std::collections::HashMap<String, String>,
}
impl ToolExecutionContext {
pub fn new(execution_id: ExecutionId, tenant: TenantContext) -> Self {
Self {
execution_id,
step_id: None,
tenant,
metadata: std::collections::HashMap::new(),
}
}
pub fn with_step(mut self, step_id: StepId) -> Self {
self.step_id = Some(step_id);
self
}
}
pub struct ToolExecutor {
policy: Arc<ToolPolicy>,
emitter: Option<EventEmitter>,
}
impl ToolExecutor {
pub fn new(policy: ToolPolicy) -> Self {
Self {
policy: Arc::new(policy),
emitter: None,
}
}
pub fn with_shared_policy(policy: Arc<ToolPolicy>) -> Self {
Self {
policy,
emitter: None,
}
}
pub fn with_emitter(mut self, emitter: EventEmitter) -> Self {
self.emitter = Some(emitter);
self
}
pub fn set_emitter(&mut self, emitter: EventEmitter) {
self.emitter = Some(emitter);
}
pub async fn execute(
&self,
tool: &dyn Tool,
args: Value,
ctx: &ToolExecutionContext,
) -> Result<Value, ToolExecutionError> {
let policy_ctx = PolicyContext {
tenant_id: Some(ctx.tenant.tenant_id().as_str().to_string()),
user_id: ctx.tenant.user_id().map(|u| u.as_str().to_string()),
action: PolicyAction::InvokeTool {
tool_name: tool.name().to_string(),
},
metadata: ctx.metadata.clone(),
};
let tool_name = tool.name().to_string();
match self.policy.evaluate(&policy_ctx) {
PolicyDecision::Allow => {
if let Some(emitter) = &self.emitter {
emitter.emit(StreamEvent::policy_decision_allow(
&ctx.execution_id,
ctx.step_id.as_ref(),
&tool_name,
));
}
tool.execute(args).await.map_err(ToolExecutionError::from)
}
PolicyDecision::Deny { reason } => {
if let Some(emitter) = &self.emitter {
emitter.emit(StreamEvent::policy_decision_deny(
&ctx.execution_id,
ctx.step_id.as_ref(),
&tool_name,
&reason,
));
}
Err(ToolExecutionError::PolicyDenied { reason })
}
PolicyDecision::Warn { message } => {
if let Some(emitter) = &self.emitter {
emitter.emit(StreamEvent::policy_decision_warn(
&ctx.execution_id,
ctx.step_id.as_ref(),
&tool_name,
&message,
));
}
tracing::warn!(tool = tool.name(), message = %message, "Tool policy warning");
tool.execute(args).await.map_err(ToolExecutionError::from)
}
}
}
pub async fn execute_sequence(
&self,
tools: &[(Arc<dyn Tool>, Value)],
ctx: &ToolExecutionContext,
) -> Result<Vec<Value>, ToolExecutionError> {
let mut results = Vec::new();
for (tool, args) in tools {
let result = self.execute(tool.as_ref(), args.clone(), ctx).await?;
results.push(result);
}
Ok(results)
}
pub fn is_allowed(&self, tool_name: &str, ctx: &ToolExecutionContext) -> bool {
let policy_ctx = PolicyContext {
tenant_id: Some(ctx.tenant.tenant_id().as_str().to_string()),
user_id: ctx.tenant.user_id().map(|u| u.as_str().to_string()),
action: PolicyAction::InvokeTool {
tool_name: tool_name.to_string(),
},
metadata: std::collections::HashMap::new(),
};
matches!(
self.policy.evaluate(&policy_ctx),
PolicyDecision::Allow | PolicyDecision::Warn { .. }
)
}
pub fn get_permissions(&self, tool_name: &str) -> &crate::policy::ToolPermissions {
self.policy.get_permissions(tool_name)
}
pub fn policy(&self) -> &ToolPolicy {
&self.policy
}
}
impl Default for ToolExecutor {
fn default() -> Self {
Self::new(ToolPolicy::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::TenantId;
use async_trait::async_trait;
struct MockTool {
name: String,
}
#[async_trait]
impl Tool for MockTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"Mock tool for testing"
}
async fn execute(&self, args: Value) -> anyhow::Result<Value> {
Ok(args)
}
}
#[tokio::test]
async fn test_tool_execution_allowed() {
let policy = ToolPolicy::new();
let executor = ToolExecutor::new(policy);
let tool = MockTool {
name: "test_tool".to_string(),
};
let ctx = ToolExecutionContext::new(
ExecutionId::new(),
TenantContext::new(TenantId::from("tenant_123")),
);
let result = executor
.execute(&tool, Value::String("test".into()), &ctx)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_tool_execution_blocked() {
let policy = ToolPolicy::new().block_tool("blocked_tool");
let executor = ToolExecutor::new(policy);
let tool = MockTool {
name: "blocked_tool".to_string(),
};
let ctx = ToolExecutionContext::new(
ExecutionId::new(),
TenantContext::new(TenantId::from("tenant_123")),
);
let result = executor.execute(&tool, Value::Null, &ctx).await;
assert!(matches!(
result,
Err(ToolExecutionError::PolicyDenied { .. })
));
}
#[tokio::test]
async fn test_is_allowed() {
let policy = ToolPolicy::new().block_tool("blocked_tool");
let executor = ToolExecutor::new(policy);
let ctx = ToolExecutionContext::new(
ExecutionId::new(),
TenantContext::new(TenantId::from("tenant_123")),
);
assert!(executor.is_allowed("allowed_tool", &ctx));
assert!(!executor.is_allowed("blocked_tool", &ctx));
}
#[tokio::test]
async fn test_policy_decision_event_emission_allowed() {
let policy = ToolPolicy::new();
let emitter = EventEmitter::new();
let executor = ToolExecutor::new(policy).with_emitter(emitter.clone());
let tool = MockTool {
name: "test_tool".to_string(),
};
let ctx = ToolExecutionContext::new(
ExecutionId::new(),
TenantContext::new(TenantId::from("tenant_123")),
);
let result = executor.execute(&tool, Value::Null, &ctx).await;
assert!(result.is_ok());
let events = emitter.drain();
assert_eq!(events.len(), 1);
match &events[0] {
StreamEvent::PolicyDecision {
decision,
tool_name,
..
} => {
assert_eq!(decision, "allow");
assert_eq!(tool_name, "test_tool");
}
_ => panic!("Expected PolicyDecision event"),
}
}
#[tokio::test]
async fn test_policy_decision_event_emission_denied() {
let policy = ToolPolicy::new().block_tool("blocked_tool");
let emitter = EventEmitter::new();
let executor = ToolExecutor::new(policy).with_emitter(emitter.clone());
let tool = MockTool {
name: "blocked_tool".to_string(),
};
let ctx = ToolExecutionContext::new(
ExecutionId::new(),
TenantContext::new(TenantId::from("tenant_123")),
);
let result = executor.execute(&tool, Value::Null, &ctx).await;
assert!(matches!(
result,
Err(ToolExecutionError::PolicyDenied { .. })
));
let events = emitter.drain();
assert_eq!(events.len(), 1);
match &events[0] {
StreamEvent::PolicyDecision {
decision,
tool_name,
reason,
..
} => {
assert_eq!(decision, "deny");
assert_eq!(tool_name, "blocked_tool");
assert!(reason.is_some());
}
_ => panic!("Expected PolicyDecision event"),
}
}
}