1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use chrono::{DateTime, Duration, Utc};
5use serde::{Deserialize, Serialize};
6use tracing::{debug, warn};
7use uuid::Uuid;
8
9use roboticus_core::config::ApprovalsConfig;
10use roboticus_core::{InputAuthority, Result, RoboticusError};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum ToolClassification {
14 Safe,
15 Gated,
16 Blocked,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum ApprovalStatus {
21 Pending,
22 Approved,
23 Denied,
24 TimedOut,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ApprovalRequest {
29 pub id: String,
30 pub tool_name: String,
31 pub tool_input: String,
32 pub session_id: Option<String>,
33 pub turn_id: Option<String>,
37 #[serde(default = "default_requested_authority")]
38 pub requested_authority: InputAuthority,
39 pub status: ApprovalStatus,
40 pub decided_by: Option<String>,
41 pub decided_at: Option<DateTime<Utc>>,
42 pub timeout_at: DateTime<Utc>,
43 pub created_at: DateTime<Utc>,
44}
45
46fn default_requested_authority() -> InputAuthority {
47 InputAuthority::External
48}
49
50pub struct ApprovalManager {
51 config: ApprovalsConfig,
52 pending: Arc<Mutex<HashMap<String, ApprovalRequest>>>,
53}
54
55impl ApprovalManager {
56 pub fn new(config: ApprovalsConfig) -> Self {
57 Self {
58 config,
59 pending: Arc::new(Mutex::new(HashMap::new())),
60 }
61 }
62
63 pub fn classify_tool(&self, tool_name: &str) -> ToolClassification {
64 if self.config.blocked_tools.iter().any(|t| t == tool_name) {
65 ToolClassification::Blocked
66 } else if self.config.gated_tools.iter().any(|t| t == tool_name) {
67 ToolClassification::Gated
68 } else {
69 ToolClassification::Safe
70 }
71 }
72
73 pub fn check_tool(&self, tool_name: &str) -> Result<ToolClassification> {
74 if !self.config.enabled {
75 return Ok(ToolClassification::Safe);
76 }
77
78 let classification = self.classify_tool(tool_name);
79
80 if classification == ToolClassification::Blocked {
81 return Err(RoboticusError::Tool {
82 tool: tool_name.to_string(),
83 message: "tool is blocked by policy".into(),
84 });
85 }
86
87 Ok(classification)
88 }
89
90 pub fn request_approval(
91 &self,
92 tool_name: &str,
93 tool_input: &str,
94 session_id: Option<&str>,
95 turn_id: Option<&str>,
96 requested_authority: InputAuthority,
97 ) -> Result<ApprovalRequest> {
98 let id = Uuid::new_v4().to_string();
99 let timeout_at = Utc::now() + Duration::seconds(self.config.timeout_seconds as i64);
100
101 let request = ApprovalRequest {
102 id: id.clone(),
103 tool_name: tool_name.to_string(),
104 tool_input: tool_input.to_string(),
105 session_id: session_id.map(|s| s.to_string()),
106 turn_id: turn_id.map(|s| s.to_string()),
107 requested_authority,
108 status: ApprovalStatus::Pending,
109 decided_by: None,
110 decided_at: None,
111 timeout_at,
112 created_at: Utc::now(),
113 };
114
115 debug!(id = %id, tool = tool_name, "approval requested");
116
117 let mut pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
118 pending.insert(id, request.clone());
119
120 Ok(request)
121 }
122
123 pub fn approve(&self, request_id: &str, decided_by: &str) -> Result<ApprovalRequest> {
124 let mut pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
125 let request = pending
126 .get_mut(request_id)
127 .ok_or_else(|| RoboticusError::Tool {
128 tool: "approvals".into(),
129 message: format!("request {request_id} not found"),
130 })?;
131
132 if request.status != ApprovalStatus::Pending {
133 return Err(RoboticusError::Tool {
134 tool: "approvals".into(),
135 message: format!("request {request_id} is already {:?}", request.status),
136 });
137 }
138
139 request.status = ApprovalStatus::Approved;
140 request.decided_by = Some(decided_by.to_string());
141 request.decided_at = Some(Utc::now());
142
143 debug!(id = request_id, by = decided_by, "approval granted");
144 Ok(request.clone())
145 }
146
147 pub fn deny(&self, request_id: &str, decided_by: &str) -> Result<ApprovalRequest> {
148 let mut pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
149 let request = pending
150 .get_mut(request_id)
151 .ok_or_else(|| RoboticusError::Tool {
152 tool: "approvals".into(),
153 message: format!("request {request_id} not found"),
154 })?;
155
156 if request.status != ApprovalStatus::Pending {
157 return Err(RoboticusError::Tool {
158 tool: "approvals".into(),
159 message: format!("request {request_id} is already {:?}", request.status),
160 });
161 }
162
163 request.status = ApprovalStatus::Denied;
164 request.decided_by = Some(decided_by.to_string());
165 request.decided_at = Some(Utc::now());
166
167 warn!(id = request_id, by = decided_by, "approval denied");
168 Ok(request.clone())
169 }
170
171 pub fn get_request(&self, request_id: &str) -> Option<ApprovalRequest> {
172 let pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
173 pending.get(request_id).cloned()
174 }
175
176 pub fn list_pending(&self) -> Vec<ApprovalRequest> {
177 let pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
178 pending
179 .values()
180 .filter(|r| r.status == ApprovalStatus::Pending)
181 .cloned()
182 .collect()
183 }
184
185 pub fn list_all(&self) -> Vec<ApprovalRequest> {
186 let pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
187 pending.values().cloned().collect()
188 }
189
190 pub fn expire_timed_out(&self) -> Vec<String> {
191 let now = Utc::now();
192 let mut pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
193 let mut expired = Vec::new();
194
195 for (id, request) in pending.iter_mut() {
196 if request.status == ApprovalStatus::Pending && now >= request.timeout_at {
197 request.status = ApprovalStatus::TimedOut;
198 expired.push(id.clone());
199 debug!(id = %id, tool = %request.tool_name, "approval timed out");
200 }
201 }
202
203 expired
204 }
205
206 pub fn clear_decided(&self) -> usize {
207 let mut pending = self.pending.lock().unwrap_or_else(|e| e.into_inner());
208 let before = pending.len();
209 pending.retain(|_, r| r.status == ApprovalStatus::Pending);
210 before - pending.len()
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 fn test_config() -> ApprovalsConfig {
219 ApprovalsConfig {
220 enabled: true,
221 gated_tools: vec!["shell".into(), "write_file".into()],
222 blocked_tools: vec!["rm_rf".into()],
223 timeout_seconds: 60,
224 }
225 }
226
227 fn disabled_config() -> ApprovalsConfig {
228 ApprovalsConfig {
229 enabled: false,
230 ..test_config()
231 }
232 }
233
234 #[test]
235 fn classify_safe_tool() {
236 let mgr = ApprovalManager::new(test_config());
237 assert_eq!(mgr.classify_tool("read_file"), ToolClassification::Safe);
238 }
239
240 #[test]
241 fn classify_gated_tool() {
242 let mgr = ApprovalManager::new(test_config());
243 assert_eq!(mgr.classify_tool("shell"), ToolClassification::Gated);
244 assert_eq!(mgr.classify_tool("write_file"), ToolClassification::Gated);
245 }
246
247 #[test]
248 fn classify_blocked_tool() {
249 let mgr = ApprovalManager::new(test_config());
250 assert_eq!(mgr.classify_tool("rm_rf"), ToolClassification::Blocked);
251 }
252
253 #[test]
254 fn check_tool_blocked_returns_error() {
255 let mgr = ApprovalManager::new(test_config());
256 let result = mgr.check_tool("rm_rf");
257 assert!(result.is_err());
258 }
259
260 #[test]
261 fn check_tool_disabled_always_safe() {
262 let mgr = ApprovalManager::new(disabled_config());
263 assert_eq!(mgr.check_tool("shell").unwrap(), ToolClassification::Safe);
264 assert_eq!(mgr.check_tool("rm_rf").unwrap(), ToolClassification::Safe);
265 }
266
267 #[test]
268 fn request_approval_creates_pending() {
269 let mgr = ApprovalManager::new(test_config());
270 let req = mgr
271 .request_approval(
272 "shell",
273 "ls -la",
274 Some("sess-1"),
275 Some("turn-1"),
276 InputAuthority::External,
277 )
278 .unwrap();
279 assert_eq!(req.status, ApprovalStatus::Pending);
280 assert_eq!(req.tool_name, "shell");
281 assert_eq!(req.requested_authority, InputAuthority::External);
282 assert!(req.decided_by.is_none());
283 }
284
285 #[test]
286 fn request_approval_preserves_requested_authority() {
287 let mgr = ApprovalManager::new(test_config());
288 let req = mgr
289 .request_approval("shell", "ls", None, None, InputAuthority::Peer)
290 .unwrap();
291 assert_eq!(req.requested_authority, InputAuthority::Peer);
292 }
293
294 #[test]
295 fn approve_request() {
296 let mgr = ApprovalManager::new(test_config());
297 let req = mgr
298 .request_approval("shell", "ls", None, None, InputAuthority::External)
299 .unwrap();
300 let approved = mgr.approve(&req.id, "admin").unwrap();
301 assert_eq!(approved.status, ApprovalStatus::Approved);
302 assert_eq!(approved.decided_by.as_deref(), Some("admin"));
303 }
304
305 #[test]
306 fn deny_request() {
307 let mgr = ApprovalManager::new(test_config());
308 let req = mgr
309 .request_approval("write_file", "{}", None, None, InputAuthority::External)
310 .unwrap();
311 let denied = mgr.deny(&req.id, "admin").unwrap();
312 assert_eq!(denied.status, ApprovalStatus::Denied);
313 }
314
315 #[test]
316 fn approve_nonexistent_fails() {
317 let mgr = ApprovalManager::new(test_config());
318 let result = mgr.approve("nonexistent", "admin");
319 assert!(result.is_err());
320 }
321
322 #[test]
323 fn double_approve_fails() {
324 let mgr = ApprovalManager::new(test_config());
325 let req = mgr
326 .request_approval("shell", "cmd", None, None, InputAuthority::External)
327 .unwrap();
328 mgr.approve(&req.id, "admin").unwrap();
329 let result = mgr.approve(&req.id, "admin2");
330 assert!(result.is_err());
331 }
332
333 #[test]
334 fn list_pending_filters() {
335 let mgr = ApprovalManager::new(test_config());
336 mgr.request_approval("shell", "1", None, None, InputAuthority::External)
337 .unwrap();
338 let req2 = mgr
339 .request_approval("write_file", "2", None, None, InputAuthority::External)
340 .unwrap();
341 mgr.approve(&req2.id, "admin").unwrap();
342
343 let pending = mgr.list_pending();
344 assert_eq!(pending.len(), 1);
345 assert_eq!(pending[0].tool_name, "shell");
346 }
347
348 #[test]
349 fn expire_timed_out() {
350 let mgr = ApprovalManager::new(ApprovalsConfig {
351 timeout_seconds: 0,
352 ..test_config()
353 });
354 mgr.request_approval("shell", "cmd", None, None, InputAuthority::External)
355 .unwrap();
356 std::thread::sleep(std::time::Duration::from_millis(10));
357 let expired = mgr.expire_timed_out();
358 assert_eq!(expired.len(), 1);
359 assert_eq!(mgr.list_pending().len(), 0);
360 }
361
362 #[test]
363 fn clear_decided() {
364 let mgr = ApprovalManager::new(test_config());
365 mgr.request_approval("shell", "1", None, None, InputAuthority::External)
366 .unwrap();
367 let req2 = mgr
368 .request_approval("write_file", "2", None, None, InputAuthority::External)
369 .unwrap();
370 mgr.approve(&req2.id, "admin").unwrap();
371
372 let cleared = mgr.clear_decided();
373 assert_eq!(cleared, 1);
374 assert_eq!(mgr.list_all().len(), 1);
375 }
376}