Skip to main content

sgr_agent/
context.rs

1//! Agent execution context — state and domain-specific data.
2
3use serde_json::Value;
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7/// Agent execution state.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum AgentState {
10    Running,
11    Completed,
12    Failed,
13    Cancelled,
14    WaitingInput,
15}
16
17/// Shared context passed to tools during execution.
18#[derive(Debug, Clone)]
19pub struct AgentContext {
20    /// Current iteration (step number).
21    pub iteration: usize,
22    /// Agent state.
23    pub state: AgentState,
24    /// Working directory.
25    pub cwd: PathBuf,
26    /// Domain-specific data (extensible).
27    pub custom: HashMap<String, Value>,
28    /// Per-tool configuration overrides.
29    /// Key: tool name, Value: tool-specific config merged at execution time.
30    pub tool_configs: HashMap<String, Value>,
31    /// Directories the agent is allowed to write to (sandbox).
32    /// Empty = no restriction.
33    pub writable_roots: Vec<PathBuf>,
34}
35
36impl AgentContext {
37    pub fn new() -> Self {
38        Self {
39            iteration: 0,
40            state: AgentState::Running,
41            cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
42            custom: HashMap::new(),
43            tool_configs: HashMap::new(),
44            writable_roots: Vec::new(),
45        }
46    }
47
48    pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
49        self.cwd = cwd.into();
50        self
51    }
52
53    pub fn with_writable_roots(mut self, roots: Vec<PathBuf>) -> Self {
54        self.writable_roots = roots;
55        self
56    }
57
58    /// Check if a path is writable under sandbox rules.
59    /// Returns true if writable_roots is empty (no sandbox) or path is under any root.
60    /// Canonicalizes paths to prevent traversal attacks (../ and symlinks).
61    pub fn is_writable(&self, path: &std::path::Path) -> bool {
62        if self.writable_roots.is_empty() {
63            return true;
64        }
65        let abs_path = if path.is_absolute() {
66            path.to_path_buf()
67        } else {
68            self.cwd.join(path)
69        };
70        // Canonicalize to resolve ".." and symlinks.
71        // If the path doesn't exist yet (new file), canonicalize the parent.
72        let resolved = std::fs::canonicalize(&abs_path).unwrap_or_else(|_| {
73            // File doesn't exist — canonicalize parent, then append filename
74            if let Some(parent) = abs_path.parent()
75                && let Ok(canon_parent) = std::fs::canonicalize(parent)
76                && let Some(name) = abs_path.file_name()
77            {
78                return canon_parent.join(name);
79            }
80            abs_path.clone()
81        });
82        self.writable_roots.iter().any(|root| {
83            // Canonicalize the root too (resolve symlinks in root paths)
84            let canon_root = std::fs::canonicalize(root).unwrap_or_else(|_| root.clone());
85            resolved.starts_with(&canon_root)
86        })
87    }
88
89    /// Set a custom value.
90    pub fn set(&mut self, key: impl Into<String>, value: Value) {
91        self.custom.insert(key.into(), value);
92    }
93
94    /// Get a custom value.
95    pub fn get(&self, key: &str) -> Option<&Value> {
96        self.custom.get(key)
97    }
98
99    /// Set per-tool config.
100    pub fn set_tool_config(&mut self, tool_name: impl Into<String>, config: Value) {
101        self.tool_configs.insert(tool_name.into(), config);
102    }
103
104    /// Get per-tool config.
105    pub fn tool_config(&self, tool_name: &str) -> Option<&Value> {
106        self.tool_configs.get(tool_name)
107    }
108
109    /// Get tool config merged with a base config.
110    /// Per-tool values override base values (shallow merge).
111    pub fn merged_tool_config(&self, tool_name: &str, base: &Value) -> Value {
112        match (base, self.tool_configs.get(tool_name)) {
113            (Value::Object(base_obj), Some(Value::Object(override_obj))) => {
114                let mut merged = base_obj.clone();
115                for (k, v) in override_obj {
116                    merged.insert(k.clone(), v.clone());
117                }
118                Value::Object(merged)
119            }
120            (_, Some(override_val)) => override_val.clone(),
121            _ => base.clone(),
122        }
123    }
124}
125
126impl Default for AgentContext {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn context_default_state() {
138        let ctx = AgentContext::new();
139        assert_eq!(ctx.state, AgentState::Running);
140        assert_eq!(ctx.iteration, 0);
141    }
142
143    #[test]
144    fn context_custom_data() {
145        let mut ctx = AgentContext::new();
146        ctx.set("project", serde_json::json!("my-project"));
147        assert_eq!(ctx.get("project").unwrap(), "my-project");
148        assert!(ctx.get("missing").is_none());
149    }
150
151    #[test]
152    fn context_with_cwd() {
153        let ctx = AgentContext::new().with_cwd("/tmp/test");
154        assert_eq!(ctx.cwd, PathBuf::from("/tmp/test"));
155    }
156
157    #[test]
158    fn tool_config_set_get() {
159        let mut ctx = AgentContext::new();
160        ctx.set_tool_config("bash", serde_json::json!({"timeout": 30}));
161        assert_eq!(ctx.tool_config("bash").unwrap()["timeout"], 30);
162        assert!(ctx.tool_config("read_file").is_none());
163    }
164
165    #[test]
166    fn tool_config_merge() {
167        let mut ctx = AgentContext::new();
168        ctx.set_tool_config("bash", serde_json::json!({"timeout": 60, "shell": "zsh"}));
169
170        let base = serde_json::json!({"timeout": 30, "cwd": "/tmp"});
171        let merged = ctx.merged_tool_config("bash", &base);
172        // Override wins for timeout, base keeps cwd, override adds shell
173        assert_eq!(merged["timeout"], 60);
174        assert_eq!(merged["cwd"], "/tmp");
175        assert_eq!(merged["shell"], "zsh");
176    }
177
178    #[test]
179    fn tool_config_merge_no_override() {
180        let ctx = AgentContext::new();
181        let base = serde_json::json!({"timeout": 30});
182        let merged = ctx.merged_tool_config("bash", &base);
183        assert_eq!(merged, base);
184    }
185
186    #[test]
187    fn writable_roots_empty_allows_all() {
188        let ctx = AgentContext::new();
189        assert!(ctx.is_writable(std::path::Path::new("/any/path")));
190    }
191
192    #[test]
193    fn writable_roots_restricts() {
194        let ctx =
195            AgentContext::new().with_writable_roots(vec![PathBuf::from("/home/user/project")]);
196        assert!(ctx.is_writable(std::path::Path::new("/home/user/project/src/main.rs")));
197        assert!(!ctx.is_writable(std::path::Path::new("/etc/passwd")));
198    }
199
200    #[test]
201    fn writable_roots_relative_path() {
202        let ctx = AgentContext::new()
203            .with_cwd("/home/user/project")
204            .with_writable_roots(vec![PathBuf::from("/home/user/project")]);
205        assert!(ctx.is_writable(std::path::Path::new("src/main.rs")));
206        assert!(!ctx.is_writable(std::path::Path::new("/etc/passwd")));
207    }
208}