1use astrid_approval::manager::ApprovalHandler;
6use astrid_approval::request::{
7 ApprovalDecision as InternalApprovalDecision, ApprovalRequest as InternalApprovalRequest,
8 ApprovalResponse as InternalApprovalResponse,
9};
10use astrid_approval::{SecurityInterceptor, SecurityPolicy, SensitiveAction};
11use astrid_audit::{AuditAction, AuditLog, AuditOutcome, AuthorizationProof};
12use astrid_capabilities::AuditEntryId;
13use astrid_core::{
14 ApprovalDecision, ApprovalOption, ApprovalRequest, Frontend, RiskLevel, SessionId,
15};
16use astrid_crypto::KeyPair;
17use astrid_hooks::result::HookContext;
18use astrid_hooks::{HookEvent, HookManager};
19use astrid_llm::{LlmProvider, LlmToolDefinition, Message, StreamEvent, ToolCall, ToolCallResult};
20use astrid_mcp::McpClient;
21use astrid_tools::{ToolContext, ToolRegistry, truncate_output};
22use astrid_workspace::{
23 EscapeDecision, EscapeRequest, PathCheck, WorkspaceBoundary, WorkspaceConfig,
24};
25use async_trait::async_trait;
26use futures::StreamExt;
27use std::path::{Path, PathBuf};
28use std::sync::Arc;
29use tracing::{debug, error, info, warn};
30
31use crate::context::ContextManager;
32use crate::error::{RuntimeError, RuntimeResult};
33use crate::session::AgentSession;
34use crate::store::SessionStore;
35use crate::subagent::SubAgentPool;
36use crate::subagent_executor::{DEFAULT_SUBAGENT_TIMEOUT, SubAgentExecutor};
37
38const DEFAULT_MAX_CONTEXT_TOKENS: usize = 100_000;
40const DEFAULT_KEEP_RECENT_COUNT: usize = 10;
42
43const DEFAULT_MAX_CONCURRENT_SUBAGENTS: usize = 4;
45const DEFAULT_MAX_SUBAGENT_DEPTH: usize = 3;
47
48#[derive(Debug, Clone)]
50pub struct RuntimeConfig {
51 pub max_context_tokens: usize,
53 pub system_prompt: String,
55 pub auto_summarize: bool,
57 pub keep_recent_count: usize,
59 pub workspace: WorkspaceConfig,
61 pub max_concurrent_subagents: usize,
63 pub max_subagent_depth: usize,
65 pub default_subagent_timeout: std::time::Duration,
67}
68
69impl Default for RuntimeConfig {
70 fn default() -> Self {
71 let workspace_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
72 Self {
73 max_context_tokens: DEFAULT_MAX_CONTEXT_TOKENS,
74 system_prompt: String::new(),
75 auto_summarize: true,
76 keep_recent_count: DEFAULT_KEEP_RECENT_COUNT,
77 workspace: WorkspaceConfig::new(workspace_root),
78 max_concurrent_subagents: DEFAULT_MAX_CONCURRENT_SUBAGENTS,
79 max_subagent_depth: DEFAULT_MAX_SUBAGENT_DEPTH,
80 default_subagent_timeout: DEFAULT_SUBAGENT_TIMEOUT,
81 }
82 }
83}
84
85pub struct AgentRuntime<P: LlmProvider> {
87 llm: Arc<P>,
89 mcp: McpClient,
91 audit: Arc<AuditLog>,
93 sessions: SessionStore,
95 crypto: Arc<KeyPair>,
97 config: RuntimeConfig,
99 context: ContextManager,
101 boundary: WorkspaceBoundary,
103 hooks: Arc<HookManager>,
105 tool_registry: ToolRegistry,
107 shared_cwd: Arc<tokio::sync::RwLock<PathBuf>>,
109 security_policy: SecurityPolicy,
111 subagent_pool: Arc<SubAgentPool>,
113 self_arc: tokio::sync::RwLock<Option<std::sync::Weak<Self>>>,
115}
116
117impl<P: LlmProvider + 'static> AgentRuntime<P> {
118 #[must_use]
120 pub fn new(
121 llm: P,
122 mcp: McpClient,
123 audit: AuditLog,
124 sessions: SessionStore,
125 crypto: KeyPair,
126 config: RuntimeConfig,
127 ) -> Self {
128 let context =
129 ContextManager::new(config.max_context_tokens).keep_recent(config.keep_recent_count);
130 let boundary = WorkspaceBoundary::new(config.workspace.clone());
131
132 let tool_registry = ToolRegistry::with_defaults();
133 let shared_cwd = Arc::new(tokio::sync::RwLock::new(config.workspace.root.clone()));
134 let subagent_pool = Arc::new(SubAgentPool::new(
135 config.max_concurrent_subagents,
136 config.max_subagent_depth,
137 ));
138
139 info!(
140 workspace_root = %config.workspace.root.display(),
141 workspace_mode = ?config.workspace.mode,
142 max_concurrent_subagents = config.max_concurrent_subagents,
143 max_subagent_depth = config.max_subagent_depth,
144 "Workspace boundary initialized"
145 );
146
147 Self {
148 llm: Arc::new(llm),
149 mcp,
150 audit: Arc::new(audit),
151 sessions,
152 crypto: Arc::new(crypto),
153 config,
154 context,
155 boundary,
156 hooks: Arc::new(HookManager::new()),
157 tool_registry,
158 shared_cwd,
159 security_policy: SecurityPolicy::default(),
160 subagent_pool,
161 self_arc: tokio::sync::RwLock::new(None),
162 }
163 }
164
165 #[must_use]
171 pub fn new_arc(
172 llm: P,
173 mcp: McpClient,
174 audit: AuditLog,
175 sessions: SessionStore,
176 crypto: KeyPair,
177 config: RuntimeConfig,
178 hooks: Option<HookManager>,
179 ) -> Arc<Self> {
180 Arc::new_cyclic(|weak| {
181 let mut runtime = Self::new(llm, mcp, audit, sessions, crypto, config);
182 if let Some(hook_manager) = hooks {
183 runtime.hooks = Arc::new(hook_manager);
184 }
185 runtime.self_arc = tokio::sync::RwLock::new(Some(weak.clone()));
187 runtime
188 })
189 }
190
191 #[must_use]
201 pub fn create_session(&self, workspace_override: Option<&Path>) -> AgentSession {
202 let workspace_root = workspace_override.unwrap_or(&self.config.workspace.root);
203
204 let system_prompt = if self.config.system_prompt.is_empty() {
205 astrid_tools::build_system_prompt(workspace_root)
206 } else {
207 self.config.system_prompt.clone()
208 };
209
210 let session = AgentSession::new(self.crypto.key_id(), system_prompt);
211 info!(session_id = %session.id, "Created new session");
212 session
213 }
214
215 pub fn save_session(&self, session: &AgentSession) -> RuntimeResult<()> {
221 self.sessions.save(session)
222 }
223
224 pub fn load_session(&self, id: &SessionId) -> RuntimeResult<Option<AgentSession>> {
230 self.sessions.load(id)
231 }
232
233 pub fn list_sessions(&self) -> RuntimeResult<Vec<crate::store::SessionSummary>> {
239 self.sessions.list_with_metadata()
240 }
241
242 #[allow(clippy::too_many_lines)]
255 pub async fn run_turn_streaming<F: Frontend + 'static>(
256 &self,
257 session: &mut AgentSession,
258 input: &str,
259 frontend: Arc<F>,
260 ) -> RuntimeResult<()> {
261 let handler: Arc<dyn ApprovalHandler> = Arc::new(FrontendApprovalHandler {
263 frontend: Arc::clone(&frontend),
264 });
265 session.approval_manager.register_handler(handler).await;
266
267 session.add_message(Message::user(input));
269
270 {
272 let ctx = self
273 .build_hook_context(session, HookEvent::UserPrompt)
274 .with_data("input", serde_json::json!(input));
275 let result = self.hooks.trigger_simple(HookEvent::UserPrompt, ctx).await;
276 if let astrid_hooks::HookResult::Block { reason } = result {
277 return Err(RuntimeError::ApprovalDenied { reason });
278 }
279 if let astrid_hooks::HookResult::ContinueWith { modifications } = &result {
280 debug!(?modifications, "UserPrompt hook modified context");
281 }
282 }
283
284 {
286 let _ = self.audit.append(
287 session.id.clone(),
288 AuditAction::LlmRequest {
289 model: self.llm.model().to_string(),
290 input_tokens: session.token_count,
291 output_tokens: 0,
292 },
293 AuthorizationProof::System {
294 reason: "user input".to_string(),
295 },
296 AuditOutcome::success(),
297 );
298 }
299
300 if self.config.auto_summarize && self.context.needs_summarization(session) {
302 frontend.show_status("Summarizing context...");
303 let result = self.context.summarize(session, self.llm.as_ref()).await?;
304
305 {
307 let _ = self.audit.append(
308 session.id.clone(),
309 AuditAction::ContextSummarized {
310 evicted_count: result.messages_evicted,
311 tokens_freed: result.tokens_freed,
312 },
313 AuthorizationProof::System {
314 reason: "context overflow".to_string(),
315 },
316 AuditOutcome::success(),
317 );
318 }
319 }
320
321 let tool_ctx = ToolContext::with_shared_cwd(
323 self.config.workspace.root.clone(),
324 Arc::clone(&self.shared_cwd),
325 );
326
327 self.inject_subagent_spawner(&tool_ctx, session, &frontend, None)
329 .await;
330
331 let loop_result = self.run_loop(session, &*frontend, &tool_ctx).await;
333
334 loop_result?;
335
336 self.sessions.save(session)?;
337 Ok(())
338 }
339
340 pub async fn run_subagent_turn<F: Frontend + 'static>(
352 &self,
353 session: &mut AgentSession,
354 prompt: &str,
355 frontend: Arc<F>,
356 parent_subagent_id: Option<crate::subagent::SubAgentId>,
357 ) -> RuntimeResult<()> {
358 let handler: Arc<dyn ApprovalHandler> = Arc::new(FrontendApprovalHandler {
360 frontend: Arc::clone(&frontend),
361 });
362 session.approval_manager.register_handler(handler).await;
363
364 session.add_message(Message::user(prompt));
366
367 {
369 let _ = self.audit.append(
370 session.id.clone(),
371 AuditAction::LlmRequest {
372 model: self.llm.model().to_string(),
373 input_tokens: session.token_count,
374 output_tokens: 0,
375 },
376 AuthorizationProof::System {
377 reason: "sub-agent prompt".to_string(),
378 },
379 AuditOutcome::success(),
380 );
381 }
382
383 let tool_ctx = ToolContext::with_shared_cwd(
385 self.config.workspace.root.clone(),
386 Arc::clone(&self.shared_cwd),
387 );
388
389 self.inject_subagent_spawner(&tool_ctx, session, &frontend, parent_subagent_id)
391 .await;
392
393 self.run_loop(session, &*frontend, &tool_ctx).await
396 }
397
398 #[allow(clippy::too_many_lines)]
402 async fn run_loop<F: Frontend>(
403 &self,
404 session: &mut AgentSession,
405 frontend: &F,
406 tool_ctx: &ToolContext,
407 ) -> RuntimeResult<()> {
408 loop {
409 let mut llm_tools: Vec<LlmToolDefinition> = self.tool_registry.all_definitions();
411
412 let mcp_tools = self.mcp.list_tools().await?;
413 llm_tools.extend(mcp_tools.iter().map(|t| {
414 LlmToolDefinition::new(format!("{}:{}", &t.server, &t.name))
415 .with_description(t.description.clone().unwrap_or_default())
416 .with_schema(t.input_schema.clone())
417 }));
418
419 let mut stream = self
421 .llm
422 .stream(&session.messages, &llm_tools, &session.system_prompt)
423 .await?;
424
425 let mut response_text = String::new();
426 let mut tool_calls: Vec<ToolCall> = Vec::new();
427 let mut current_tool_args = String::new();
428
429 while let Some(event) = stream.next().await {
430 match event? {
431 StreamEvent::TextDelta(text) => {
432 frontend.show_status(&text);
433 response_text.push_str(&text);
434 },
435 StreamEvent::ToolCallStart { id, name } => {
436 tool_calls.push(ToolCall::new(id, name));
437 current_tool_args.clear();
438 },
439 StreamEvent::ToolCallDelta { id: _, args_delta } => {
440 current_tool_args.push_str(&args_delta);
441 },
442 StreamEvent::ToolCallEnd { id } => {
443 if let Some(call) = tool_calls.iter_mut().find(|c| c.id == id)
445 && let Ok(args) = serde_json::from_str(¤t_tool_args)
446 {
447 call.arguments = args;
448 }
449 current_tool_args.clear();
450 },
451 StreamEvent::Usage {
452 input_tokens,
453 output_tokens,
454 } => {
455 debug!(input = input_tokens, output = output_tokens, "Token usage");
456 let cost = tokens_to_usd(input_tokens, output_tokens);
458 session.budget_tracker.record_cost(cost);
459 if let Some(ref ws_budget) = session.workspace_budget_tracker {
461 ws_budget.record_cost(cost);
462 }
463 },
464 StreamEvent::ReasoningDelta(_) => {
465 },
467 StreamEvent::Done => break,
468 StreamEvent::Error(e) => {
469 error!(error = %e, "Stream error");
470 return Err(RuntimeError::LlmError(
471 astrid_llm::LlmError::StreamingError(e),
472 ));
473 },
474 }
475 }
476
477 if !tool_calls.is_empty() {
479 session.add_message(Message::assistant_with_tools(tool_calls.clone()));
481
482 for call in &tool_calls {
484 frontend.tool_started(&call.id, &call.name, &call.arguments);
485 let result = self
486 .execute_tool_call(session, call, frontend, tool_ctx)
487 .await?;
488 frontend.tool_completed(&call.id, &result.content, result.is_error);
489 session.add_message(Message::tool_result(result));
490 session.metadata.tool_call_count =
491 session.metadata.tool_call_count.saturating_add(1);
492 }
493
494 continue;
496 }
497
498 if !response_text.is_empty() {
500 session.add_message(Message::assistant(&response_text));
501 return Ok(());
502 }
503
504 break;
506 }
507
508 Ok(())
509 }
510
511 #[allow(clippy::too_many_lines)]
513 async fn execute_tool_call<F: Frontend>(
514 &self,
515 session: &mut AgentSession,
516 call: &ToolCall,
517 frontend: &F,
518 tool_ctx: &ToolContext,
519 ) -> RuntimeResult<ToolCallResult> {
520 if ToolRegistry::is_builtin(&call.name) {
522 return self
523 .execute_builtin_tool(session, call, frontend, tool_ctx)
524 .await;
525 }
526
527 let (server, tool) = call.parse_name().ok_or_else(|| {
528 RuntimeError::McpError(astrid_mcp::McpError::ToolNotFound {
529 server: "unknown".to_string(),
530 tool: call.name.clone(),
531 })
532 })?;
533
534 if let Err(tool_error) = self
536 .check_workspace_boundaries(session, call, server, tool, frontend)
537 .await
538 {
539 return Ok(tool_error);
540 }
541
542 {
544 let ctx = self
545 .build_hook_context(session, HookEvent::PreToolCall)
546 .with_data("tool_name", serde_json::json!(tool))
547 .with_data("server_name", serde_json::json!(server))
548 .with_data("arguments", call.arguments.clone());
549 let result = self.hooks.trigger_simple(HookEvent::PreToolCall, ctx).await;
550 if let astrid_hooks::HookResult::Block { reason } = result {
551 return Ok(ToolCallResult::error(&call.id, reason));
552 }
553 if let astrid_hooks::HookResult::ContinueWith { modifications } = &result {
554 debug!(?modifications, "PreToolCall hook modified context");
555 }
556 }
557
558 let action = classify_tool_call(server, tool, &call.arguments);
560
561 let interceptor = self.build_interceptor(session);
563 let tool_result = match interceptor
564 .intercept(&action, &format!("MCP tool call to {server}:{tool}"), None)
565 .await
566 {
567 Ok(intercept_result) => {
568 if let Some(warning) = &intercept_result.budget_warning {
570 frontend.show_status(&format!(
571 "Budget warning: ${:.2}/${:.2} spent ({:.0}%)",
572 warning.current_spend, warning.session_max, warning.percent_used
573 ));
574 }
575 let result = self
577 .mcp
578 .call_tool(server, tool, call.arguments.clone())
579 .await?;
580 ToolCallResult::success(&call.id, result.text_content())
581 },
582 Err(e) => ToolCallResult::error(&call.id, e.to_string()),
583 };
584
585 {
587 let hook_event = if tool_result.is_error {
588 HookEvent::ToolError
589 } else {
590 HookEvent::PostToolCall
591 };
592 let ctx = self
593 .build_hook_context(session, hook_event)
594 .with_data("tool_name", serde_json::json!(tool))
595 .with_data("server_name", serde_json::json!(server))
596 .with_data("is_error", serde_json::json!(tool_result.is_error));
597 let _ = self.hooks.trigger_simple(hook_event, ctx).await;
598 }
599
600 Ok(tool_result)
601 }
602
603 #[must_use]
605 pub fn config(&self) -> &RuntimeConfig {
606 &self.config
607 }
608
609 #[must_use]
611 pub fn audit(&self) -> &Arc<AuditLog> {
612 &self.audit
613 }
614
615 #[must_use]
617 pub fn mcp(&self) -> &McpClient {
618 &self.mcp
619 }
620
621 #[must_use]
623 pub fn key_id(&self) -> [u8; 8] {
624 self.crypto.key_id()
625 }
626
627 #[must_use]
629 pub fn boundary(&self) -> &WorkspaceBoundary {
630 &self.boundary
631 }
632
633 #[must_use]
635 pub fn with_hooks(mut self, hooks: HookManager) -> Self {
636 self.hooks = Arc::new(hooks);
637 self
638 }
639
640 #[must_use]
642 pub fn hooks(&self) -> &Arc<HookManager> {
643 &self.hooks
644 }
645
646 #[must_use]
648 pub fn subagent_pool(&self) -> &Arc<SubAgentPool> {
649 &self.subagent_pool
650 }
651
652 pub async fn set_self_arc(self: &Arc<Self>) {
666 *self.self_arc.write().await = Some(Arc::downgrade(self));
667 }
668
669 async fn inject_subagent_spawner<F: Frontend + 'static>(
673 &self,
674 tool_ctx: &ToolContext,
675 session: &AgentSession,
676 frontend: &Arc<F>,
677 parent_subagent_id: Option<crate::subagent::SubAgentId>,
678 ) {
679 let self_arc = {
680 let guard = self.self_arc.read().await;
681 guard.as_ref().and_then(std::sync::Weak::upgrade)
682 };
683
684 if let Some(runtime_arc) = self_arc {
685 let executor = SubAgentExecutor::new(
686 runtime_arc,
687 Arc::clone(&self.subagent_pool),
688 Arc::clone(frontend),
689 session.user_id,
690 parent_subagent_id,
691 session.id.clone(),
692 Arc::clone(&session.allowance_store),
693 Arc::clone(&session.capabilities),
694 Arc::clone(&session.budget_tracker),
695 self.config.default_subagent_timeout,
696 );
697 tool_ctx
698 .set_subagent_spawner(Some(Arc::new(executor)))
699 .await;
700 } else {
701 debug!("No self_arc set — sub-agent spawning disabled for this turn");
702 }
703 }
704
705 #[allow(clippy::unused_self)]
707 fn build_hook_context(&self, session: &AgentSession, event: HookEvent) -> HookContext {
708 let mut uuid_bytes = [0u8; 16];
710 uuid_bytes[..8].copy_from_slice(&session.user_id);
711 let user_uuid = uuid::Uuid::from_bytes(uuid_bytes);
712
713 HookContext::new(event)
714 .with_session(session.id.0)
715 .with_user(user_uuid)
716 }
717
718 fn build_interceptor(&self, session: &AgentSession) -> SecurityInterceptor {
723 SecurityInterceptor::new(
724 Arc::clone(&session.capabilities),
725 Arc::clone(&session.approval_manager),
726 self.security_policy.clone(),
727 Arc::clone(&session.budget_tracker),
728 Arc::clone(&self.audit),
729 Arc::clone(&self.crypto),
730 session.id.clone(),
731 Arc::clone(&session.allowance_store),
732 Some(self.config.workspace.root.clone()),
733 session.workspace_budget_tracker.clone(),
734 )
735 }
736
737 async fn execute_builtin_tool<F: Frontend>(
739 &self,
740 session: &mut AgentSession,
741 call: &ToolCall,
742 frontend: &F,
743 tool_ctx: &ToolContext,
744 ) -> RuntimeResult<ToolCallResult> {
745 let tool_name = &call.name;
746
747 let Some(tool) = self.tool_registry.get(tool_name) else {
748 return Ok(ToolCallResult::error(
749 &call.id,
750 format!("Unknown built-in tool: {tool_name}"),
751 ));
752 };
753
754 if let Err(tool_error) = self
756 .check_workspace_boundaries(session, call, "builtin", tool_name, frontend)
757 .await
758 {
759 return Ok(tool_error);
760 }
761
762 {
764 let ctx = self
765 .build_hook_context(session, HookEvent::PreToolCall)
766 .with_data("tool_name", serde_json::json!(tool_name))
767 .with_data("server_name", serde_json::json!("builtin"))
768 .with_data("arguments", call.arguments.clone());
769 let result = self.hooks.trigger_simple(HookEvent::PreToolCall, ctx).await;
770 if let astrid_hooks::HookResult::Block { reason } = result {
771 return Ok(ToolCallResult::error(&call.id, reason));
772 }
773 }
774
775 let action = classify_builtin_tool_call(tool_name, &call.arguments);
777 let interceptor = self.build_interceptor(session);
778 match interceptor
779 .intercept(&action, &format!("Built-in tool: {tool_name}"), None)
780 .await
781 {
782 Ok(intercept_result) => {
783 if let Some(warning) = &intercept_result.budget_warning {
785 frontend.show_status(&format!(
786 "Budget warning: ${:.2}/${:.2} spent ({:.0}%)",
787 warning.current_spend, warning.session_max, warning.percent_used
788 ));
789 }
790 },
791 Err(e) => return Ok(ToolCallResult::error(&call.id, e.to_string())),
792 }
793
794 let tool_result = match tool.execute(call.arguments.clone(), tool_ctx).await {
796 Ok(output) => {
797 let output = truncate_output(output);
798 ToolCallResult::success(&call.id, output)
799 },
800 Err(e) => ToolCallResult::error(&call.id, e.to_string()),
801 };
802
803 {
805 let hook_event = if tool_result.is_error {
806 HookEvent::ToolError
807 } else {
808 HookEvent::PostToolCall
809 };
810 let ctx = self
811 .build_hook_context(session, hook_event)
812 .with_data("tool_name", serde_json::json!(tool_name))
813 .with_data("server_name", serde_json::json!("builtin"))
814 .with_data("is_error", serde_json::json!(tool_result.is_error));
815 let _ = self.hooks.trigger_simple(hook_event, ctx).await;
816 }
817
818 Ok(tool_result)
819 }
820
821 #[allow(clippy::too_many_lines)]
825 async fn check_workspace_boundaries<F: Frontend>(
826 &self,
827 session: &mut AgentSession,
828 call: &ToolCall,
829 server: &str,
830 tool: &str,
831 frontend: &F,
832 ) -> Result<(), ToolCallResult> {
833 let paths = extract_paths_from_args(&call.arguments);
834 if paths.is_empty() {
835 return Ok(());
836 }
837
838 for path in &paths {
839 if session.escape_handler.is_allowed(path) {
841 debug!(path = %path.display(), "Path already approved by escape handler");
842 continue;
843 }
844
845 let check = self.boundary.check(path);
846 match check {
847 PathCheck::Allowed | PathCheck::AutoAllowed => {},
848 PathCheck::NeverAllowed => {
849 warn!(
850 path = %path.display(),
851 tool = %format!("{server}:{tool}"),
852 "Access to protected path blocked"
853 );
854
855 {
857 let _ = self.audit.append(
858 session.id.clone(),
859 AuditAction::ApprovalDenied {
860 action: format!("{server}:{tool} -> {}", path.display()),
861 reason: Some("protected system path".to_string()),
862 },
863 AuthorizationProof::System {
864 reason: "workspace boundary: never-allowed path".to_string(),
865 },
866 AuditOutcome::failure("protected path"),
867 );
868 }
869
870 return Err(ToolCallResult::error(
871 &call.id,
872 format!(
873 "Access to {} is blocked — this is a protected system path",
874 path.display()
875 ),
876 ));
877 },
878 PathCheck::RequiresApproval => {
879 let escape_request = EscapeRequest::new(
880 path.clone(),
881 infer_operation(tool),
882 format!(
883 "Tool {server}:{tool} wants to access {} outside the workspace",
884 path.display()
885 ),
886 )
887 .with_tool(tool)
888 .with_server(server);
889
890 let approval_request = ApprovalRequest::new(
892 format!("workspace-escape:{server}:{tool}"),
893 format!(
894 "Allow {} {} outside workspace?\n Path: {}",
895 tool,
896 escape_request.operation,
897 path.display()
898 ),
899 )
900 .with_risk_level(risk_level_for_operation(escape_request.operation))
901 .with_resource(path.display().to_string());
902
903 let decision =
904 frontend
905 .request_approval(approval_request)
906 .await
907 .map_err(|_| {
908 ToolCallResult::error(
909 &call.id,
910 "Failed to request workspace escape approval",
911 )
912 })?;
913
914 let escape_decision = match decision.decision {
916 ApprovalOption::AllowOnce => EscapeDecision::AllowOnce,
917 ApprovalOption::AllowSession | ApprovalOption::AllowWorkspace => {
918 EscapeDecision::AllowSession
919 },
920 ApprovalOption::AllowAlways => EscapeDecision::AllowAlways,
921 ApprovalOption::Deny => EscapeDecision::Deny,
922 };
923
924 session
926 .escape_handler
927 .process_decision(&escape_request, escape_decision);
928
929 if escape_decision.is_allowed() {
931 let _ = self.audit.append(
932 session.id.clone(),
933 AuditAction::ApprovalGranted {
934 action: format!("{server}:{tool}"),
935 resource: Some(path.display().to_string()),
936 scope: match decision.decision {
937 ApprovalOption::AllowSession => {
938 astrid_audit::ApprovalScope::Session
939 },
940 ApprovalOption::AllowWorkspace => {
941 astrid_audit::ApprovalScope::Workspace
942 },
943 ApprovalOption::AllowAlways => {
944 astrid_audit::ApprovalScope::Always
945 },
946 ApprovalOption::AllowOnce | ApprovalOption::Deny => {
947 astrid_audit::ApprovalScope::Once
948 },
949 },
950 },
951 AuthorizationProof::UserApproval {
952 user_id: session.user_id,
953 approval_entry_id: AuditEntryId::new(),
954 },
955 AuditOutcome::success(),
956 );
957 } else {
958 let _ = self.audit.append(
959 session.id.clone(),
960 AuditAction::ApprovalDenied {
961 action: format!("{server}:{tool} -> {}", path.display()),
962 reason: Some(
963 decision
964 .reason
965 .clone()
966 .unwrap_or_else(|| "user denied".to_string()),
967 ),
968 },
969 AuthorizationProof::UserApproval {
970 user_id: session.user_id,
971 approval_entry_id: AuditEntryId::new(),
972 },
973 AuditOutcome::failure("user denied workspace escape"),
974 );
975 }
976
977 if !escape_decision.is_allowed() {
978 return Err(ToolCallResult::error(
979 &call.id,
980 decision.reason.unwrap_or_else(|| {
981 format!("Access to {} denied — outside workspace", path.display())
982 }),
983 ));
984 }
985
986 info!(
987 path = %path.display(),
988 decision = ?escape_decision,
989 "Workspace escape approved"
990 );
991 },
992 }
993 }
994
995 Ok(())
996 }
997}
998
999fn extract_paths_from_args(args: &serde_json::Value) -> Vec<PathBuf> {
1003 const PATH_KEYS: &[&str] = &[
1005 "path",
1006 "file",
1007 "file_path",
1008 "filepath",
1009 "filename",
1010 "directory",
1011 "dir",
1012 "target",
1013 "source",
1014 "destination",
1015 "src",
1016 "dst",
1017 "input",
1018 "output",
1019 "uri",
1020 "url",
1021 "cwd",
1022 "working_directory",
1023 ];
1024
1025 let mut paths = Vec::new();
1026
1027 if let Some(obj) = args.as_object() {
1028 for (key, value) in obj {
1029 let key_lower = key.to_lowercase();
1030 if let Some(s) = value.as_str()
1031 && PATH_KEYS.contains(&key_lower.as_str())
1032 && let Some(path) = try_extract_path(s)
1033 {
1034 paths.push(path);
1035 }
1036 }
1037 }
1038
1039 paths
1040}
1041
1042fn try_extract_path(value: &str) -> Option<PathBuf> {
1044 if let Some(stripped) = value.strip_prefix("file://") {
1046 return Some(PathBuf::from(stripped));
1047 }
1048
1049 if value.contains("://") {
1051 return None;
1052 }
1053
1054 if value.starts_with('/')
1056 || value.starts_with("~/")
1057 || value.starts_with("./")
1058 || value.starts_with("../")
1059 {
1060 return Some(PathBuf::from(value));
1061 }
1062
1063 None
1064}
1065
1066fn infer_operation(tool: &str) -> astrid_workspace::escape::EscapeOperation {
1068 use astrid_workspace::escape::EscapeOperation;
1069 let tool_lower = tool.to_lowercase();
1070
1071 if tool_lower.contains("read") || tool_lower.contains("get") || tool_lower.contains("cat") {
1072 EscapeOperation::Read
1073 } else if tool_lower.contains("write")
1074 || tool_lower.contains("set")
1075 || tool_lower.contains("put")
1076 || tool_lower.contains("edit")
1077 || tool_lower.contains("update")
1078 {
1079 EscapeOperation::Write
1080 } else if tool_lower.contains("create")
1081 || tool_lower.contains("mkdir")
1082 || tool_lower.contains("touch")
1083 || tool_lower.contains("new")
1084 {
1085 EscapeOperation::Create
1086 } else if tool_lower.contains("delete")
1087 || tool_lower.contains("remove")
1088 || tool_lower.contains("rm")
1089 {
1090 EscapeOperation::Delete
1091 } else if tool_lower.contains("exec")
1092 || tool_lower.contains("run")
1093 || tool_lower.contains("launch")
1094 {
1095 EscapeOperation::Execute
1096 } else if tool_lower.contains("list") || tool_lower.contains("ls") || tool_lower.contains("dir")
1097 {
1098 EscapeOperation::List
1099 } else {
1100 EscapeOperation::Read
1102 }
1103}
1104
1105fn risk_level_for_operation(operation: astrid_workspace::escape::EscapeOperation) -> RiskLevel {
1107 use astrid_workspace::escape::EscapeOperation;
1108 match operation {
1109 EscapeOperation::Read | EscapeOperation::List => RiskLevel::Medium,
1110 EscapeOperation::Write | EscapeOperation::Create => RiskLevel::High,
1111 EscapeOperation::Delete | EscapeOperation::Execute => RiskLevel::Critical,
1112 }
1113}
1114
1115fn classify_tool_call(server: &str, tool: &str, args: &serde_json::Value) -> SensitiveAction {
1117 let tool_lower = tool.to_lowercase();
1118
1119 if (tool_lower.contains("delete") || tool_lower.contains("remove"))
1121 && let Some(path) = args
1122 .get("path")
1123 .or_else(|| args.get("file"))
1124 .and_then(|v| v.as_str())
1125 {
1126 return SensitiveAction::FileDelete {
1127 path: path.to_string(),
1128 };
1129 }
1130
1131 if tool_lower.contains("exec") || tool_lower.contains("run") || tool_lower.contains("bash") {
1133 let command = args
1134 .get("command")
1135 .and_then(|v| v.as_str())
1136 .unwrap_or(tool)
1137 .to_string();
1138 let cmd_args = args
1139 .get("args")
1140 .and_then(|v| v.as_array())
1141 .map(|a| {
1142 a.iter()
1143 .filter_map(|v| v.as_str().map(String::from))
1144 .collect()
1145 })
1146 .unwrap_or_default();
1147 return SensitiveAction::ExecuteCommand {
1148 command,
1149 args: cmd_args,
1150 };
1151 }
1152
1153 if tool_lower.contains("write")
1155 && let Some(path) = args
1156 .get("path")
1157 .or_else(|| args.get("file_path"))
1158 .and_then(|v| v.as_str())
1159 && path.starts_with('/')
1160 {
1161 return SensitiveAction::FileWriteOutsideSandbox {
1162 path: path.to_string(),
1163 };
1164 }
1165
1166 SensitiveAction::McpToolCall {
1168 server: server.to_string(),
1169 tool: tool.to_string(),
1170 }
1171}
1172
1173fn to_frontend_request(internal: &InternalApprovalRequest) -> ApprovalRequest {
1175 ApprovalRequest::new(
1176 internal.action.action_type().to_string(),
1177 internal.action.summary(),
1178 )
1179 .with_risk_level(internal.assessment.level)
1180 .with_resource(format!("{}", internal.action))
1181}
1182
1183fn to_internal_response(
1185 request: &InternalApprovalRequest,
1186 decision: &ApprovalDecision,
1187) -> InternalApprovalResponse {
1188 let internal_decision = match decision.decision {
1189 ApprovalOption::AllowOnce => InternalApprovalDecision::Approve,
1190 ApprovalOption::AllowSession => InternalApprovalDecision::ApproveSession,
1191 ApprovalOption::AllowWorkspace => InternalApprovalDecision::ApproveWorkspace,
1192 ApprovalOption::AllowAlways => InternalApprovalDecision::ApproveAlways,
1193 ApprovalOption::Deny => InternalApprovalDecision::Deny {
1194 reason: decision
1195 .reason
1196 .clone()
1197 .unwrap_or_else(|| "denied by user".to_string()),
1198 },
1199 };
1200 InternalApprovalResponse::new(request.id.clone(), internal_decision)
1201}
1202
1203fn classify_builtin_tool_call(tool_name: &str, args: &serde_json::Value) -> SensitiveAction {
1208 match tool_name {
1209 "bash" => {
1210 let command = args
1211 .get("command")
1212 .and_then(|v| v.as_str())
1213 .unwrap_or("bash")
1214 .to_string();
1215 SensitiveAction::ExecuteCommand {
1216 command,
1217 args: Vec::new(),
1218 }
1219 },
1220 "write_file" | "edit_file" => {
1221 let path = args
1222 .get("file_path")
1223 .or_else(|| args.get("path"))
1224 .and_then(|v| v.as_str())
1225 .unwrap_or("unknown")
1226 .to_string();
1227 SensitiveAction::FileWriteOutsideSandbox { path }
1228 },
1229 "read_file" | "glob" | "grep" | "list_directory" => {
1230 let path = args
1231 .get("file_path")
1232 .or_else(|| args.get("path"))
1233 .or_else(|| args.get("pattern"))
1234 .and_then(|v| v.as_str())
1235 .unwrap_or(".")
1236 .to_string();
1237 SensitiveAction::FileRead { path }
1238 },
1239 other => SensitiveAction::McpToolCall {
1241 server: "builtin".to_string(),
1242 tool: other.to_string(),
1243 },
1244 }
1245}
1246
1247struct FrontendApprovalHandler<F: Frontend> {
1254 frontend: Arc<F>,
1255}
1256
1257#[async_trait]
1258impl<F: Frontend> ApprovalHandler for FrontendApprovalHandler<F> {
1259 async fn request_approval(
1260 &self,
1261 request: InternalApprovalRequest,
1262 ) -> Option<InternalApprovalResponse> {
1263 let frontend_request = to_frontend_request(&request);
1264 match self.frontend.request_approval(frontend_request).await {
1265 Ok(decision) => Some(to_internal_response(&request, &decision)),
1266 Err(_) => None,
1267 }
1268 }
1269
1270 fn is_available(&self) -> bool {
1271 true
1272 }
1273}
1274
1275const INPUT_RATE_PER_1K: f64 = 0.003; const OUTPUT_RATE_PER_1K: f64 = 0.015; #[allow(clippy::cast_precision_loss)]
1286fn tokens_to_usd(input_tokens: usize, output_tokens: usize) -> f64 {
1287 let input_cost = (input_tokens as f64 / 1000.0) * INPUT_RATE_PER_1K;
1288 let output_cost = (output_tokens as f64 / 1000.0) * OUTPUT_RATE_PER_1K;
1289 input_cost + output_cost
1290}
1291
1292#[cfg(test)]
1293mod tests {
1294 use super::*;
1295
1296 #[test]
1297 fn test_extract_paths_from_args() {
1298 let args = serde_json::json!({
1299 "path": "/home/user/file.txt",
1300 "content": "some data",
1301 "count": 42
1302 });
1303 let paths = extract_paths_from_args(&args);
1304 assert_eq!(paths.len(), 1);
1305 assert_eq!(paths[0], PathBuf::from("/home/user/file.txt"));
1306 }
1307
1308 #[test]
1309 fn test_extract_paths_ignores_non_path_values() {
1310 let args = serde_json::json!({
1311 "path": "not-a-path",
1312 "url": "https://example.com",
1313 });
1314 let paths = extract_paths_from_args(&args);
1315 assert!(paths.is_empty());
1316 }
1317
1318 #[test]
1319 fn test_extract_paths_file_uri() {
1320 let args = serde_json::json!({
1321 "uri": "file:///tmp/test.txt"
1322 });
1323 let paths = extract_paths_from_args(&args);
1324 assert_eq!(paths.len(), 1);
1325 assert_eq!(paths[0], PathBuf::from("/tmp/test.txt"));
1326 }
1327
1328 #[test]
1329 fn test_extract_paths_relative() {
1330 let args = serde_json::json!({
1331 "file": "./src/main.rs",
1332 "dir": "../other"
1333 });
1334 let paths = extract_paths_from_args(&args);
1335 assert_eq!(paths.len(), 2);
1336 }
1337
1338 #[test]
1339 fn test_infer_operation() {
1340 use astrid_workspace::escape::EscapeOperation;
1341 assert_eq!(infer_operation("read_file"), EscapeOperation::Read);
1342 assert_eq!(infer_operation("write_file"), EscapeOperation::Write);
1343 assert_eq!(infer_operation("create_directory"), EscapeOperation::Create);
1344 assert_eq!(infer_operation("delete_file"), EscapeOperation::Delete);
1345 assert_eq!(infer_operation("execute_command"), EscapeOperation::Execute);
1346 assert_eq!(infer_operation("list_files"), EscapeOperation::List);
1347 assert_eq!(infer_operation("unknown_tool"), EscapeOperation::Read);
1348 }
1349
1350 #[test]
1351 fn test_risk_level_for_operation() {
1352 use astrid_workspace::escape::EscapeOperation;
1353 assert_eq!(
1354 risk_level_for_operation(EscapeOperation::Read),
1355 RiskLevel::Medium
1356 );
1357 assert_eq!(
1358 risk_level_for_operation(EscapeOperation::Write),
1359 RiskLevel::High
1360 );
1361 assert_eq!(
1362 risk_level_for_operation(EscapeOperation::Delete),
1363 RiskLevel::Critical
1364 );
1365 }
1366}