Skip to main content

roboticus_agent/
approvals.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use chrono::{DateTime, Duration, Utc};
5use serde::{Deserialize, Serialize};
6use tracing::{debug, warn};
7use uuid::Uuid;
8
9use roboticus_core::config::ApprovalsConfig;
10use roboticus_core::{InputAuthority, Result, RoboticusError};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum ToolClassification {
14    Safe,
15    Gated,
16    Blocked,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum ApprovalStatus {
21    Pending,
22    Approved,
23    Denied,
24    TimedOut,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ApprovalRequest {
29    pub id: String,
30    pub tool_name: String,
31    pub tool_input: String,
32    pub session_id: Option<String>,
33    /// The turn that triggered the approval. Used for replay after approval.
34    /// Previously conflated with `session_id`; now a dedicated field so the
35    /// audit trail correctly tracks both the owning session and the replay target.
36    pub turn_id: Option<String>,
37    #[serde(default = "default_requested_authority")]
38    pub requested_authority: InputAuthority,
39    pub status: ApprovalStatus,
40    pub decided_by: Option<String>,
41    pub decided_at: Option<DateTime<Utc>>,
42    pub timeout_at: DateTime<Utc>,
43    pub created_at: DateTime<Utc>,
44}
45
46fn default_requested_authority() -> InputAuthority {
47    InputAuthority::External
48}
49
50pub struct ApprovalManager {
51    config: ApprovalsConfig,
52    pending: Arc<Mutex<HashMap<String, ApprovalRequest>>>,
53}
54
55impl ApprovalManager {
56    pub fn new(config: ApprovalsConfig) -> Self {
57        Self {
58            config,
59            pending: Arc::new(Mutex::new(HashMap::new())),
60        }
61    }
62
63    pub fn classify_tool(&self, tool_name: &str) -> ToolClassification {
64        if self.config.blocked_tools.iter().any(|t| t == tool_name) {
65            ToolClassification::Blocked
66        } else if self.config.gated_tools.iter().any(|t| t == tool_name) {
67            ToolClassification::Gated
68        } else {
69            ToolClassification::Safe
70        }
71    }
72
73    pub fn check_tool(&self, tool_name: &str) -> Result<ToolClassification> {
74        if !self.config.enabled {
75            return Ok(ToolClassification::Safe);
76        }
77
78        let classification = self.classify_tool(tool_name);
79
80        if classification == ToolClassification::Blocked {
81            return Err(RoboticusError::Tool {
82                tool: tool_name.to_string(),
83                message: "tool is blocked by policy".into(),
84            });
85        }
86
87        Ok(classification)
88    }
89
90    pub fn request_approval(
91        &self,
92        tool_name: &str,
93        tool_input: &str,
94        session_id: Option<&str>,
95        turn_id: Option<&str>,
96        requested_authority: InputAuthority,
97    ) -> Result<ApprovalRequest> {
98        let id = Uuid::new_v4().to_string();
99        let timeout_at = Utc::now() + Duration::seconds(self.config.timeout_seconds as i64);
100
101        let request = ApprovalRequest {
102            id: id.clone(),
103            tool_name: tool_name.to_string(),
104            tool_input: tool_input.to_string(),
105            session_id: session_id.map(|s| s.to_string()),
106            turn_id: turn_id.map(|s| s.to_string()),
107            requested_authority,
108            status: ApprovalStatus::Pending,
109            decided_by: None,
110            decided_at: None,
111            timeout_at,
112            created_at: Utc::now(),
113        };
114
115        debug!(id = %id, tool = tool_name, "approval requested");
116
117        let mut pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
118        pending.insert(id, request.clone());
119
120        Ok(request)
121    }
122
123    pub fn approve(&self, request_id: &str, decided_by: &str) -> Result<ApprovalRequest> {
124        let mut pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
125        let request = pending
126            .get_mut(request_id)
127            .ok_or_else(|| RoboticusError::Tool {
128                tool: "approvals".into(),
129                message: format!("request {request_id} not found"),
130            })?;
131
132        if request.status != ApprovalStatus::Pending {
133            return Err(RoboticusError::Tool {
134                tool: "approvals".into(),
135                message: format!("request {request_id} is already {:?}", request.status),
136            });
137        }
138
139        request.status = ApprovalStatus::Approved;
140        request.decided_by = Some(decided_by.to_string());
141        request.decided_at = Some(Utc::now());
142
143        debug!(id = request_id, by = decided_by, "approval granted");
144        Ok(request.clone())
145    }
146
147    pub fn deny(&self, request_id: &str, decided_by: &str) -> Result<ApprovalRequest> {
148        let mut pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
149        let request = pending
150            .get_mut(request_id)
151            .ok_or_else(|| RoboticusError::Tool {
152                tool: "approvals".into(),
153                message: format!("request {request_id} not found"),
154            })?;
155
156        if request.status != ApprovalStatus::Pending {
157            return Err(RoboticusError::Tool {
158                tool: "approvals".into(),
159                message: format!("request {request_id} is already {:?}", request.status),
160            });
161        }
162
163        request.status = ApprovalStatus::Denied;
164        request.decided_by = Some(decided_by.to_string());
165        request.decided_at = Some(Utc::now());
166
167        warn!(id = request_id, by = decided_by, "approval denied");
168        Ok(request.clone())
169    }
170
171    pub fn get_request(&self, request_id: &str) -> Option<ApprovalRequest> {
172        let pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
173        pending.get(request_id).cloned()
174    }
175
176    pub fn list_pending(&self) -> Vec<ApprovalRequest> {
177        let pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
178        pending
179            .values()
180            .filter(|r| r.status == ApprovalStatus::Pending)
181            .cloned()
182            .collect()
183    }
184
185    pub fn list_all(&self) -> Vec<ApprovalRequest> {
186        let pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
187        pending.values().cloned().collect()
188    }
189
190    pub fn expire_timed_out(&self) -> Vec<String> {
191        let now = Utc::now();
192        let mut pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
193        let mut expired = Vec::new();
194
195        for (id, request) in pending.iter_mut() {
196            if request.status == ApprovalStatus::Pending && now >= request.timeout_at {
197                request.status = ApprovalStatus::TimedOut;
198                expired.push(id.clone());
199                debug!(id = %id, tool = %request.tool_name, "approval timed out");
200            }
201        }
202
203        expired
204    }
205
206    pub fn clear_decided(&self) -> usize {
207        let mut pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
208        let before = pending.len();
209        pending.retain(|_, r| r.status == ApprovalStatus::Pending);
210        before - pending.len()
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    fn test_config() -> ApprovalsConfig {
219        ApprovalsConfig {
220            enabled: true,
221            gated_tools: vec!["shell".into(), "write_file".into()],
222            blocked_tools: vec!["rm_rf".into()],
223            timeout_seconds: 60,
224        }
225    }
226
227    fn disabled_config() -> ApprovalsConfig {
228        ApprovalsConfig {
229            enabled: false,
230            ..test_config()
231        }
232    }
233
234    #[test]
235    fn classify_safe_tool() {
236        let mgr = ApprovalManager::new(test_config());
237        assert_eq!(mgr.classify_tool("read_file"), ToolClassification::Safe);
238    }
239
240    #[test]
241    fn classify_gated_tool() {
242        let mgr = ApprovalManager::new(test_config());
243        assert_eq!(mgr.classify_tool("shell"), ToolClassification::Gated);
244        assert_eq!(mgr.classify_tool("write_file"), ToolClassification::Gated);
245    }
246
247    #[test]
248    fn classify_blocked_tool() {
249        let mgr = ApprovalManager::new(test_config());
250        assert_eq!(mgr.classify_tool("rm_rf"), ToolClassification::Blocked);
251    }
252
253    #[test]
254    fn check_tool_blocked_returns_error() {
255        let mgr = ApprovalManager::new(test_config());
256        let result = mgr.check_tool("rm_rf");
257        assert!(result.is_err());
258    }
259
260    #[test]
261    fn check_tool_disabled_always_safe() {
262        let mgr = ApprovalManager::new(disabled_config());
263        assert_eq!(mgr.check_tool("shell").unwrap(), ToolClassification::Safe);
264        assert_eq!(mgr.check_tool("rm_rf").unwrap(), ToolClassification::Safe);
265    }
266
267    #[test]
268    fn request_approval_creates_pending() {
269        let mgr = ApprovalManager::new(test_config());
270        let req = mgr
271            .request_approval(
272                "shell",
273                "ls -la",
274                Some("sess-1"),
275                Some("turn-1"),
276                InputAuthority::External,
277            )
278            .unwrap();
279        assert_eq!(req.status, ApprovalStatus::Pending);
280        assert_eq!(req.tool_name, "shell");
281        assert_eq!(req.requested_authority, InputAuthority::External);
282        assert!(req.decided_by.is_none());
283    }
284
285    #[test]
286    fn request_approval_preserves_requested_authority() {
287        let mgr = ApprovalManager::new(test_config());
288        let req = mgr
289            .request_approval("shell", "ls", None, None, InputAuthority::Peer)
290            .unwrap();
291        assert_eq!(req.requested_authority, InputAuthority::Peer);
292    }
293
294    #[test]
295    fn approve_request() {
296        let mgr = ApprovalManager::new(test_config());
297        let req = mgr
298            .request_approval("shell", "ls", None, None, InputAuthority::External)
299            .unwrap();
300        let approved = mgr.approve(&req.id, "admin").unwrap();
301        assert_eq!(approved.status, ApprovalStatus::Approved);
302        assert_eq!(approved.decided_by.as_deref(), Some("admin"));
303    }
304
305    #[test]
306    fn deny_request() {
307        let mgr = ApprovalManager::new(test_config());
308        let req = mgr
309            .request_approval("write_file", "{}", None, None, InputAuthority::External)
310            .unwrap();
311        let denied = mgr.deny(&req.id, "admin").unwrap();
312        assert_eq!(denied.status, ApprovalStatus::Denied);
313    }
314
315    #[test]
316    fn approve_nonexistent_fails() {
317        let mgr = ApprovalManager::new(test_config());
318        let result = mgr.approve("nonexistent", "admin");
319        assert!(result.is_err());
320    }
321
322    #[test]
323    fn double_approve_fails() {
324        let mgr = ApprovalManager::new(test_config());
325        let req = mgr
326            .request_approval("shell", "cmd", None, None, InputAuthority::External)
327            .unwrap();
328        mgr.approve(&req.id, "admin").unwrap();
329        let result = mgr.approve(&req.id, "admin2");
330        assert!(result.is_err());
331    }
332
333    #[test]
334    fn list_pending_filters() {
335        let mgr = ApprovalManager::new(test_config());
336        mgr.request_approval("shell", "1", None, None, InputAuthority::External)
337            .unwrap();
338        let req2 = mgr
339            .request_approval("write_file", "2", None, None, InputAuthority::External)
340            .unwrap();
341        mgr.approve(&req2.id, "admin").unwrap();
342
343        let pending = mgr.list_pending();
344        assert_eq!(pending.len(), 1);
345        assert_eq!(pending[0].tool_name, "shell");
346    }
347
348    #[test]
349    fn expire_timed_out() {
350        let mgr = ApprovalManager::new(ApprovalsConfig {
351            timeout_seconds: 0,
352            ..test_config()
353        });
354        mgr.request_approval("shell", "cmd", None, None, InputAuthority::External)
355            .unwrap();
356        std::thread::sleep(std::time::Duration::from_millis(10));
357        let expired = mgr.expire_timed_out();
358        assert_eq!(expired.len(), 1);
359        assert_eq!(mgr.list_pending().len(), 0);
360    }
361
362    #[test]
363    fn clear_decided() {
364        let mgr = ApprovalManager::new(test_config());
365        mgr.request_approval("shell", "1", None, None, InputAuthority::External)
366            .unwrap();
367        let req2 = mgr
368            .request_approval("write_file", "2", None, None, InputAuthority::External)
369            .unwrap();
370        mgr.approve(&req2.id, "admin").unwrap();
371
372        let cleared = mgr.clear_decided();
373        assert_eq!(cleared, 1);
374        assert_eq!(mgr.list_all().len(), 1);
375    }
376}