Skip to main content

claude_agent/tools/
context.rs

1//! Execution context for tool operations.
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::sync::Arc;
6
7use crate::hooks::{HookContext, HookEvent, HookInput, HookManager};
8use crate::permissions::{PermissionResult, ToolLimits};
9use crate::security::bash::{BashAnalysis, SanitizedEnv};
10use crate::security::fs::SecureFileHandle;
11use crate::security::guard::SecurityGuard;
12use crate::security::path::SafePath;
13use crate::security::sandbox::{DomainCheck, SandboxResult};
14use crate::security::{ResourceLimits, SecurityContext, SecurityError};
15
16#[derive(Clone)]
17pub struct ExecutionContext {
18    security: Arc<SecurityContext>,
19    hooks: Option<HookManager>,
20    session_id: Option<String>,
21}
22
23impl ExecutionContext {
24    pub fn new(security: SecurityContext) -> Self {
25        Self {
26            security: Arc::new(security),
27            hooks: None,
28            session_id: None,
29        }
30    }
31
32    pub fn from_path(root: impl AsRef<Path>) -> Result<Self, SecurityError> {
33        let security = SecurityContext::new(root)?;
34        Ok(Self::new(security))
35    }
36
37    /// Create a permissive ExecutionContext that allows all operations.
38    ///
39    /// # Panics
40    /// Panics if the root filesystem cannot be accessed.
41    pub fn permissive() -> Self {
42        Self {
43            security: Arc::new(SecurityContext::permissive()),
44            hooks: None,
45            session_id: None,
46        }
47    }
48
49    pub fn hooks(mut self, hooks: HookManager, session_id: impl Into<String>) -> Self {
50        self.hooks = Some(hooks);
51        self.session_id = Some(session_id.into());
52        self
53    }
54
55    pub fn session_id(&self) -> Option<&str> {
56        self.session_id.as_deref()
57    }
58
59    pub async fn fire_hook(&self, event: HookEvent, input: HookInput) {
60        if let Some(ref hooks) = self.hooks {
61            let context = HookContext::new(input.session_id.clone()).cwd(self.root().to_path_buf());
62            if let Err(e) = hooks.execute(event, input, &context).await {
63                tracing::warn!(error = %e, "Hook execution failed");
64            }
65        }
66    }
67
68    pub fn root(&self) -> &Path {
69        self.security.root()
70    }
71
72    pub fn limits_for(&self, tool_name: &str) -> ToolLimits {
73        self.security
74            .policy
75            .permission
76            .limits(tool_name)
77            .cloned()
78            .unwrap_or_default()
79    }
80
81    pub fn resolve(&self, input: &str) -> Result<SafePath, SecurityError> {
82        self.security.fs.resolve(input)
83    }
84
85    pub fn resolve_with_limits(
86        &self,
87        input: &str,
88        limits: &ToolLimits,
89    ) -> Result<SafePath, SecurityError> {
90        self.security.fs.resolve_with_limits(input, limits)
91    }
92
93    pub fn resolve_for(&self, tool_name: &str, path: &str) -> Result<SafePath, SecurityError> {
94        let limits = self.limits_for(tool_name);
95        self.resolve_with_limits(path, &limits)
96    }
97
98    pub fn try_resolve_for(
99        &self,
100        tool_name: &str,
101        path: &str,
102    ) -> Result<SafePath, crate::types::ToolResult> {
103        self.resolve_for(tool_name, path)
104            .map_err(|e| crate::types::ToolResult::error(e.to_string()))
105    }
106
107    pub fn try_resolve_or_root_for(
108        &self,
109        tool_name: &str,
110        path: Option<&str>,
111    ) -> Result<std::path::PathBuf, crate::types::ToolResult> {
112        let limits = self.limits_for(tool_name);
113        self.resolve_or_root(path, &limits)
114            .map_err(|e| crate::types::ToolResult::error(e.to_string()))
115    }
116
117    pub fn resolve_or_root(
118        &self,
119        path: Option<&str>,
120        limits: &ToolLimits,
121    ) -> Result<std::path::PathBuf, SecurityError> {
122        match path {
123            Some(p) => self
124                .resolve_with_limits(p, limits)
125                .map(|sp| sp.as_path().to_path_buf()),
126            None => Ok(self.root().to_path_buf()),
127        }
128    }
129
130    pub fn open_read(&self, input: &str) -> Result<SecureFileHandle, SecurityError> {
131        self.security.fs.open_read(input)
132    }
133
134    pub fn open_write(&self, input: &str) -> Result<SecureFileHandle, SecurityError> {
135        self.security.fs.open_write(input)
136    }
137
138    pub fn is_within(&self, path: &Path) -> bool {
139        self.security.fs.is_within(path)
140    }
141
142    pub fn analyze_bash(&self, command: &str) -> BashAnalysis {
143        self.security.bash.analyze(command)
144    }
145
146    pub fn validate_bash(&self, command: &str) -> Result<BashAnalysis, String> {
147        self.security.bash.validate(command)
148    }
149
150    fn sanitized_env(&self) -> SanitizedEnv {
151        SanitizedEnv::from_current().working_dir(self.root())
152    }
153
154    pub fn resource_limits(&self) -> &ResourceLimits {
155        &self.security.limits
156    }
157
158    pub fn check_domain(&self, domain: &str) -> DomainCheck {
159        self.security.network.check(domain)
160    }
161
162    pub fn can_bypass_sandbox(&self) -> bool {
163        self.security.policy.can_bypass_sandbox()
164    }
165
166    pub fn is_sandboxed(&self) -> bool {
167        self.security.is_sandboxed()
168    }
169
170    pub fn should_auto_allow_bash(&self) -> bool {
171        self.security.should_auto_allow_bash()
172    }
173
174    pub fn wrap_command(&self, command: &str) -> SandboxResult<String> {
175        self.security.sandbox.wrap_command(command)
176    }
177
178    pub fn sandbox_env(&self) -> HashMap<String, String> {
179        self.security.sandbox.environment_vars()
180    }
181
182    pub fn sanitized_env_with_sandbox(&self) -> SanitizedEnv {
183        let sandbox_env = self.sandbox_env();
184        self.sanitized_env().vars(sandbox_env)
185    }
186
187    pub fn check_permission(&self, tool_name: &str, input: &serde_json::Value) -> PermissionResult {
188        self.security.policy.permission.check(tool_name, input)
189    }
190
191    pub fn validate_security(
192        &self,
193        tool_name: &str,
194        input: &serde_json::Value,
195    ) -> Result<(), String> {
196        SecurityGuard::validate(&self.security, tool_name, input).map_err(|e| e.to_string())
197    }
198}
199
200impl Default for ExecutionContext {
201    fn default() -> Self {
202        let security = SecurityContext::builder()
203            .build()
204            .unwrap_or_else(|_| SecurityContext::permissive());
205        Self::new(security)
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use tempfile::tempdir;
213
214    #[test]
215    fn test_execution_context_new() {
216        let dir = tempdir().unwrap();
217        let context = ExecutionContext::from_path(dir.path()).unwrap();
218        assert!(context.is_within(&std::fs::canonicalize(dir.path()).unwrap()));
219    }
220
221    #[test]
222    fn test_permissive_context() {
223        let context = ExecutionContext::permissive();
224        assert!(context.can_bypass_sandbox());
225    }
226
227    #[test]
228    fn test_resolve() {
229        let dir = tempdir().unwrap();
230        let root = std::fs::canonicalize(dir.path()).unwrap();
231        std::fs::write(root.join("test.txt"), "content").unwrap();
232
233        let context = ExecutionContext::from_path(&root).unwrap();
234        let path = context.resolve("test.txt").unwrap();
235        assert_eq!(path.as_path(), root.join("test.txt"));
236    }
237
238    #[test]
239    fn test_path_escape_blocked() {
240        let dir = tempdir().unwrap();
241        let context = ExecutionContext::from_path(dir.path()).unwrap();
242        let result = context.resolve("../../../etc/passwd");
243        assert!(result.is_err());
244    }
245
246    #[test]
247    fn test_analyze_bash() {
248        let context = ExecutionContext::default();
249        let analysis = context.analyze_bash("cat /etc/passwd");
250        assert!(analysis.paths.iter().any(|p| p.path == "/etc/passwd"));
251    }
252}