1use std::collections::{HashMap, HashSet};
7
8use tokio::sync::{mpsc, oneshot, Mutex};
9
10use super::ask_for_permissions::{PermissionCategory, PermissionRequest, PermissionResponse};
11use crate::controller::types::{ControllerEvent, TurnId};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum PermissionError {
16 NotFound,
18 AlreadyResponded,
20 SendFailed,
22 EventSendFailed,
24}
25
26impl std::fmt::Display for PermissionError {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 PermissionError::NotFound => write!(f, "No pending permission request found"),
30 PermissionError::AlreadyResponded => write!(f, "Permission already responded to"),
31 PermissionError::SendFailed => write!(f, "Failed to send response"),
32 PermissionError::EventSendFailed => write!(f, "Failed to send event notification"),
33 }
34 }
35}
36
37impl std::error::Error for PermissionError {}
38
39#[derive(Debug, Clone)]
41pub struct PendingPermissionInfo {
42 pub tool_use_id: String,
44 pub session_id: i64,
46 pub request: PermissionRequest,
48 pub turn_id: Option<TurnId>,
50}
51
52#[derive(Debug, Clone, Hash, PartialEq, Eq)]
54pub struct PermissionGrant {
55 pub category: PermissionCategory,
57 pub action_pattern: String,
59}
60
61struct PendingPermission {
63 session_id: i64,
64 request: PermissionRequest,
65 turn_id: Option<TurnId>,
66 responder: oneshot::Sender<PermissionResponse>,
67}
68
69pub struct PermissionRegistry {
75 pending: Mutex<HashMap<String, PendingPermission>>,
77 session_grants: Mutex<HashMap<i64, HashSet<PermissionGrant>>>,
79 event_tx: mpsc::Sender<ControllerEvent>,
81}
82
83impl PermissionRegistry {
84 pub fn new(event_tx: mpsc::Sender<ControllerEvent>) -> Self {
89 Self {
90 pending: Mutex::new(HashMap::new()),
91 session_grants: Mutex::new(HashMap::new()),
92 event_tx,
93 }
94 }
95
96 pub async fn is_granted(&self, session_id: i64, request: &PermissionRequest) -> bool {
107 let grants = self.session_grants.lock().await;
108 if let Some(session_grants) = grants.get(&session_id) {
109 session_grants.iter().any(|grant| {
111 grant.category == request.category && grant.action_pattern == request.action
112 })
113 } else {
114 false
115 }
116 }
117
118 pub async fn register(
132 &self,
133 tool_use_id: String,
134 session_id: i64,
135 request: PermissionRequest,
136 turn_id: Option<TurnId>,
137 ) -> Result<oneshot::Receiver<PermissionResponse>, PermissionError> {
138 let (tx, rx) = oneshot::channel();
139
140 {
142 let mut pending = self.pending.lock().await;
143 pending.insert(
144 tool_use_id.clone(),
145 PendingPermission {
146 session_id,
147 request: request.clone(),
148 turn_id: turn_id.clone(),
149 responder: tx,
150 },
151 );
152 }
153
154 self.event_tx
156 .send(ControllerEvent::PermissionRequired {
157 session_id,
158 tool_use_id,
159 request,
160 turn_id,
161 })
162 .await
163 .map_err(|_| PermissionError::EventSendFailed)?;
164
165 Ok(rx)
166 }
167
168 pub async fn respond(
179 &self,
180 tool_use_id: &str,
181 response: PermissionResponse,
182 ) -> Result<(), PermissionError> {
183 let pending_permission = {
184 let mut pending = self.pending.lock().await;
185 pending
186 .remove(tool_use_id)
187 .ok_or(PermissionError::NotFound)?
188 };
189
190 if response.granted {
192 if let Some(ref scope) = response.scope {
193 if *scope == super::ask_for_permissions::PermissionScope::Session {
194 let mut grants = self.session_grants.lock().await;
195 let session_grants = grants
196 .entry(pending_permission.session_id)
197 .or_insert_with(HashSet::new);
198 session_grants.insert(PermissionGrant {
199 category: pending_permission.request.category.clone(),
200 action_pattern: pending_permission.request.action.clone(),
201 });
202 }
203 }
204 }
205
206 pending_permission
207 .responder
208 .send(response)
209 .map_err(|_| PermissionError::SendFailed)
210 }
211
212 pub async fn cancel(&self, tool_use_id: &str) -> Result<(), PermissionError> {
223 let mut pending = self.pending.lock().await;
224 if pending.remove(tool_use_id).is_some() {
225 Ok(())
228 } else {
229 Err(PermissionError::NotFound)
230 }
231 }
232
233 pub async fn pending_for_session(&self, session_id: i64) -> Vec<PendingPermissionInfo> {
244 let pending = self.pending.lock().await;
245 pending
246 .iter()
247 .filter(|(_, perm)| perm.session_id == session_id)
248 .map(|(tool_use_id, perm)| PendingPermissionInfo {
249 tool_use_id: tool_use_id.clone(),
250 session_id: perm.session_id,
251 request: perm.request.clone(),
252 turn_id: perm.turn_id.clone(),
253 })
254 .collect()
255 }
256
257 pub async fn cancel_session(&self, session_id: i64) {
265 let mut pending = self.pending.lock().await;
266 pending.retain(|_, perm| perm.session_id != session_id);
267 }
269
270 pub async fn clear_session(&self, session_id: i64) {
277 self.cancel_session(session_id).await;
279
280 let mut grants = self.session_grants.lock().await;
282 grants.remove(&session_id);
283 }
284
285 pub async fn has_pending(&self, session_id: i64) -> bool {
293 let pending = self.pending.lock().await;
294 pending.values().any(|perm| perm.session_id == session_id)
295 }
296
297 pub async fn pending_count(&self) -> usize {
299 let pending = self.pending.lock().await;
300 pending.len()
301 }
302
303 pub async fn session_grants(&self, session_id: i64) -> HashSet<PermissionGrant> {
311 let grants = self.session_grants.lock().await;
312 grants.get(&session_id).cloned().unwrap_or_default()
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::controller::tools::ask_for_permissions::PermissionScope;
320
321 fn create_test_request() -> PermissionRequest {
322 PermissionRequest {
323 action: "Delete file /tmp/foo.txt".to_string(),
324 reason: Some("User requested cleanup".to_string()),
325 resources: vec!["/tmp/foo.txt".to_string()],
326 category: PermissionCategory::FileDelete,
327 }
328 }
329
330 fn create_grant_response() -> PermissionResponse {
331 PermissionResponse {
332 granted: true,
333 scope: Some(PermissionScope::Session),
334 message: None,
335 }
336 }
337
338 fn create_deny_response() -> PermissionResponse {
339 PermissionResponse {
340 granted: false,
341 scope: None,
342 message: Some("Not allowed".to_string()),
343 }
344 }
345
346 #[tokio::test]
347 async fn test_register_and_respond() {
348 let (event_tx, mut event_rx) = mpsc::channel(10);
349 let registry = PermissionRegistry::new(event_tx);
350
351 let request = create_test_request();
352 let response = create_grant_response();
353
354 let rx = registry
356 .register("tool_123".to_string(), 1, request.clone(), None)
357 .await
358 .unwrap();
359
360 let event = event_rx.recv().await.unwrap();
362 if let ControllerEvent::PermissionRequired {
363 session_id,
364 tool_use_id,
365 ..
366 } = event
367 {
368 assert_eq!(session_id, 1);
369 assert_eq!(tool_use_id, "tool_123");
370 } else {
371 panic!("Expected PermissionRequired event");
372 }
373
374 registry
376 .respond("tool_123", response.clone())
377 .await
378 .unwrap();
379
380 let received = rx.await.unwrap();
382 assert!(received.granted);
383 assert_eq!(received.scope, Some(PermissionScope::Session));
384 }
385
386 #[tokio::test]
387 async fn test_respond_not_found() {
388 let (event_tx, _event_rx) = mpsc::channel(10);
389 let registry = PermissionRegistry::new(event_tx);
390
391 let response = create_grant_response();
392 let result = registry.respond("nonexistent", response).await;
393
394 assert_eq!(result, Err(PermissionError::NotFound));
395 }
396
397 #[tokio::test]
398 async fn test_session_grant_caching() {
399 let (event_tx, _event_rx) = mpsc::channel(10);
400 let registry = PermissionRegistry::new(event_tx);
401
402 let request = create_test_request();
403
404 assert!(!registry.is_granted(1, &request).await);
406
407 let rx = registry
409 .register("tool_1".to_string(), 1, request.clone(), None)
410 .await
411 .unwrap();
412 registry.respond("tool_1", create_grant_response()).await.unwrap();
413
414 let _ = rx.await;
416
417 assert!(registry.is_granted(1, &request).await);
419
420 assert!(!registry.is_granted(2, &request).await);
422 }
423
424 #[tokio::test]
425 async fn test_once_scope_not_cached() {
426 let (event_tx, _event_rx) = mpsc::channel(10);
427 let registry = PermissionRegistry::new(event_tx);
428
429 let request = create_test_request();
430
431 let rx = registry
433 .register("tool_1".to_string(), 1, request.clone(), None)
434 .await
435 .unwrap();
436 registry
437 .respond(
438 "tool_1",
439 PermissionResponse {
440 granted: true,
441 scope: Some(PermissionScope::Once),
442 message: None,
443 },
444 )
445 .await
446 .unwrap();
447
448 let _ = rx.await;
450
451 assert!(!registry.is_granted(1, &request).await);
453 }
454
455 #[tokio::test]
456 async fn test_denied_not_cached() {
457 let (event_tx, _event_rx) = mpsc::channel(10);
458 let registry = PermissionRegistry::new(event_tx);
459
460 let request = create_test_request();
461
462 let rx = registry
464 .register("tool_1".to_string(), 1, request.clone(), None)
465 .await
466 .unwrap();
467 registry.respond("tool_1", create_deny_response()).await.unwrap();
468
469 let _ = rx.await;
471
472 assert!(!registry.is_granted(1, &request).await);
474 }
475
476 #[tokio::test]
477 async fn test_pending_for_session() {
478 let (event_tx, _event_rx) = mpsc::channel(10);
479 let registry = PermissionRegistry::new(event_tx);
480
481 let request = create_test_request();
482
483 let _ = registry
485 .register("tool_1".to_string(), 1, request.clone(), None)
486 .await;
487 let _ = registry
488 .register("tool_2".to_string(), 1, request.clone(), None)
489 .await;
490 let _ = registry
491 .register("tool_3".to_string(), 2, request.clone(), None)
492 .await;
493
494 let pending = registry.pending_for_session(1).await;
496 assert_eq!(pending.len(), 2);
497
498 let pending = registry.pending_for_session(2).await;
500 assert_eq!(pending.len(), 1);
501
502 let pending = registry.pending_for_session(999).await;
504 assert_eq!(pending.len(), 0);
505 }
506
507 #[tokio::test]
508 async fn test_cancel_session() {
509 let (event_tx, _event_rx) = mpsc::channel(10);
510 let registry = PermissionRegistry::new(event_tx);
511
512 let request = create_test_request();
513
514 let rx1 = registry
516 .register("tool_1".to_string(), 1, request.clone(), None)
517 .await
518 .unwrap();
519 registry.respond("tool_1", create_grant_response()).await.unwrap();
520
521 let _ = rx1.await;
523
524 let rx2 = registry
526 .register("tool_2".to_string(), 2, request.clone(), None)
527 .await
528 .unwrap();
529
530 registry.clear_session(1).await;
532
533 assert!(!registry.is_granted(1, &request).await);
535
536 assert!(registry.has_pending(2).await);
538
539 registry.respond("tool_2", create_grant_response()).await.unwrap();
541 let received = rx2.await.unwrap();
542 assert!(received.granted);
543 }
544
545 #[tokio::test]
546 async fn test_has_pending() {
547 let (event_tx, _event_rx) = mpsc::channel(10);
548 let registry = PermissionRegistry::new(event_tx);
549
550 assert!(!registry.has_pending(1).await);
551
552 let request = create_test_request();
553 let _ = registry
554 .register("tool_1".to_string(), 1, request, None)
555 .await;
556
557 assert!(registry.has_pending(1).await);
558 assert!(!registry.has_pending(2).await);
559 }
560
561 #[tokio::test]
562 async fn test_pending_count() {
563 let (event_tx, _event_rx) = mpsc::channel(10);
564 let registry = PermissionRegistry::new(event_tx);
565
566 assert_eq!(registry.pending_count().await, 0);
567
568 let request = create_test_request();
569 let _ = registry
570 .register("tool_1".to_string(), 1, request.clone(), None)
571 .await;
572 assert_eq!(registry.pending_count().await, 1);
573
574 let _ = registry
575 .register("tool_2".to_string(), 1, request, None)
576 .await;
577 assert_eq!(registry.pending_count().await, 2);
578 }
579
580 #[tokio::test]
581 async fn test_cancel() {
582 let (event_tx, _event_rx) = mpsc::channel(10);
583 let registry = PermissionRegistry::new(event_tx);
584
585 let request = create_test_request();
586 let rx = registry
587 .register("tool_1".to_string(), 1, request, None)
588 .await
589 .unwrap();
590
591 registry.cancel("tool_1").await.unwrap();
593
594 assert!(rx.await.is_err());
596
597 let result = registry.cancel("nonexistent").await;
599 assert_eq!(result, Err(PermissionError::NotFound));
600 }
601}