Skip to main content

agent_core/controller/tools/
permission_registry.rs

1//! Permission registry for managing permission requests and session-level grants.
2//!
3//! This module provides a registry for tools that need to request user permission,
4//! such as the AskForPermissions tool.
5
6use std::collections::{HashMap, HashSet};
7
8use tokio::sync::{mpsc, oneshot, Mutex};
9
10use super::ask_for_permissions::{PermissionCategory, PermissionRequest, PermissionResponse};
11use crate::controller::types::{ControllerEvent, TurnId};
12
13/// Error types for permission operations.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum PermissionError {
16    /// No pending permission request found for the given tool_use_id.
17    NotFound,
18    /// The permission request was already responded to.
19    AlreadyResponded,
20    /// Failed to send response (channel closed).
21    SendFailed,
22    /// Failed to send event notification.
23    EventSendFailed,
24}
25
26impl std::fmt::Display for PermissionError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            PermissionError::NotFound => write!(f, "No pending permission request found"),
30            PermissionError::AlreadyResponded => write!(f, "Permission already responded to"),
31            PermissionError::SendFailed => write!(f, "Failed to send response"),
32            PermissionError::EventSendFailed => write!(f, "Failed to send event notification"),
33        }
34    }
35}
36
37impl std::error::Error for PermissionError {}
38
39/// Information about a pending permission request for UI display.
40#[derive(Debug, Clone)]
41pub struct PendingPermissionInfo {
42    /// Tool use ID for this permission request.
43    pub tool_use_id: String,
44    /// Session ID this permission belongs to.
45    pub session_id: i64,
46    /// The permission request details.
47    pub request: PermissionRequest,
48    /// Turn ID for this permission request.
49    pub turn_id: Option<TurnId>,
50}
51
52/// A grant that was approved for the session.
53#[derive(Debug, Clone, Hash, PartialEq, Eq)]
54pub struct PermissionGrant {
55    /// Category of the permission.
56    pub category: PermissionCategory,
57    /// The action pattern that was granted (for matching future requests).
58    /// None means this is a category-wide grant ("allow all" for this category).
59    pub action_pattern: Option<String>,
60}
61
62/// Internal state for a pending permission request.
63struct PendingPermission {
64    session_id: i64,
65    request: PermissionRequest,
66    turn_id: Option<TurnId>,
67    responder: oneshot::Sender<PermissionResponse>,
68}
69
70/// Registry for managing permission requests and session-level grants.
71///
72/// This registry tracks tools that are blocked waiting for permission
73/// and provides methods for the UI to query and respond to these requests.
74/// It also caches session-level grants to avoid re-asking.
75pub struct PermissionRegistry {
76    /// Pending permission requests keyed by tool_use_id.
77    pending: Mutex<HashMap<String, PendingPermission>>,
78    /// Session-level grants (session_id -> set of grants).
79    session_grants: Mutex<HashMap<i64, HashSet<PermissionGrant>>>,
80    /// Channel to send events to the controller.
81    event_tx: mpsc::Sender<ControllerEvent>,
82}
83
84impl PermissionRegistry {
85    /// Create a new PermissionRegistry.
86    ///
87    /// # Arguments
88    /// * `event_tx` - Channel to send events when permissions are requested.
89    pub fn new(event_tx: mpsc::Sender<ControllerEvent>) -> Self {
90        Self {
91            pending: Mutex::new(HashMap::new()),
92            session_grants: Mutex::new(HashMap::new()),
93            event_tx,
94        }
95    }
96
97    /// Check if permission is already granted for the session.
98    ///
99    /// This checks if a previous session-level grant covers this request.
100    /// A grant matches if:
101    /// - The category matches AND
102    /// - Either the grant is category-wide (action_pattern is None), OR
103    /// - The action_pattern exactly matches the request action
104    ///
105    /// # Arguments
106    /// * `session_id` - Session to check.
107    /// * `request` - The permission request to check.
108    ///
109    /// # Returns
110    /// True if permission was previously granted for the session.
111    pub async fn is_granted(&self, session_id: i64, request: &PermissionRequest) -> bool {
112        let grants = self.session_grants.lock().await;
113        if let Some(session_grants) = grants.get(&session_id) {
114            // Check if any grant matches the request
115            session_grants.iter().any(|grant| {
116                if grant.category != request.category {
117                    return false;
118                }
119                // Category-wide grant (None) matches everything in that category
120                match &grant.action_pattern {
121                    None => true,
122                    Some(pattern) => pattern == &request.action,
123                }
124            })
125        } else {
126            false
127        }
128    }
129
130    /// Register a permission request and get a receiver to await on.
131    ///
132    /// This is called by the AskForPermissionsTool when it starts executing.
133    /// The tool will await on the returned receiver until the UI responds.
134    ///
135    /// # Arguments
136    /// * `tool_use_id` - Unique ID for this tool use request.
137    /// * `session_id` - Session that requested the permission.
138    /// * `request` - The permission request details.
139    /// * `turn_id` - Optional turn ID for this request.
140    ///
141    /// # Returns
142    /// A oneshot receiver that will receive the user's response.
143    pub async fn register(
144        &self,
145        tool_use_id: String,
146        session_id: i64,
147        request: PermissionRequest,
148        turn_id: Option<TurnId>,
149    ) -> Result<oneshot::Receiver<PermissionResponse>, PermissionError> {
150        let (tx, rx) = oneshot::channel();
151
152        // Store the pending request
153        {
154            let mut pending = self.pending.lock().await;
155            pending.insert(
156                tool_use_id.clone(),
157                PendingPermission {
158                    session_id,
159                    request: request.clone(),
160                    turn_id: turn_id.clone(),
161                    responder: tx,
162                },
163            );
164        }
165
166        // Emit event to notify UI
167        self.event_tx
168            .send(ControllerEvent::PermissionRequired {
169                session_id,
170                tool_use_id,
171                request,
172                turn_id,
173            })
174            .await
175            .map_err(|_| PermissionError::EventSendFailed)?;
176
177        Ok(rx)
178    }
179
180    /// Respond to a pending permission request.
181    ///
182    /// This is called by the UI when the user has granted or denied permission.
183    ///
184    /// # Arguments
185    /// * `tool_use_id` - ID of the tool use to respond to.
186    /// * `response` - The user's response (grant/deny).
187    ///
188    /// # Returns
189    /// Ok(()) if the response was sent successfully, or an error.
190    pub async fn respond(
191        &self,
192        tool_use_id: &str,
193        response: PermissionResponse,
194    ) -> Result<(), PermissionError> {
195        let pending_permission = {
196            let mut pending = self.pending.lock().await;
197            pending
198                .remove(tool_use_id)
199                .ok_or(PermissionError::NotFound)?
200        };
201
202        // If granted with session or category-session scope, cache the grant
203        if response.granted {
204            if let Some(ref scope) = response.scope {
205                use super::ask_for_permissions::PermissionScope;
206                match scope {
207                    PermissionScope::Session => {
208                        // Grant for this specific action pattern
209                        let mut grants = self.session_grants.lock().await;
210                        let session_grants = grants
211                            .entry(pending_permission.session_id)
212                            .or_insert_with(HashSet::new);
213                        session_grants.insert(PermissionGrant {
214                            category: pending_permission.request.category.clone(),
215                            action_pattern: Some(pending_permission.request.action.clone()),
216                        });
217                    }
218                    PermissionScope::CategorySession => {
219                        // Category-wide grant (allow all in this category)
220                        let mut grants = self.session_grants.lock().await;
221                        let session_grants = grants
222                            .entry(pending_permission.session_id)
223                            .or_insert_with(HashSet::new);
224                        session_grants.insert(PermissionGrant {
225                            category: pending_permission.request.category.clone(),
226                            action_pattern: None, // None = category-wide
227                        });
228                    }
229                    PermissionScope::Once => {
230                        // Do not cache - one-time grant
231                    }
232                }
233            }
234        }
235
236        pending_permission
237            .responder
238            .send(response)
239            .map_err(|_| PermissionError::SendFailed)
240    }
241
242    /// Cancel a pending permission request (user declined).
243    ///
244    /// This is called by the UI when the user closes the permission dialog
245    /// without responding.
246    ///
247    /// # Arguments
248    /// * `tool_use_id` - ID of the tool use to cancel.
249    ///
250    /// # Returns
251    /// Ok(()) if the request was found and cancelled, or NotFound error.
252    pub async fn cancel(&self, tool_use_id: &str) -> Result<(), PermissionError> {
253        let mut pending = self.pending.lock().await;
254        if pending.remove(tool_use_id).is_some() {
255            // Dropping the sender will cause the tool to receive a RecvError
256            // which will be converted to "User denied permission"
257            Ok(())
258        } else {
259            Err(PermissionError::NotFound)
260        }
261    }
262
263    /// Get all pending permission requests for a session.
264    ///
265    /// This is called by the UI when switching sessions to display
266    /// any pending permission requests for that session.
267    ///
268    /// # Arguments
269    /// * `session_id` - Session ID to query.
270    ///
271    /// # Returns
272    /// List of pending permission info for the session.
273    pub async fn pending_for_session(&self, session_id: i64) -> Vec<PendingPermissionInfo> {
274        let pending = self.pending.lock().await;
275        pending
276            .iter()
277            .filter(|(_, perm)| perm.session_id == session_id)
278            .map(|(tool_use_id, perm)| PendingPermissionInfo {
279                tool_use_id: tool_use_id.clone(),
280                session_id: perm.session_id,
281                request: perm.request.clone(),
282                turn_id: perm.turn_id.clone(),
283            })
284            .collect()
285    }
286
287    /// Cancel all pending permission requests for a session.
288    ///
289    /// This is called when a session is destroyed. It drops the senders,
290    /// which will cause the awaiting tools to receive a RecvError.
291    ///
292    /// # Arguments
293    /// * `session_id` - Session ID to cancel.
294    pub async fn cancel_session(&self, session_id: i64) {
295        let mut pending = self.pending.lock().await;
296        pending.retain(|_, perm| perm.session_id != session_id);
297        // Dropped senders will cause RecvError on the tool side
298    }
299
300    /// Clear all grants for a session.
301    ///
302    /// This is called when a session ends or is reset.
303    ///
304    /// # Arguments
305    /// * `session_id` - Session ID to clear grants for.
306    pub async fn clear_session(&self, session_id: i64) {
307        // Cancel pending requests
308        self.cancel_session(session_id).await;
309
310        // Clear session grants
311        let mut grants = self.session_grants.lock().await;
312        grants.remove(&session_id);
313    }
314
315    /// Check if there are any pending permission requests for a session.
316    ///
317    /// # Arguments
318    /// * `session_id` - Session ID to check.
319    ///
320    /// # Returns
321    /// True if there are pending permission requests.
322    pub async fn has_pending(&self, session_id: i64) -> bool {
323        let pending = self.pending.lock().await;
324        pending.values().any(|perm| perm.session_id == session_id)
325    }
326
327    /// Get the count of pending permission requests.
328    pub async fn pending_count(&self) -> usize {
329        let pending = self.pending.lock().await;
330        pending.len()
331    }
332
333    /// Get all session grants for a session.
334    ///
335    /// # Arguments
336    /// * `session_id` - Session ID to query.
337    ///
338    /// # Returns
339    /// Set of grants for the session (empty if none).
340    pub async fn session_grants(&self, session_id: i64) -> HashSet<PermissionGrant> {
341        let grants = self.session_grants.lock().await;
342        grants.get(&session_id).cloned().unwrap_or_default()
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use crate::controller::tools::ask_for_permissions::PermissionScope;
350
351    fn create_test_request() -> PermissionRequest {
352        PermissionRequest {
353            action: "Delete file /tmp/foo.txt".to_string(),
354            reason: Some("User requested cleanup".to_string()),
355            resources: vec!["/tmp/foo.txt".to_string()],
356            category: PermissionCategory::FileDelete,
357        }
358    }
359
360    fn create_grant_response() -> PermissionResponse {
361        PermissionResponse {
362            granted: true,
363            scope: Some(PermissionScope::Session),
364            message: None,
365        }
366    }
367
368    fn create_deny_response() -> PermissionResponse {
369        PermissionResponse {
370            granted: false,
371            scope: None,
372            message: Some("Not allowed".to_string()),
373        }
374    }
375
376    #[tokio::test]
377    async fn test_register_and_respond() {
378        let (event_tx, mut event_rx) = mpsc::channel(10);
379        let registry = PermissionRegistry::new(event_tx);
380
381        let request = create_test_request();
382        let response = create_grant_response();
383
384        // Register permission request
385        let rx = registry
386            .register("tool_123".to_string(), 1, request.clone(), None)
387            .await
388            .unwrap();
389
390        // Verify event was emitted
391        let event = event_rx.recv().await.unwrap();
392        if let ControllerEvent::PermissionRequired {
393            session_id,
394            tool_use_id,
395            ..
396        } = event
397        {
398            assert_eq!(session_id, 1);
399            assert_eq!(tool_use_id, "tool_123");
400        } else {
401            panic!("Expected PermissionRequired event");
402        }
403
404        // Respond to request
405        registry
406            .respond("tool_123", response.clone())
407            .await
408            .unwrap();
409
410        // Verify response was received
411        let received = rx.await.unwrap();
412        assert!(received.granted);
413        assert_eq!(received.scope, Some(PermissionScope::Session));
414    }
415
416    #[tokio::test]
417    async fn test_respond_not_found() {
418        let (event_tx, _event_rx) = mpsc::channel(10);
419        let registry = PermissionRegistry::new(event_tx);
420
421        let response = create_grant_response();
422        let result = registry.respond("nonexistent", response).await;
423
424        assert_eq!(result, Err(PermissionError::NotFound));
425    }
426
427    #[tokio::test]
428    async fn test_session_grant_caching() {
429        let (event_tx, _event_rx) = mpsc::channel(10);
430        let registry = PermissionRegistry::new(event_tx);
431
432        let request = create_test_request();
433
434        // Not granted initially
435        assert!(!registry.is_granted(1, &request).await);
436
437        // Register and grant with session scope
438        let rx = registry
439            .register("tool_1".to_string(), 1, request.clone(), None)
440            .await
441            .unwrap();
442        registry.respond("tool_1", create_grant_response()).await.unwrap();
443
444        // Consume the response to complete the flow
445        let _ = rx.await;
446
447        // Now should be granted
448        assert!(registry.is_granted(1, &request).await);
449
450        // Different session should not be granted
451        assert!(!registry.is_granted(2, &request).await);
452    }
453
454    #[tokio::test]
455    async fn test_once_scope_not_cached() {
456        let (event_tx, _event_rx) = mpsc::channel(10);
457        let registry = PermissionRegistry::new(event_tx);
458
459        let request = create_test_request();
460
461        // Register and grant with Once scope
462        let rx = registry
463            .register("tool_1".to_string(), 1, request.clone(), None)
464            .await
465            .unwrap();
466        registry
467            .respond(
468                "tool_1",
469                PermissionResponse {
470                    granted: true,
471                    scope: Some(PermissionScope::Once),
472                    message: None,
473                },
474            )
475            .await
476            .unwrap();
477
478        // Consume the response to complete the flow
479        let _ = rx.await;
480
481        // Should NOT be cached (Once scope)
482        assert!(!registry.is_granted(1, &request).await);
483    }
484
485    #[tokio::test]
486    async fn test_denied_not_cached() {
487        let (event_tx, _event_rx) = mpsc::channel(10);
488        let registry = PermissionRegistry::new(event_tx);
489
490        let request = create_test_request();
491
492        // Register and deny
493        let rx = registry
494            .register("tool_1".to_string(), 1, request.clone(), None)
495            .await
496            .unwrap();
497        registry.respond("tool_1", create_deny_response()).await.unwrap();
498
499        // Consume the response to complete the flow
500        let _ = rx.await;
501
502        // Should not be granted
503        assert!(!registry.is_granted(1, &request).await);
504    }
505
506    #[tokio::test]
507    async fn test_pending_for_session() {
508        let (event_tx, _event_rx) = mpsc::channel(10);
509        let registry = PermissionRegistry::new(event_tx);
510
511        let request = create_test_request();
512
513        // Register requests for different sessions
514        let _ = registry
515            .register("tool_1".to_string(), 1, request.clone(), None)
516            .await;
517        let _ = registry
518            .register("tool_2".to_string(), 1, request.clone(), None)
519            .await;
520        let _ = registry
521            .register("tool_3".to_string(), 2, request.clone(), None)
522            .await;
523
524        // Query session 1
525        let pending = registry.pending_for_session(1).await;
526        assert_eq!(pending.len(), 2);
527
528        // Query session 2
529        let pending = registry.pending_for_session(2).await;
530        assert_eq!(pending.len(), 1);
531
532        // Query nonexistent session
533        let pending = registry.pending_for_session(999).await;
534        assert_eq!(pending.len(), 0);
535    }
536
537    #[tokio::test]
538    async fn test_cancel_session() {
539        let (event_tx, _event_rx) = mpsc::channel(10);
540        let registry = PermissionRegistry::new(event_tx);
541
542        let request = create_test_request();
543
544        // Register and grant for session 1
545        let rx1 = registry
546            .register("tool_1".to_string(), 1, request.clone(), None)
547            .await
548            .unwrap();
549        registry.respond("tool_1", create_grant_response()).await.unwrap();
550
551        // Consume the response to complete the flow
552        let _ = rx1.await;
553
554        // Register request for session 2
555        let rx2 = registry
556            .register("tool_2".to_string(), 2, request.clone(), None)
557            .await
558            .unwrap();
559
560        // Clear session 1
561        registry.clear_session(1).await;
562
563        // Session 1 grants should be gone
564        assert!(!registry.is_granted(1, &request).await);
565
566        // Session 2 should still have pending
567        assert!(registry.has_pending(2).await);
568
569        // Session 2 receiver should still work
570        registry.respond("tool_2", create_grant_response()).await.unwrap();
571        let received = rx2.await.unwrap();
572        assert!(received.granted);
573    }
574
575    #[tokio::test]
576    async fn test_has_pending() {
577        let (event_tx, _event_rx) = mpsc::channel(10);
578        let registry = PermissionRegistry::new(event_tx);
579
580        assert!(!registry.has_pending(1).await);
581
582        let request = create_test_request();
583        let _ = registry
584            .register("tool_1".to_string(), 1, request, None)
585            .await;
586
587        assert!(registry.has_pending(1).await);
588        assert!(!registry.has_pending(2).await);
589    }
590
591    #[tokio::test]
592    async fn test_pending_count() {
593        let (event_tx, _event_rx) = mpsc::channel(10);
594        let registry = PermissionRegistry::new(event_tx);
595
596        assert_eq!(registry.pending_count().await, 0);
597
598        let request = create_test_request();
599        let _ = registry
600            .register("tool_1".to_string(), 1, request.clone(), None)
601            .await;
602        assert_eq!(registry.pending_count().await, 1);
603
604        let _ = registry
605            .register("tool_2".to_string(), 1, request, None)
606            .await;
607        assert_eq!(registry.pending_count().await, 2);
608    }
609
610    #[tokio::test]
611    async fn test_cancel() {
612        let (event_tx, _event_rx) = mpsc::channel(10);
613        let registry = PermissionRegistry::new(event_tx);
614
615        let request = create_test_request();
616        let rx = registry
617            .register("tool_1".to_string(), 1, request, None)
618            .await
619            .unwrap();
620
621        // Cancel the request
622        registry.cancel("tool_1").await.unwrap();
623
624        // Receiver should get error
625        assert!(rx.await.is_err());
626
627        // Cancel nonexistent should fail
628        let result = registry.cancel("nonexistent").await;
629        assert_eq!(result, Err(PermissionError::NotFound));
630    }
631}