1use std::collections::{HashMap, HashSet};
7
8use tokio::sync::{mpsc, oneshot, Mutex};
9
10use super::ask_for_permissions::{PermissionCategory, PermissionRequest, PermissionResponse};
11use crate::controller::types::{ControllerEvent, TurnId};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum PermissionError {
16 NotFound,
18 AlreadyResponded,
20 SendFailed,
22 EventSendFailed,
24}
25
26impl std::fmt::Display for PermissionError {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 PermissionError::NotFound => write!(f, "No pending permission request found"),
30 PermissionError::AlreadyResponded => write!(f, "Permission already responded to"),
31 PermissionError::SendFailed => write!(f, "Failed to send response"),
32 PermissionError::EventSendFailed => write!(f, "Failed to send event notification"),
33 }
34 }
35}
36
37impl std::error::Error for PermissionError {}
38
39#[derive(Debug, Clone)]
41pub struct PendingPermissionInfo {
42 pub tool_use_id: String,
44 pub session_id: i64,
46 pub request: PermissionRequest,
48 pub turn_id: Option<TurnId>,
50}
51
52#[derive(Debug, Clone, Hash, PartialEq, Eq)]
54pub struct PermissionGrant {
55 pub category: PermissionCategory,
57 pub action_pattern: Option<String>,
60}
61
62struct PendingPermission {
64 session_id: i64,
65 request: PermissionRequest,
66 turn_id: Option<TurnId>,
67 responder: oneshot::Sender<PermissionResponse>,
68}
69
70pub struct PermissionRegistry {
76 pending: Mutex<HashMap<String, PendingPermission>>,
78 session_grants: Mutex<HashMap<i64, HashSet<PermissionGrant>>>,
80 event_tx: mpsc::Sender<ControllerEvent>,
82}
83
84impl PermissionRegistry {
85 pub fn new(event_tx: mpsc::Sender<ControllerEvent>) -> Self {
90 Self {
91 pending: Mutex::new(HashMap::new()),
92 session_grants: Mutex::new(HashMap::new()),
93 event_tx,
94 }
95 }
96
97 pub async fn is_granted(&self, session_id: i64, request: &PermissionRequest) -> bool {
112 let grants = self.session_grants.lock().await;
113 if let Some(session_grants) = grants.get(&session_id) {
114 session_grants.iter().any(|grant| {
116 if grant.category != request.category {
117 return false;
118 }
119 match &grant.action_pattern {
121 None => true,
122 Some(pattern) => pattern == &request.action,
123 }
124 })
125 } else {
126 false
127 }
128 }
129
130 pub async fn register(
144 &self,
145 tool_use_id: String,
146 session_id: i64,
147 request: PermissionRequest,
148 turn_id: Option<TurnId>,
149 ) -> Result<oneshot::Receiver<PermissionResponse>, PermissionError> {
150 let (tx, rx) = oneshot::channel();
151
152 {
154 let mut pending = self.pending.lock().await;
155 pending.insert(
156 tool_use_id.clone(),
157 PendingPermission {
158 session_id,
159 request: request.clone(),
160 turn_id: turn_id.clone(),
161 responder: tx,
162 },
163 );
164 }
165
166 self.event_tx
168 .send(ControllerEvent::PermissionRequired {
169 session_id,
170 tool_use_id,
171 request,
172 turn_id,
173 })
174 .await
175 .map_err(|_| PermissionError::EventSendFailed)?;
176
177 Ok(rx)
178 }
179
180 pub async fn respond(
191 &self,
192 tool_use_id: &str,
193 response: PermissionResponse,
194 ) -> Result<(), PermissionError> {
195 let pending_permission = {
196 let mut pending = self.pending.lock().await;
197 pending
198 .remove(tool_use_id)
199 .ok_or(PermissionError::NotFound)?
200 };
201
202 if response.granted {
204 if let Some(ref scope) = response.scope {
205 use super::ask_for_permissions::PermissionScope;
206 match scope {
207 PermissionScope::Session => {
208 let mut grants = self.session_grants.lock().await;
210 let session_grants = grants
211 .entry(pending_permission.session_id)
212 .or_insert_with(HashSet::new);
213 session_grants.insert(PermissionGrant {
214 category: pending_permission.request.category.clone(),
215 action_pattern: Some(pending_permission.request.action.clone()),
216 });
217 }
218 PermissionScope::CategorySession => {
219 let mut grants = self.session_grants.lock().await;
221 let session_grants = grants
222 .entry(pending_permission.session_id)
223 .or_insert_with(HashSet::new);
224 session_grants.insert(PermissionGrant {
225 category: pending_permission.request.category.clone(),
226 action_pattern: None, });
228 }
229 PermissionScope::Once => {
230 }
232 }
233 }
234 }
235
236 pending_permission
237 .responder
238 .send(response)
239 .map_err(|_| PermissionError::SendFailed)
240 }
241
242 pub async fn cancel(&self, tool_use_id: &str) -> Result<(), PermissionError> {
253 let mut pending = self.pending.lock().await;
254 if pending.remove(tool_use_id).is_some() {
255 Ok(())
258 } else {
259 Err(PermissionError::NotFound)
260 }
261 }
262
263 pub async fn pending_for_session(&self, session_id: i64) -> Vec<PendingPermissionInfo> {
274 let pending = self.pending.lock().await;
275 pending
276 .iter()
277 .filter(|(_, perm)| perm.session_id == session_id)
278 .map(|(tool_use_id, perm)| PendingPermissionInfo {
279 tool_use_id: tool_use_id.clone(),
280 session_id: perm.session_id,
281 request: perm.request.clone(),
282 turn_id: perm.turn_id.clone(),
283 })
284 .collect()
285 }
286
287 pub async fn cancel_session(&self, session_id: i64) {
295 let mut pending = self.pending.lock().await;
296 pending.retain(|_, perm| perm.session_id != session_id);
297 }
299
300 pub async fn clear_session(&self, session_id: i64) {
307 self.cancel_session(session_id).await;
309
310 let mut grants = self.session_grants.lock().await;
312 grants.remove(&session_id);
313 }
314
315 pub async fn has_pending(&self, session_id: i64) -> bool {
323 let pending = self.pending.lock().await;
324 pending.values().any(|perm| perm.session_id == session_id)
325 }
326
327 pub async fn pending_count(&self) -> usize {
329 let pending = self.pending.lock().await;
330 pending.len()
331 }
332
333 pub async fn session_grants(&self, session_id: i64) -> HashSet<PermissionGrant> {
341 let grants = self.session_grants.lock().await;
342 grants.get(&session_id).cloned().unwrap_or_default()
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use crate::controller::tools::ask_for_permissions::PermissionScope;
350
351 fn create_test_request() -> PermissionRequest {
352 PermissionRequest {
353 action: "Delete file /tmp/foo.txt".to_string(),
354 reason: Some("User requested cleanup".to_string()),
355 resources: vec!["/tmp/foo.txt".to_string()],
356 category: PermissionCategory::FileDelete,
357 }
358 }
359
360 fn create_grant_response() -> PermissionResponse {
361 PermissionResponse {
362 granted: true,
363 scope: Some(PermissionScope::Session),
364 message: None,
365 }
366 }
367
368 fn create_deny_response() -> PermissionResponse {
369 PermissionResponse {
370 granted: false,
371 scope: None,
372 message: Some("Not allowed".to_string()),
373 }
374 }
375
376 #[tokio::test]
377 async fn test_register_and_respond() {
378 let (event_tx, mut event_rx) = mpsc::channel(10);
379 let registry = PermissionRegistry::new(event_tx);
380
381 let request = create_test_request();
382 let response = create_grant_response();
383
384 let rx = registry
386 .register("tool_123".to_string(), 1, request.clone(), None)
387 .await
388 .unwrap();
389
390 let event = event_rx.recv().await.unwrap();
392 if let ControllerEvent::PermissionRequired {
393 session_id,
394 tool_use_id,
395 ..
396 } = event
397 {
398 assert_eq!(session_id, 1);
399 assert_eq!(tool_use_id, "tool_123");
400 } else {
401 panic!("Expected PermissionRequired event");
402 }
403
404 registry
406 .respond("tool_123", response.clone())
407 .await
408 .unwrap();
409
410 let received = rx.await.unwrap();
412 assert!(received.granted);
413 assert_eq!(received.scope, Some(PermissionScope::Session));
414 }
415
416 #[tokio::test]
417 async fn test_respond_not_found() {
418 let (event_tx, _event_rx) = mpsc::channel(10);
419 let registry = PermissionRegistry::new(event_tx);
420
421 let response = create_grant_response();
422 let result = registry.respond("nonexistent", response).await;
423
424 assert_eq!(result, Err(PermissionError::NotFound));
425 }
426
427 #[tokio::test]
428 async fn test_session_grant_caching() {
429 let (event_tx, _event_rx) = mpsc::channel(10);
430 let registry = PermissionRegistry::new(event_tx);
431
432 let request = create_test_request();
433
434 assert!(!registry.is_granted(1, &request).await);
436
437 let rx = registry
439 .register("tool_1".to_string(), 1, request.clone(), None)
440 .await
441 .unwrap();
442 registry.respond("tool_1", create_grant_response()).await.unwrap();
443
444 let _ = rx.await;
446
447 assert!(registry.is_granted(1, &request).await);
449
450 assert!(!registry.is_granted(2, &request).await);
452 }
453
454 #[tokio::test]
455 async fn test_once_scope_not_cached() {
456 let (event_tx, _event_rx) = mpsc::channel(10);
457 let registry = PermissionRegistry::new(event_tx);
458
459 let request = create_test_request();
460
461 let rx = registry
463 .register("tool_1".to_string(), 1, request.clone(), None)
464 .await
465 .unwrap();
466 registry
467 .respond(
468 "tool_1",
469 PermissionResponse {
470 granted: true,
471 scope: Some(PermissionScope::Once),
472 message: None,
473 },
474 )
475 .await
476 .unwrap();
477
478 let _ = rx.await;
480
481 assert!(!registry.is_granted(1, &request).await);
483 }
484
485 #[tokio::test]
486 async fn test_denied_not_cached() {
487 let (event_tx, _event_rx) = mpsc::channel(10);
488 let registry = PermissionRegistry::new(event_tx);
489
490 let request = create_test_request();
491
492 let rx = registry
494 .register("tool_1".to_string(), 1, request.clone(), None)
495 .await
496 .unwrap();
497 registry.respond("tool_1", create_deny_response()).await.unwrap();
498
499 let _ = rx.await;
501
502 assert!(!registry.is_granted(1, &request).await);
504 }
505
506 #[tokio::test]
507 async fn test_pending_for_session() {
508 let (event_tx, _event_rx) = mpsc::channel(10);
509 let registry = PermissionRegistry::new(event_tx);
510
511 let request = create_test_request();
512
513 let _ = registry
515 .register("tool_1".to_string(), 1, request.clone(), None)
516 .await;
517 let _ = registry
518 .register("tool_2".to_string(), 1, request.clone(), None)
519 .await;
520 let _ = registry
521 .register("tool_3".to_string(), 2, request.clone(), None)
522 .await;
523
524 let pending = registry.pending_for_session(1).await;
526 assert_eq!(pending.len(), 2);
527
528 let pending = registry.pending_for_session(2).await;
530 assert_eq!(pending.len(), 1);
531
532 let pending = registry.pending_for_session(999).await;
534 assert_eq!(pending.len(), 0);
535 }
536
537 #[tokio::test]
538 async fn test_cancel_session() {
539 let (event_tx, _event_rx) = mpsc::channel(10);
540 let registry = PermissionRegistry::new(event_tx);
541
542 let request = create_test_request();
543
544 let rx1 = registry
546 .register("tool_1".to_string(), 1, request.clone(), None)
547 .await
548 .unwrap();
549 registry.respond("tool_1", create_grant_response()).await.unwrap();
550
551 let _ = rx1.await;
553
554 let rx2 = registry
556 .register("tool_2".to_string(), 2, request.clone(), None)
557 .await
558 .unwrap();
559
560 registry.clear_session(1).await;
562
563 assert!(!registry.is_granted(1, &request).await);
565
566 assert!(registry.has_pending(2).await);
568
569 registry.respond("tool_2", create_grant_response()).await.unwrap();
571 let received = rx2.await.unwrap();
572 assert!(received.granted);
573 }
574
575 #[tokio::test]
576 async fn test_has_pending() {
577 let (event_tx, _event_rx) = mpsc::channel(10);
578 let registry = PermissionRegistry::new(event_tx);
579
580 assert!(!registry.has_pending(1).await);
581
582 let request = create_test_request();
583 let _ = registry
584 .register("tool_1".to_string(), 1, request, None)
585 .await;
586
587 assert!(registry.has_pending(1).await);
588 assert!(!registry.has_pending(2).await);
589 }
590
591 #[tokio::test]
592 async fn test_pending_count() {
593 let (event_tx, _event_rx) = mpsc::channel(10);
594 let registry = PermissionRegistry::new(event_tx);
595
596 assert_eq!(registry.pending_count().await, 0);
597
598 let request = create_test_request();
599 let _ = registry
600 .register("tool_1".to_string(), 1, request.clone(), None)
601 .await;
602 assert_eq!(registry.pending_count().await, 1);
603
604 let _ = registry
605 .register("tool_2".to_string(), 1, request, None)
606 .await;
607 assert_eq!(registry.pending_count().await, 2);
608 }
609
610 #[tokio::test]
611 async fn test_cancel() {
612 let (event_tx, _event_rx) = mpsc::channel(10);
613 let registry = PermissionRegistry::new(event_tx);
614
615 let request = create_test_request();
616 let rx = registry
617 .register("tool_1".to_string(), 1, request, None)
618 .await
619 .unwrap();
620
621 registry.cancel("tool_1").await.unwrap();
623
624 assert!(rx.await.is_err());
626
627 let result = registry.cancel("nonexistent").await;
629 assert_eq!(result, Err(PermissionError::NotFound));
630 }
631}