agent_core/controller/tools/
permission_registry.rs

1//! Permission registry for managing permission requests and session-level grants.
2//!
3//! This module provides a registry for tools that need to request user permission,
4//! such as the AskForPermissions tool.
5
6use std::collections::{HashMap, HashSet};
7
8use tokio::sync::{mpsc, oneshot, Mutex};
9
10use super::ask_for_permissions::{PermissionCategory, PermissionRequest, PermissionResponse};
11use crate::controller::types::{ControllerEvent, TurnId};
12
13/// Error types for permission operations.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum PermissionError {
16    /// No pending permission request found for the given tool_use_id.
17    NotFound,
18    /// The permission request was already responded to.
19    AlreadyResponded,
20    /// Failed to send response (channel closed).
21    SendFailed,
22    /// Failed to send event notification.
23    EventSendFailed,
24}
25
26impl std::fmt::Display for PermissionError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            PermissionError::NotFound => write!(f, "No pending permission request found"),
30            PermissionError::AlreadyResponded => write!(f, "Permission already responded to"),
31            PermissionError::SendFailed => write!(f, "Failed to send response"),
32            PermissionError::EventSendFailed => write!(f, "Failed to send event notification"),
33        }
34    }
35}
36
37impl std::error::Error for PermissionError {}
38
39/// Information about a pending permission request for UI display.
40#[derive(Debug, Clone)]
41pub struct PendingPermissionInfo {
42    /// Tool use ID for this permission request.
43    pub tool_use_id: String,
44    /// Session ID this permission belongs to.
45    pub session_id: i64,
46    /// The permission request details.
47    pub request: PermissionRequest,
48    /// Turn ID for this permission request.
49    pub turn_id: Option<TurnId>,
50}
51
52/// A grant that was approved for the session.
53#[derive(Debug, Clone, Hash, PartialEq, Eq)]
54pub struct PermissionGrant {
55    /// Category of the permission.
56    pub category: PermissionCategory,
57    /// The action pattern that was granted (for matching future requests).
58    pub action_pattern: String,
59}
60
61/// Internal state for a pending permission request.
62struct PendingPermission {
63    session_id: i64,
64    request: PermissionRequest,
65    turn_id: Option<TurnId>,
66    responder: oneshot::Sender<PermissionResponse>,
67}
68
69/// Registry for managing permission requests and session-level grants.
70///
71/// This registry tracks tools that are blocked waiting for permission
72/// and provides methods for the UI to query and respond to these requests.
73/// It also caches session-level grants to avoid re-asking.
74pub struct PermissionRegistry {
75    /// Pending permission requests keyed by tool_use_id.
76    pending: Mutex<HashMap<String, PendingPermission>>,
77    /// Session-level grants (session_id -> set of grants).
78    session_grants: Mutex<HashMap<i64, HashSet<PermissionGrant>>>,
79    /// Channel to send events to the controller.
80    event_tx: mpsc::Sender<ControllerEvent>,
81}
82
83impl PermissionRegistry {
84    /// Create a new PermissionRegistry.
85    ///
86    /// # Arguments
87    /// * `event_tx` - Channel to send events when permissions are requested.
88    pub fn new(event_tx: mpsc::Sender<ControllerEvent>) -> Self {
89        Self {
90            pending: Mutex::new(HashMap::new()),
91            session_grants: Mutex::new(HashMap::new()),
92            event_tx,
93        }
94    }
95
96    /// Check if permission is already granted for the session.
97    ///
98    /// This checks if a previous session-level grant covers this request.
99    ///
100    /// # Arguments
101    /// * `session_id` - Session to check.
102    /// * `request` - The permission request to check.
103    ///
104    /// # Returns
105    /// True if permission was previously granted for the session.
106    pub async fn is_granted(&self, session_id: i64, request: &PermissionRequest) -> bool {
107        let grants = self.session_grants.lock().await;
108        if let Some(session_grants) = grants.get(&session_id) {
109            // Check if any grant matches the request
110            session_grants.iter().any(|grant| {
111                grant.category == request.category && grant.action_pattern == request.action
112            })
113        } else {
114            false
115        }
116    }
117
118    /// Register a permission request and get a receiver to await on.
119    ///
120    /// This is called by the AskForPermissionsTool when it starts executing.
121    /// The tool will await on the returned receiver until the UI responds.
122    ///
123    /// # Arguments
124    /// * `tool_use_id` - Unique ID for this tool use request.
125    /// * `session_id` - Session that requested the permission.
126    /// * `request` - The permission request details.
127    /// * `turn_id` - Optional turn ID for this request.
128    ///
129    /// # Returns
130    /// A oneshot receiver that will receive the user's response.
131    pub async fn register(
132        &self,
133        tool_use_id: String,
134        session_id: i64,
135        request: PermissionRequest,
136        turn_id: Option<TurnId>,
137    ) -> Result<oneshot::Receiver<PermissionResponse>, PermissionError> {
138        let (tx, rx) = oneshot::channel();
139
140        // Store the pending request
141        {
142            let mut pending = self.pending.lock().await;
143            pending.insert(
144                tool_use_id.clone(),
145                PendingPermission {
146                    session_id,
147                    request: request.clone(),
148                    turn_id: turn_id.clone(),
149                    responder: tx,
150                },
151            );
152        }
153
154        // Emit event to notify UI
155        self.event_tx
156            .send(ControllerEvent::PermissionRequired {
157                session_id,
158                tool_use_id,
159                request,
160                turn_id,
161            })
162            .await
163            .map_err(|_| PermissionError::EventSendFailed)?;
164
165        Ok(rx)
166    }
167
168    /// Respond to a pending permission request.
169    ///
170    /// This is called by the UI when the user has granted or denied permission.
171    ///
172    /// # Arguments
173    /// * `tool_use_id` - ID of the tool use to respond to.
174    /// * `response` - The user's response (grant/deny).
175    ///
176    /// # Returns
177    /// Ok(()) if the response was sent successfully, or an error.
178    pub async fn respond(
179        &self,
180        tool_use_id: &str,
181        response: PermissionResponse,
182    ) -> Result<(), PermissionError> {
183        let pending_permission = {
184            let mut pending = self.pending.lock().await;
185            pending
186                .remove(tool_use_id)
187                .ok_or(PermissionError::NotFound)?
188        };
189
190        // If granted with session scope, cache the grant
191        if response.granted {
192            if let Some(ref scope) = response.scope {
193                if *scope == super::ask_for_permissions::PermissionScope::Session {
194                    let mut grants = self.session_grants.lock().await;
195                    let session_grants = grants
196                        .entry(pending_permission.session_id)
197                        .or_insert_with(HashSet::new);
198                    session_grants.insert(PermissionGrant {
199                        category: pending_permission.request.category.clone(),
200                        action_pattern: pending_permission.request.action.clone(),
201                    });
202                }
203            }
204        }
205
206        pending_permission
207            .responder
208            .send(response)
209            .map_err(|_| PermissionError::SendFailed)
210    }
211
212    /// Cancel a pending permission request (user declined).
213    ///
214    /// This is called by the UI when the user closes the permission dialog
215    /// without responding.
216    ///
217    /// # Arguments
218    /// * `tool_use_id` - ID of the tool use to cancel.
219    ///
220    /// # Returns
221    /// Ok(()) if the request was found and cancelled, or NotFound error.
222    pub async fn cancel(&self, tool_use_id: &str) -> Result<(), PermissionError> {
223        let mut pending = self.pending.lock().await;
224        if pending.remove(tool_use_id).is_some() {
225            // Dropping the sender will cause the tool to receive a RecvError
226            // which will be converted to "User denied permission"
227            Ok(())
228        } else {
229            Err(PermissionError::NotFound)
230        }
231    }
232
233    /// Get all pending permission requests for a session.
234    ///
235    /// This is called by the UI when switching sessions to display
236    /// any pending permission requests for that session.
237    ///
238    /// # Arguments
239    /// * `session_id` - Session ID to query.
240    ///
241    /// # Returns
242    /// List of pending permission info for the session.
243    pub async fn pending_for_session(&self, session_id: i64) -> Vec<PendingPermissionInfo> {
244        let pending = self.pending.lock().await;
245        pending
246            .iter()
247            .filter(|(_, perm)| perm.session_id == session_id)
248            .map(|(tool_use_id, perm)| PendingPermissionInfo {
249                tool_use_id: tool_use_id.clone(),
250                session_id: perm.session_id,
251                request: perm.request.clone(),
252                turn_id: perm.turn_id.clone(),
253            })
254            .collect()
255    }
256
257    /// Cancel all pending permission requests for a session.
258    ///
259    /// This is called when a session is destroyed. It drops the senders,
260    /// which will cause the awaiting tools to receive a RecvError.
261    ///
262    /// # Arguments
263    /// * `session_id` - Session ID to cancel.
264    pub async fn cancel_session(&self, session_id: i64) {
265        let mut pending = self.pending.lock().await;
266        pending.retain(|_, perm| perm.session_id != session_id);
267        // Dropped senders will cause RecvError on the tool side
268    }
269
270    /// Clear all grants for a session.
271    ///
272    /// This is called when a session ends or is reset.
273    ///
274    /// # Arguments
275    /// * `session_id` - Session ID to clear grants for.
276    pub async fn clear_session(&self, session_id: i64) {
277        // Cancel pending requests
278        self.cancel_session(session_id).await;
279
280        // Clear session grants
281        let mut grants = self.session_grants.lock().await;
282        grants.remove(&session_id);
283    }
284
285    /// Check if there are any pending permission requests for a session.
286    ///
287    /// # Arguments
288    /// * `session_id` - Session ID to check.
289    ///
290    /// # Returns
291    /// True if there are pending permission requests.
292    pub async fn has_pending(&self, session_id: i64) -> bool {
293        let pending = self.pending.lock().await;
294        pending.values().any(|perm| perm.session_id == session_id)
295    }
296
297    /// Get the count of pending permission requests.
298    pub async fn pending_count(&self) -> usize {
299        let pending = self.pending.lock().await;
300        pending.len()
301    }
302
303    /// Get all session grants for a session.
304    ///
305    /// # Arguments
306    /// * `session_id` - Session ID to query.
307    ///
308    /// # Returns
309    /// Set of grants for the session (empty if none).
310    pub async fn session_grants(&self, session_id: i64) -> HashSet<PermissionGrant> {
311        let grants = self.session_grants.lock().await;
312        grants.get(&session_id).cloned().unwrap_or_default()
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::controller::tools::ask_for_permissions::PermissionScope;
320
321    fn create_test_request() -> PermissionRequest {
322        PermissionRequest {
323            action: "Delete file /tmp/foo.txt".to_string(),
324            reason: Some("User requested cleanup".to_string()),
325            resources: vec!["/tmp/foo.txt".to_string()],
326            category: PermissionCategory::FileDelete,
327        }
328    }
329
330    fn create_grant_response() -> PermissionResponse {
331        PermissionResponse {
332            granted: true,
333            scope: Some(PermissionScope::Session),
334            message: None,
335        }
336    }
337
338    fn create_deny_response() -> PermissionResponse {
339        PermissionResponse {
340            granted: false,
341            scope: None,
342            message: Some("Not allowed".to_string()),
343        }
344    }
345
346    #[tokio::test]
347    async fn test_register_and_respond() {
348        let (event_tx, mut event_rx) = mpsc::channel(10);
349        let registry = PermissionRegistry::new(event_tx);
350
351        let request = create_test_request();
352        let response = create_grant_response();
353
354        // Register permission request
355        let rx = registry
356            .register("tool_123".to_string(), 1, request.clone(), None)
357            .await
358            .unwrap();
359
360        // Verify event was emitted
361        let event = event_rx.recv().await.unwrap();
362        if let ControllerEvent::PermissionRequired {
363            session_id,
364            tool_use_id,
365            ..
366        } = event
367        {
368            assert_eq!(session_id, 1);
369            assert_eq!(tool_use_id, "tool_123");
370        } else {
371            panic!("Expected PermissionRequired event");
372        }
373
374        // Respond to request
375        registry
376            .respond("tool_123", response.clone())
377            .await
378            .unwrap();
379
380        // Verify response was received
381        let received = rx.await.unwrap();
382        assert!(received.granted);
383        assert_eq!(received.scope, Some(PermissionScope::Session));
384    }
385
386    #[tokio::test]
387    async fn test_respond_not_found() {
388        let (event_tx, _event_rx) = mpsc::channel(10);
389        let registry = PermissionRegistry::new(event_tx);
390
391        let response = create_grant_response();
392        let result = registry.respond("nonexistent", response).await;
393
394        assert_eq!(result, Err(PermissionError::NotFound));
395    }
396
397    #[tokio::test]
398    async fn test_session_grant_caching() {
399        let (event_tx, _event_rx) = mpsc::channel(10);
400        let registry = PermissionRegistry::new(event_tx);
401
402        let request = create_test_request();
403
404        // Not granted initially
405        assert!(!registry.is_granted(1, &request).await);
406
407        // Register and grant with session scope
408        let rx = registry
409            .register("tool_1".to_string(), 1, request.clone(), None)
410            .await
411            .unwrap();
412        registry.respond("tool_1", create_grant_response()).await.unwrap();
413
414        // Consume the response to complete the flow
415        let _ = rx.await;
416
417        // Now should be granted
418        assert!(registry.is_granted(1, &request).await);
419
420        // Different session should not be granted
421        assert!(!registry.is_granted(2, &request).await);
422    }
423
424    #[tokio::test]
425    async fn test_once_scope_not_cached() {
426        let (event_tx, _event_rx) = mpsc::channel(10);
427        let registry = PermissionRegistry::new(event_tx);
428
429        let request = create_test_request();
430
431        // Register and grant with Once scope
432        let rx = registry
433            .register("tool_1".to_string(), 1, request.clone(), None)
434            .await
435            .unwrap();
436        registry
437            .respond(
438                "tool_1",
439                PermissionResponse {
440                    granted: true,
441                    scope: Some(PermissionScope::Once),
442                    message: None,
443                },
444            )
445            .await
446            .unwrap();
447
448        // Consume the response to complete the flow
449        let _ = rx.await;
450
451        // Should NOT be cached (Once scope)
452        assert!(!registry.is_granted(1, &request).await);
453    }
454
455    #[tokio::test]
456    async fn test_denied_not_cached() {
457        let (event_tx, _event_rx) = mpsc::channel(10);
458        let registry = PermissionRegistry::new(event_tx);
459
460        let request = create_test_request();
461
462        // Register and deny
463        let rx = registry
464            .register("tool_1".to_string(), 1, request.clone(), None)
465            .await
466            .unwrap();
467        registry.respond("tool_1", create_deny_response()).await.unwrap();
468
469        // Consume the response to complete the flow
470        let _ = rx.await;
471
472        // Should not be granted
473        assert!(!registry.is_granted(1, &request).await);
474    }
475
476    #[tokio::test]
477    async fn test_pending_for_session() {
478        let (event_tx, _event_rx) = mpsc::channel(10);
479        let registry = PermissionRegistry::new(event_tx);
480
481        let request = create_test_request();
482
483        // Register requests for different sessions
484        let _ = registry
485            .register("tool_1".to_string(), 1, request.clone(), None)
486            .await;
487        let _ = registry
488            .register("tool_2".to_string(), 1, request.clone(), None)
489            .await;
490        let _ = registry
491            .register("tool_3".to_string(), 2, request.clone(), None)
492            .await;
493
494        // Query session 1
495        let pending = registry.pending_for_session(1).await;
496        assert_eq!(pending.len(), 2);
497
498        // Query session 2
499        let pending = registry.pending_for_session(2).await;
500        assert_eq!(pending.len(), 1);
501
502        // Query nonexistent session
503        let pending = registry.pending_for_session(999).await;
504        assert_eq!(pending.len(), 0);
505    }
506
507    #[tokio::test]
508    async fn test_cancel_session() {
509        let (event_tx, _event_rx) = mpsc::channel(10);
510        let registry = PermissionRegistry::new(event_tx);
511
512        let request = create_test_request();
513
514        // Register and grant for session 1
515        let rx1 = registry
516            .register("tool_1".to_string(), 1, request.clone(), None)
517            .await
518            .unwrap();
519        registry.respond("tool_1", create_grant_response()).await.unwrap();
520
521        // Consume the response to complete the flow
522        let _ = rx1.await;
523
524        // Register request for session 2
525        let rx2 = registry
526            .register("tool_2".to_string(), 2, request.clone(), None)
527            .await
528            .unwrap();
529
530        // Clear session 1
531        registry.clear_session(1).await;
532
533        // Session 1 grants should be gone
534        assert!(!registry.is_granted(1, &request).await);
535
536        // Session 2 should still have pending
537        assert!(registry.has_pending(2).await);
538
539        // Session 2 receiver should still work
540        registry.respond("tool_2", create_grant_response()).await.unwrap();
541        let received = rx2.await.unwrap();
542        assert!(received.granted);
543    }
544
545    #[tokio::test]
546    async fn test_has_pending() {
547        let (event_tx, _event_rx) = mpsc::channel(10);
548        let registry = PermissionRegistry::new(event_tx);
549
550        assert!(!registry.has_pending(1).await);
551
552        let request = create_test_request();
553        let _ = registry
554            .register("tool_1".to_string(), 1, request, None)
555            .await;
556
557        assert!(registry.has_pending(1).await);
558        assert!(!registry.has_pending(2).await);
559    }
560
561    #[tokio::test]
562    async fn test_pending_count() {
563        let (event_tx, _event_rx) = mpsc::channel(10);
564        let registry = PermissionRegistry::new(event_tx);
565
566        assert_eq!(registry.pending_count().await, 0);
567
568        let request = create_test_request();
569        let _ = registry
570            .register("tool_1".to_string(), 1, request.clone(), None)
571            .await;
572        assert_eq!(registry.pending_count().await, 1);
573
574        let _ = registry
575            .register("tool_2".to_string(), 1, request, None)
576            .await;
577        assert_eq!(registry.pending_count().await, 2);
578    }
579
580    #[tokio::test]
581    async fn test_cancel() {
582        let (event_tx, _event_rx) = mpsc::channel(10);
583        let registry = PermissionRegistry::new(event_tx);
584
585        let request = create_test_request();
586        let rx = registry
587            .register("tool_1".to_string(), 1, request, None)
588            .await
589            .unwrap();
590
591        // Cancel the request
592        registry.cancel("tool_1").await.unwrap();
593
594        // Receiver should get error
595        assert!(rx.await.is_err());
596
597        // Cancel nonexistent should fail
598        let result = registry.cancel("nonexistent").await;
599        assert_eq!(result, Err(PermissionError::NotFound));
600    }
601}