use astrid_approval::manager::ApprovalHandler;
use astrid_approval::request::{
ApprovalDecision as InternalApprovalDecision, ApprovalRequest as InternalApprovalRequest,
ApprovalResponse as InternalApprovalResponse,
};
use astrid_approval::{SecurityInterceptor, SecurityPolicy, SensitiveAction};
use astrid_audit::{AuditAction, AuditLog, AuditOutcome, AuthorizationProof};
use astrid_capabilities::AuditEntryId;
use astrid_core::{
ApprovalDecision, ApprovalOption, ApprovalRequest, Frontend, RiskLevel, SessionId,
};
use astrid_crypto::KeyPair;
use astrid_hooks::result::HookContext;
use astrid_hooks::{HookEvent, HookManager};
use astrid_llm::{LlmProvider, LlmToolDefinition, Message, StreamEvent, ToolCall, ToolCallResult};
use astrid_mcp::McpClient;
use astrid_tools::{ToolContext, ToolRegistry, truncate_output};
use astrid_workspace::{
EscapeDecision, EscapeRequest, PathCheck, WorkspaceBoundary, WorkspaceConfig,
};
use async_trait::async_trait;
use futures::StreamExt;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tracing::{debug, error, info, warn};
use crate::context::ContextManager;
use crate::error::{RuntimeError, RuntimeResult};
use crate::session::AgentSession;
use crate::store::SessionStore;
use crate::subagent::SubAgentPool;
use crate::subagent_executor::{DEFAULT_SUBAGENT_TIMEOUT, SubAgentExecutor};
const DEFAULT_MAX_CONTEXT_TOKENS: usize = 100_000;
const DEFAULT_KEEP_RECENT_COUNT: usize = 10;
const DEFAULT_MAX_CONCURRENT_SUBAGENTS: usize = 4;
const DEFAULT_MAX_SUBAGENT_DEPTH: usize = 3;
#[derive(Debug, Clone)]
pub struct RuntimeConfig {
pub max_context_tokens: usize,
pub system_prompt: String,
pub auto_summarize: bool,
pub keep_recent_count: usize,
pub workspace: WorkspaceConfig,
pub max_concurrent_subagents: usize,
pub max_subagent_depth: usize,
pub default_subagent_timeout: std::time::Duration,
}
impl Default for RuntimeConfig {
fn default() -> Self {
let workspace_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
Self {
max_context_tokens: DEFAULT_MAX_CONTEXT_TOKENS,
system_prompt: String::new(),
auto_summarize: true,
keep_recent_count: DEFAULT_KEEP_RECENT_COUNT,
workspace: WorkspaceConfig::new(workspace_root),
max_concurrent_subagents: DEFAULT_MAX_CONCURRENT_SUBAGENTS,
max_subagent_depth: DEFAULT_MAX_SUBAGENT_DEPTH,
default_subagent_timeout: DEFAULT_SUBAGENT_TIMEOUT,
}
}
}
pub struct AgentRuntime<P: LlmProvider> {
llm: Arc<P>,
mcp: McpClient,
audit: Arc<AuditLog>,
sessions: SessionStore,
crypto: Arc<KeyPair>,
config: RuntimeConfig,
context: ContextManager,
boundary: WorkspaceBoundary,
hooks: Arc<HookManager>,
tool_registry: ToolRegistry,
shared_cwd: Arc<tokio::sync::RwLock<PathBuf>>,
security_policy: SecurityPolicy,
subagent_pool: Arc<SubAgentPool>,
self_arc: tokio::sync::RwLock<Option<std::sync::Weak<Self>>>,
}
impl<P: LlmProvider + 'static> AgentRuntime<P> {
#[must_use]
pub fn new(
llm: P,
mcp: McpClient,
audit: AuditLog,
sessions: SessionStore,
crypto: KeyPair,
config: RuntimeConfig,
) -> Self {
let context =
ContextManager::new(config.max_context_tokens).keep_recent(config.keep_recent_count);
let boundary = WorkspaceBoundary::new(config.workspace.clone());
let tool_registry = ToolRegistry::with_defaults();
let shared_cwd = Arc::new(tokio::sync::RwLock::new(config.workspace.root.clone()));
let subagent_pool = Arc::new(SubAgentPool::new(
config.max_concurrent_subagents,
config.max_subagent_depth,
));
info!(
workspace_root = %config.workspace.root.display(),
workspace_mode = ?config.workspace.mode,
max_concurrent_subagents = config.max_concurrent_subagents,
max_subagent_depth = config.max_subagent_depth,
"Workspace boundary initialized"
);
Self {
llm: Arc::new(llm),
mcp,
audit: Arc::new(audit),
sessions,
crypto: Arc::new(crypto),
config,
context,
boundary,
hooks: Arc::new(HookManager::new()),
tool_registry,
shared_cwd,
security_policy: SecurityPolicy::default(),
subagent_pool,
self_arc: tokio::sync::RwLock::new(None),
}
}
#[must_use]
pub fn new_arc(
llm: P,
mcp: McpClient,
audit: AuditLog,
sessions: SessionStore,
crypto: KeyPair,
config: RuntimeConfig,
hooks: Option<HookManager>,
) -> Arc<Self> {
Arc::new_cyclic(|weak| {
let mut runtime = Self::new(llm, mcp, audit, sessions, crypto, config);
if let Some(hook_manager) = hooks {
runtime.hooks = Arc::new(hook_manager);
}
runtime.self_arc = tokio::sync::RwLock::new(Some(weak.clone()));
runtime
})
}
#[must_use]
pub fn create_session(&self, workspace_override: Option<&Path>) -> AgentSession {
let workspace_root = workspace_override.unwrap_or(&self.config.workspace.root);
let system_prompt = if self.config.system_prompt.is_empty() {
astrid_tools::build_system_prompt(workspace_root)
} else {
self.config.system_prompt.clone()
};
let session = AgentSession::new(self.crypto.key_id(), system_prompt);
info!(session_id = %session.id, "Created new session");
session
}
pub fn save_session(&self, session: &AgentSession) -> RuntimeResult<()> {
self.sessions.save(session)
}
pub fn load_session(&self, id: &SessionId) -> RuntimeResult<Option<AgentSession>> {
self.sessions.load(id)
}
pub fn list_sessions(&self) -> RuntimeResult<Vec<crate::store::SessionSummary>> {
self.sessions.list_with_metadata()
}
#[allow(clippy::too_many_lines)]
pub async fn run_turn_streaming<F: Frontend + 'static>(
&self,
session: &mut AgentSession,
input: &str,
frontend: Arc<F>,
) -> RuntimeResult<()> {
let handler: Arc<dyn ApprovalHandler> = Arc::new(FrontendApprovalHandler {
frontend: Arc::clone(&frontend),
});
session.approval_manager.register_handler(handler).await;
session.add_message(Message::user(input));
{
let ctx = self
.build_hook_context(session, HookEvent::UserPrompt)
.with_data("input", serde_json::json!(input));
let result = self.hooks.trigger_simple(HookEvent::UserPrompt, ctx).await;
if let astrid_hooks::HookResult::Block { reason } = result {
return Err(RuntimeError::ApprovalDenied { reason });
}
if let astrid_hooks::HookResult::ContinueWith { modifications } = &result {
debug!(?modifications, "UserPrompt hook modified context");
}
}
{
let _ = self.audit.append(
session.id.clone(),
AuditAction::LlmRequest {
model: self.llm.model().to_string(),
input_tokens: session.token_count,
output_tokens: 0,
},
AuthorizationProof::System {
reason: "user input".to_string(),
},
AuditOutcome::success(),
);
}
if self.config.auto_summarize && self.context.needs_summarization(session) {
frontend.show_status("Summarizing context...");
let result = self.context.summarize(session, self.llm.as_ref()).await?;
{
let _ = self.audit.append(
session.id.clone(),
AuditAction::ContextSummarized {
evicted_count: result.messages_evicted,
tokens_freed: result.tokens_freed,
},
AuthorizationProof::System {
reason: "context overflow".to_string(),
},
AuditOutcome::success(),
);
}
}
let tool_ctx = ToolContext::with_shared_cwd(
self.config.workspace.root.clone(),
Arc::clone(&self.shared_cwd),
);
self.inject_subagent_spawner(&tool_ctx, session, &frontend, None)
.await;
let loop_result = self.run_loop(session, &*frontend, &tool_ctx).await;
loop_result?;
self.sessions.save(session)?;
Ok(())
}
pub async fn run_subagent_turn<F: Frontend + 'static>(
&self,
session: &mut AgentSession,
prompt: &str,
frontend: Arc<F>,
parent_subagent_id: Option<crate::subagent::SubAgentId>,
) -> RuntimeResult<()> {
let handler: Arc<dyn ApprovalHandler> = Arc::new(FrontendApprovalHandler {
frontend: Arc::clone(&frontend),
});
session.approval_manager.register_handler(handler).await;
session.add_message(Message::user(prompt));
{
let _ = self.audit.append(
session.id.clone(),
AuditAction::LlmRequest {
model: self.llm.model().to_string(),
input_tokens: session.token_count,
output_tokens: 0,
},
AuthorizationProof::System {
reason: "sub-agent prompt".to_string(),
},
AuditOutcome::success(),
);
}
let tool_ctx = ToolContext::with_shared_cwd(
self.config.workspace.root.clone(),
Arc::clone(&self.shared_cwd),
);
self.inject_subagent_spawner(&tool_ctx, session, &frontend, parent_subagent_id)
.await;
self.run_loop(session, &*frontend, &tool_ctx).await
}
#[allow(clippy::too_many_lines)]
async fn run_loop<F: Frontend>(
&self,
session: &mut AgentSession,
frontend: &F,
tool_ctx: &ToolContext,
) -> RuntimeResult<()> {
loop {
let mut llm_tools: Vec<LlmToolDefinition> = self.tool_registry.all_definitions();
let mcp_tools = self.mcp.list_tools().await?;
llm_tools.extend(mcp_tools.iter().map(|t| {
LlmToolDefinition::new(format!("{}:{}", &t.server, &t.name))
.with_description(t.description.clone().unwrap_or_default())
.with_schema(t.input_schema.clone())
}));
let mut stream = self
.llm
.stream(&session.messages, &llm_tools, &session.system_prompt)
.await?;
let mut response_text = String::new();
let mut tool_calls: Vec<ToolCall> = Vec::new();
let mut current_tool_args = String::new();
while let Some(event) = stream.next().await {
match event? {
StreamEvent::TextDelta(text) => {
frontend.show_status(&text);
response_text.push_str(&text);
},
StreamEvent::ToolCallStart { id, name } => {
tool_calls.push(ToolCall::new(id, name));
current_tool_args.clear();
},
StreamEvent::ToolCallDelta { id: _, args_delta } => {
current_tool_args.push_str(&args_delta);
},
StreamEvent::ToolCallEnd { id } => {
if let Some(call) = tool_calls.iter_mut().find(|c| c.id == id)
&& let Ok(args) = serde_json::from_str(¤t_tool_args)
{
call.arguments = args;
}
current_tool_args.clear();
},
StreamEvent::Usage {
input_tokens,
output_tokens,
} => {
debug!(input = input_tokens, output = output_tokens, "Token usage");
let cost = tokens_to_usd(input_tokens, output_tokens);
session.budget_tracker.record_cost(cost);
if let Some(ref ws_budget) = session.workspace_budget_tracker {
ws_budget.record_cost(cost);
}
},
StreamEvent::ReasoningDelta(_) => {
},
StreamEvent::Done => break,
StreamEvent::Error(e) => {
error!(error = %e, "Stream error");
return Err(RuntimeError::LlmError(
astrid_llm::LlmError::StreamingError(e),
));
},
}
}
if !tool_calls.is_empty() {
session.add_message(Message::assistant_with_tools(tool_calls.clone()));
for call in &tool_calls {
frontend.tool_started(&call.id, &call.name, &call.arguments);
let result = self
.execute_tool_call(session, call, frontend, tool_ctx)
.await?;
frontend.tool_completed(&call.id, &result.content, result.is_error);
session.add_message(Message::tool_result(result));
session.metadata.tool_call_count =
session.metadata.tool_call_count.saturating_add(1);
}
continue;
}
if !response_text.is_empty() {
session.add_message(Message::assistant(&response_text));
return Ok(());
}
break;
}
Ok(())
}
#[allow(clippy::too_many_lines)]
async fn execute_tool_call<F: Frontend>(
&self,
session: &mut AgentSession,
call: &ToolCall,
frontend: &F,
tool_ctx: &ToolContext,
) -> RuntimeResult<ToolCallResult> {
if ToolRegistry::is_builtin(&call.name) {
return self
.execute_builtin_tool(session, call, frontend, tool_ctx)
.await;
}
let (server, tool) = call.parse_name().ok_or_else(|| {
RuntimeError::McpError(astrid_mcp::McpError::ToolNotFound {
server: "unknown".to_string(),
tool: call.name.clone(),
})
})?;
if let Err(tool_error) = self
.check_workspace_boundaries(session, call, server, tool, frontend)
.await
{
return Ok(tool_error);
}
{
let ctx = self
.build_hook_context(session, HookEvent::PreToolCall)
.with_data("tool_name", serde_json::json!(tool))
.with_data("server_name", serde_json::json!(server))
.with_data("arguments", call.arguments.clone());
let result = self.hooks.trigger_simple(HookEvent::PreToolCall, ctx).await;
if let astrid_hooks::HookResult::Block { reason } = result {
return Ok(ToolCallResult::error(&call.id, reason));
}
if let astrid_hooks::HookResult::ContinueWith { modifications } = &result {
debug!(?modifications, "PreToolCall hook modified context");
}
}
let action = classify_tool_call(server, tool, &call.arguments);
let interceptor = self.build_interceptor(session);
let tool_result = match interceptor
.intercept(&action, &format!("MCP tool call to {server}:{tool}"), None)
.await
{
Ok(intercept_result) => {
if let Some(warning) = &intercept_result.budget_warning {
frontend.show_status(&format!(
"Budget warning: ${:.2}/${:.2} spent ({:.0}%)",
warning.current_spend, warning.session_max, warning.percent_used
));
}
let result = self
.mcp
.call_tool(server, tool, call.arguments.clone())
.await?;
ToolCallResult::success(&call.id, result.text_content())
},
Err(e) => ToolCallResult::error(&call.id, e.to_string()),
};
{
let hook_event = if tool_result.is_error {
HookEvent::ToolError
} else {
HookEvent::PostToolCall
};
let ctx = self
.build_hook_context(session, hook_event)
.with_data("tool_name", serde_json::json!(tool))
.with_data("server_name", serde_json::json!(server))
.with_data("is_error", serde_json::json!(tool_result.is_error));
let _ = self.hooks.trigger_simple(hook_event, ctx).await;
}
Ok(tool_result)
}
#[must_use]
pub fn config(&self) -> &RuntimeConfig {
&self.config
}
#[must_use]
pub fn audit(&self) -> &Arc<AuditLog> {
&self.audit
}
#[must_use]
pub fn mcp(&self) -> &McpClient {
&self.mcp
}
#[must_use]
pub fn key_id(&self) -> [u8; 8] {
self.crypto.key_id()
}
#[must_use]
pub fn boundary(&self) -> &WorkspaceBoundary {
&self.boundary
}
#[must_use]
pub fn with_hooks(mut self, hooks: HookManager) -> Self {
self.hooks = Arc::new(hooks);
self
}
#[must_use]
pub fn hooks(&self) -> &Arc<HookManager> {
&self.hooks
}
#[must_use]
pub fn subagent_pool(&self) -> &Arc<SubAgentPool> {
&self.subagent_pool
}
pub async fn set_self_arc(self: &Arc<Self>) {
*self.self_arc.write().await = Some(Arc::downgrade(self));
}
async fn inject_subagent_spawner<F: Frontend + 'static>(
&self,
tool_ctx: &ToolContext,
session: &AgentSession,
frontend: &Arc<F>,
parent_subagent_id: Option<crate::subagent::SubAgentId>,
) {
let self_arc = {
let guard = self.self_arc.read().await;
guard.as_ref().and_then(std::sync::Weak::upgrade)
};
if let Some(runtime_arc) = self_arc {
let executor = SubAgentExecutor::new(
runtime_arc,
Arc::clone(&self.subagent_pool),
Arc::clone(frontend),
session.user_id,
parent_subagent_id,
session.id.clone(),
Arc::clone(&session.allowance_store),
Arc::clone(&session.capabilities),
Arc::clone(&session.budget_tracker),
self.config.default_subagent_timeout,
);
tool_ctx
.set_subagent_spawner(Some(Arc::new(executor)))
.await;
} else {
debug!("No self_arc set — sub-agent spawning disabled for this turn");
}
}
#[allow(clippy::unused_self)]
fn build_hook_context(&self, session: &AgentSession, event: HookEvent) -> HookContext {
let mut uuid_bytes = [0u8; 16];
uuid_bytes[..8].copy_from_slice(&session.user_id);
let user_uuid = uuid::Uuid::from_bytes(uuid_bytes);
HookContext::new(event)
.with_session(session.id.0)
.with_user(user_uuid)
}
fn build_interceptor(&self, session: &AgentSession) -> SecurityInterceptor {
SecurityInterceptor::new(
Arc::clone(&session.capabilities),
Arc::clone(&session.approval_manager),
self.security_policy.clone(),
Arc::clone(&session.budget_tracker),
Arc::clone(&self.audit),
Arc::clone(&self.crypto),
session.id.clone(),
Arc::clone(&session.allowance_store),
Some(self.config.workspace.root.clone()),
session.workspace_budget_tracker.clone(),
)
}
async fn execute_builtin_tool<F: Frontend>(
&self,
session: &mut AgentSession,
call: &ToolCall,
frontend: &F,
tool_ctx: &ToolContext,
) -> RuntimeResult<ToolCallResult> {
let tool_name = &call.name;
let Some(tool) = self.tool_registry.get(tool_name) else {
return Ok(ToolCallResult::error(
&call.id,
format!("Unknown built-in tool: {tool_name}"),
));
};
if let Err(tool_error) = self
.check_workspace_boundaries(session, call, "builtin", tool_name, frontend)
.await
{
return Ok(tool_error);
}
{
let ctx = self
.build_hook_context(session, HookEvent::PreToolCall)
.with_data("tool_name", serde_json::json!(tool_name))
.with_data("server_name", serde_json::json!("builtin"))
.with_data("arguments", call.arguments.clone());
let result = self.hooks.trigger_simple(HookEvent::PreToolCall, ctx).await;
if let astrid_hooks::HookResult::Block { reason } = result {
return Ok(ToolCallResult::error(&call.id, reason));
}
}
let action = classify_builtin_tool_call(tool_name, &call.arguments);
let interceptor = self.build_interceptor(session);
match interceptor
.intercept(&action, &format!("Built-in tool: {tool_name}"), None)
.await
{
Ok(intercept_result) => {
if let Some(warning) = &intercept_result.budget_warning {
frontend.show_status(&format!(
"Budget warning: ${:.2}/${:.2} spent ({:.0}%)",
warning.current_spend, warning.session_max, warning.percent_used
));
}
},
Err(e) => return Ok(ToolCallResult::error(&call.id, e.to_string())),
}
let tool_result = match tool.execute(call.arguments.clone(), tool_ctx).await {
Ok(output) => {
let output = truncate_output(output);
ToolCallResult::success(&call.id, output)
},
Err(e) => ToolCallResult::error(&call.id, e.to_string()),
};
{
let hook_event = if tool_result.is_error {
HookEvent::ToolError
} else {
HookEvent::PostToolCall
};
let ctx = self
.build_hook_context(session, hook_event)
.with_data("tool_name", serde_json::json!(tool_name))
.with_data("server_name", serde_json::json!("builtin"))
.with_data("is_error", serde_json::json!(tool_result.is_error));
let _ = self.hooks.trigger_simple(hook_event, ctx).await;
}
Ok(tool_result)
}
#[allow(clippy::too_many_lines)]
async fn check_workspace_boundaries<F: Frontend>(
&self,
session: &mut AgentSession,
call: &ToolCall,
server: &str,
tool: &str,
frontend: &F,
) -> Result<(), ToolCallResult> {
let paths = extract_paths_from_args(&call.arguments);
if paths.is_empty() {
return Ok(());
}
for path in &paths {
if session.escape_handler.is_allowed(path) {
debug!(path = %path.display(), "Path already approved by escape handler");
continue;
}
let check = self.boundary.check(path);
match check {
PathCheck::Allowed | PathCheck::AutoAllowed => {},
PathCheck::NeverAllowed => {
warn!(
path = %path.display(),
tool = %format!("{server}:{tool}"),
"Access to protected path blocked"
);
{
let _ = self.audit.append(
session.id.clone(),
AuditAction::ApprovalDenied {
action: format!("{server}:{tool} -> {}", path.display()),
reason: Some("protected system path".to_string()),
},
AuthorizationProof::System {
reason: "workspace boundary: never-allowed path".to_string(),
},
AuditOutcome::failure("protected path"),
);
}
return Err(ToolCallResult::error(
&call.id,
format!(
"Access to {} is blocked — this is a protected system path",
path.display()
),
));
},
PathCheck::RequiresApproval => {
let escape_request = EscapeRequest::new(
path.clone(),
infer_operation(tool),
format!(
"Tool {server}:{tool} wants to access {} outside the workspace",
path.display()
),
)
.with_tool(tool)
.with_server(server);
let approval_request = ApprovalRequest::new(
format!("workspace-escape:{server}:{tool}"),
format!(
"Allow {} {} outside workspace?\n Path: {}",
tool,
escape_request.operation,
path.display()
),
)
.with_risk_level(risk_level_for_operation(escape_request.operation))
.with_resource(path.display().to_string());
let decision =
frontend
.request_approval(approval_request)
.await
.map_err(|_| {
ToolCallResult::error(
&call.id,
"Failed to request workspace escape approval",
)
})?;
let escape_decision = match decision.decision {
ApprovalOption::AllowOnce => EscapeDecision::AllowOnce,
ApprovalOption::AllowSession | ApprovalOption::AllowWorkspace => {
EscapeDecision::AllowSession
},
ApprovalOption::AllowAlways => EscapeDecision::AllowAlways,
ApprovalOption::Deny => EscapeDecision::Deny,
};
session
.escape_handler
.process_decision(&escape_request, escape_decision);
if escape_decision.is_allowed() {
let _ = self.audit.append(
session.id.clone(),
AuditAction::ApprovalGranted {
action: format!("{server}:{tool}"),
resource: Some(path.display().to_string()),
scope: match decision.decision {
ApprovalOption::AllowSession => {
astrid_audit::ApprovalScope::Session
},
ApprovalOption::AllowWorkspace => {
astrid_audit::ApprovalScope::Workspace
},
ApprovalOption::AllowAlways => {
astrid_audit::ApprovalScope::Always
},
ApprovalOption::AllowOnce | ApprovalOption::Deny => {
astrid_audit::ApprovalScope::Once
},
},
},
AuthorizationProof::UserApproval {
user_id: session.user_id,
approval_entry_id: AuditEntryId::new(),
},
AuditOutcome::success(),
);
} else {
let _ = self.audit.append(
session.id.clone(),
AuditAction::ApprovalDenied {
action: format!("{server}:{tool} -> {}", path.display()),
reason: Some(
decision
.reason
.clone()
.unwrap_or_else(|| "user denied".to_string()),
),
},
AuthorizationProof::UserApproval {
user_id: session.user_id,
approval_entry_id: AuditEntryId::new(),
},
AuditOutcome::failure("user denied workspace escape"),
);
}
if !escape_decision.is_allowed() {
return Err(ToolCallResult::error(
&call.id,
decision.reason.unwrap_or_else(|| {
format!("Access to {} denied — outside workspace", path.display())
}),
));
}
info!(
path = %path.display(),
decision = ?escape_decision,
"Workspace escape approved"
);
},
}
}
Ok(())
}
}
fn extract_paths_from_args(args: &serde_json::Value) -> Vec<PathBuf> {
const PATH_KEYS: &[&str] = &[
"path",
"file",
"file_path",
"filepath",
"filename",
"directory",
"dir",
"target",
"source",
"destination",
"src",
"dst",
"input",
"output",
"uri",
"url",
"cwd",
"working_directory",
];
let mut paths = Vec::new();
if let Some(obj) = args.as_object() {
for (key, value) in obj {
let key_lower = key.to_lowercase();
if let Some(s) = value.as_str()
&& PATH_KEYS.contains(&key_lower.as_str())
&& let Some(path) = try_extract_path(s)
{
paths.push(path);
}
}
}
paths
}
fn try_extract_path(value: &str) -> Option<PathBuf> {
if let Some(stripped) = value.strip_prefix("file://") {
return Some(PathBuf::from(stripped));
}
if value.contains("://") {
return None;
}
if value.starts_with('/')
|| value.starts_with("~/")
|| value.starts_with("./")
|| value.starts_with("../")
{
return Some(PathBuf::from(value));
}
None
}
fn infer_operation(tool: &str) -> astrid_workspace::escape::EscapeOperation {
use astrid_workspace::escape::EscapeOperation;
let tool_lower = tool.to_lowercase();
if tool_lower.contains("read") || tool_lower.contains("get") || tool_lower.contains("cat") {
EscapeOperation::Read
} else if tool_lower.contains("write")
|| tool_lower.contains("set")
|| tool_lower.contains("put")
|| tool_lower.contains("edit")
|| tool_lower.contains("update")
{
EscapeOperation::Write
} else if tool_lower.contains("create")
|| tool_lower.contains("mkdir")
|| tool_lower.contains("touch")
|| tool_lower.contains("new")
{
EscapeOperation::Create
} else if tool_lower.contains("delete")
|| tool_lower.contains("remove")
|| tool_lower.contains("rm")
{
EscapeOperation::Delete
} else if tool_lower.contains("exec")
|| tool_lower.contains("run")
|| tool_lower.contains("launch")
{
EscapeOperation::Execute
} else if tool_lower.contains("list") || tool_lower.contains("ls") || tool_lower.contains("dir")
{
EscapeOperation::List
} else {
EscapeOperation::Read
}
}
fn risk_level_for_operation(operation: astrid_workspace::escape::EscapeOperation) -> RiskLevel {
use astrid_workspace::escape::EscapeOperation;
match operation {
EscapeOperation::Read | EscapeOperation::List => RiskLevel::Medium,
EscapeOperation::Write | EscapeOperation::Create => RiskLevel::High,
EscapeOperation::Delete | EscapeOperation::Execute => RiskLevel::Critical,
}
}
fn classify_tool_call(server: &str, tool: &str, args: &serde_json::Value) -> SensitiveAction {
let tool_lower = tool.to_lowercase();
if (tool_lower.contains("delete") || tool_lower.contains("remove"))
&& let Some(path) = args
.get("path")
.or_else(|| args.get("file"))
.and_then(|v| v.as_str())
{
return SensitiveAction::FileDelete {
path: path.to_string(),
};
}
if tool_lower.contains("exec") || tool_lower.contains("run") || tool_lower.contains("bash") {
let command = args
.get("command")
.and_then(|v| v.as_str())
.unwrap_or(tool)
.to_string();
let cmd_args = args
.get("args")
.and_then(|v| v.as_array())
.map(|a| {
a.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
return SensitiveAction::ExecuteCommand {
command,
args: cmd_args,
};
}
if tool_lower.contains("write")
&& let Some(path) = args
.get("path")
.or_else(|| args.get("file_path"))
.and_then(|v| v.as_str())
&& path.starts_with('/')
{
return SensitiveAction::FileWriteOutsideSandbox {
path: path.to_string(),
};
}
SensitiveAction::McpToolCall {
server: server.to_string(),
tool: tool.to_string(),
}
}
fn to_frontend_request(internal: &InternalApprovalRequest) -> ApprovalRequest {
ApprovalRequest::new(
internal.action.action_type().to_string(),
internal.action.summary(),
)
.with_risk_level(internal.assessment.level)
.with_resource(format!("{}", internal.action))
}
fn to_internal_response(
request: &InternalApprovalRequest,
decision: &ApprovalDecision,
) -> InternalApprovalResponse {
let internal_decision = match decision.decision {
ApprovalOption::AllowOnce => InternalApprovalDecision::Approve,
ApprovalOption::AllowSession => InternalApprovalDecision::ApproveSession,
ApprovalOption::AllowWorkspace => InternalApprovalDecision::ApproveWorkspace,
ApprovalOption::AllowAlways => InternalApprovalDecision::ApproveAlways,
ApprovalOption::Deny => InternalApprovalDecision::Deny {
reason: decision
.reason
.clone()
.unwrap_or_else(|| "denied by user".to_string()),
},
};
InternalApprovalResponse::new(request.id.clone(), internal_decision)
}
fn classify_builtin_tool_call(tool_name: &str, args: &serde_json::Value) -> SensitiveAction {
match tool_name {
"bash" => {
let command = args
.get("command")
.and_then(|v| v.as_str())
.unwrap_or("bash")
.to_string();
SensitiveAction::ExecuteCommand {
command,
args: Vec::new(),
}
},
"write_file" | "edit_file" => {
let path = args
.get("file_path")
.or_else(|| args.get("path"))
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
SensitiveAction::FileWriteOutsideSandbox { path }
},
"read_file" | "glob" | "grep" | "list_directory" => {
let path = args
.get("file_path")
.or_else(|| args.get("path"))
.or_else(|| args.get("pattern"))
.and_then(|v| v.as_str())
.unwrap_or(".")
.to_string();
SensitiveAction::FileRead { path }
},
other => SensitiveAction::McpToolCall {
server: "builtin".to_string(),
tool: other.to_string(),
},
}
}
struct FrontendApprovalHandler<F: Frontend> {
frontend: Arc<F>,
}
#[async_trait]
impl<F: Frontend> ApprovalHandler for FrontendApprovalHandler<F> {
async fn request_approval(
&self,
request: InternalApprovalRequest,
) -> Option<InternalApprovalResponse> {
let frontend_request = to_frontend_request(&request);
match self.frontend.request_approval(frontend_request).await {
Ok(decision) => Some(to_internal_response(&request, &decision)),
Err(_) => None,
}
}
fn is_available(&self) -> bool {
true
}
}
const INPUT_RATE_PER_1K: f64 = 0.003; const OUTPUT_RATE_PER_1K: f64 = 0.015;
#[allow(clippy::cast_precision_loss)]
fn tokens_to_usd(input_tokens: usize, output_tokens: usize) -> f64 {
let input_cost = (input_tokens as f64 / 1000.0) * INPUT_RATE_PER_1K;
let output_cost = (output_tokens as f64 / 1000.0) * OUTPUT_RATE_PER_1K;
input_cost + output_cost
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_paths_from_args() {
let args = serde_json::json!({
"path": "/home/user/file.txt",
"content": "some data",
"count": 42
});
let paths = extract_paths_from_args(&args);
assert_eq!(paths.len(), 1);
assert_eq!(paths[0], PathBuf::from("/home/user/file.txt"));
}
#[test]
fn test_extract_paths_ignores_non_path_values() {
let args = serde_json::json!({
"path": "not-a-path",
"url": "https://example.com",
});
let paths = extract_paths_from_args(&args);
assert!(paths.is_empty());
}
#[test]
fn test_extract_paths_file_uri() {
let args = serde_json::json!({
"uri": "file:///tmp/test.txt"
});
let paths = extract_paths_from_args(&args);
assert_eq!(paths.len(), 1);
assert_eq!(paths[0], PathBuf::from("/tmp/test.txt"));
}
#[test]
fn test_extract_paths_relative() {
let args = serde_json::json!({
"file": "./src/main.rs",
"dir": "../other"
});
let paths = extract_paths_from_args(&args);
assert_eq!(paths.len(), 2);
}
#[test]
fn test_infer_operation() {
use astrid_workspace::escape::EscapeOperation;
assert_eq!(infer_operation("read_file"), EscapeOperation::Read);
assert_eq!(infer_operation("write_file"), EscapeOperation::Write);
assert_eq!(infer_operation("create_directory"), EscapeOperation::Create);
assert_eq!(infer_operation("delete_file"), EscapeOperation::Delete);
assert_eq!(infer_operation("execute_command"), EscapeOperation::Execute);
assert_eq!(infer_operation("list_files"), EscapeOperation::List);
assert_eq!(infer_operation("unknown_tool"), EscapeOperation::Read);
}
#[test]
fn test_risk_level_for_operation() {
use astrid_workspace::escape::EscapeOperation;
assert_eq!(
risk_level_for_operation(EscapeOperation::Read),
RiskLevel::Medium
);
assert_eq!(
risk_level_for_operation(EscapeOperation::Write),
RiskLevel::High
);
assert_eq!(
risk_level_for_operation(EscapeOperation::Delete),
RiskLevel::Critical
);
}
}