Skip to main content

github_copilot_sdk/
permission.rs

1//! Permission-policy helpers that compose with an existing
2//! [`SessionHandler`](crate::handler::SessionHandler).
3//!
4//! These wrap an inner handler and override **only** permission requests,
5//! forwarding every other event (tool calls, user input, elicitation,
6//! session events) to the inner handler. Use them when you have a custom
7//! tool handler — typically a [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) —
8//! but want a one-line policy for permission prompts.
9//!
10//! For a full handler that approves or denies everything, see
11//! [`ApproveAllHandler`](crate::handler::ApproveAllHandler) and
12//! [`DenyAllHandler`](crate::handler::DenyAllHandler).
13//!
14//! # Example
15//!
16//! ```rust,no_run
17//! # use std::sync::Arc;
18//! # use github_copilot_sdk::handler::ApproveAllHandler;
19//! # use github_copilot_sdk::permission;
20//! # use github_copilot_sdk::tool::ToolHandlerRouter;
21//! let router = ToolHandlerRouter::new(vec![], Arc::new(ApproveAllHandler));
22//! // Inherit the router's tool dispatch but auto-approve all permission prompts:
23//! let handler = permission::approve_all(Arc::new(router));
24//! ```
25
26use std::sync::Arc;
27
28use async_trait::async_trait;
29
30use crate::handler::{HandlerEvent, HandlerResponse, PermissionResult, SessionHandler};
31use crate::types::PermissionRequestData;
32
33/// Wrap `inner` so that every [`HandlerEvent::PermissionRequest`] is
34/// auto-approved. All other events are forwarded to `inner`.
35pub fn approve_all(inner: Arc<dyn SessionHandler>) -> Arc<dyn SessionHandler> {
36    Arc::new(PermissionOverrideHandler {
37        inner,
38        policy: Policy::ApproveAll,
39    })
40}
41
42/// Wrap `inner` so that every [`HandlerEvent::PermissionRequest`] is
43/// auto-denied. All other events are forwarded to `inner`.
44pub fn deny_all(inner: Arc<dyn SessionHandler>) -> Arc<dyn SessionHandler> {
45    Arc::new(PermissionOverrideHandler {
46        inner,
47        policy: Policy::DenyAll,
48    })
49}
50
51/// Wrap `inner` with a closure-based policy: `predicate` is called for each
52/// permission request; `true` approves, `false` denies. All other events
53/// are forwarded to `inner`.
54///
55/// ```rust,no_run
56/// # use std::sync::Arc;
57/// # use github_copilot_sdk::handler::ApproveAllHandler;
58/// # use github_copilot_sdk::permission;
59/// let inner = Arc::new(ApproveAllHandler);
60/// let handler = permission::approve_if(inner, |data| {
61///     // Inspect data.extra (the raw JSON payload) for custom policy.
62///     data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell")
63/// });
64/// # let _ = handler;
65/// ```
66pub fn approve_if<F>(inner: Arc<dyn SessionHandler>, predicate: F) -> Arc<dyn SessionHandler>
67where
68    F: Fn(&PermissionRequestData) -> bool + Send + Sync + 'static,
69{
70    Arc::new(PermissionOverrideHandler {
71        inner,
72        policy: Policy::Predicate(Arc::new(predicate)),
73    })
74}
75
76enum Policy {
77    ApproveAll,
78    DenyAll,
79    Predicate(Arc<dyn Fn(&PermissionRequestData) -> bool + Send + Sync>),
80}
81
82struct PermissionOverrideHandler {
83    inner: Arc<dyn SessionHandler>,
84    policy: Policy,
85}
86
87#[async_trait]
88impl SessionHandler for PermissionOverrideHandler {
89    async fn on_event(&self, event: HandlerEvent) -> HandlerResponse {
90        match event {
91            HandlerEvent::PermissionRequest { ref data, .. } => {
92                let approved = match &self.policy {
93                    Policy::ApproveAll => true,
94                    Policy::DenyAll => false,
95                    Policy::Predicate(f) => f(data),
96                };
97                HandlerResponse::Permission(if approved {
98                    PermissionResult::Approved
99                } else {
100                    PermissionResult::Denied
101                })
102            }
103            other => self.inner.on_event(other).await,
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::handler::ApproveAllHandler;
112    use crate::types::{RequestId, SessionId};
113
114    fn request() -> HandlerEvent {
115        HandlerEvent::PermissionRequest {
116            session_id: SessionId::from("s1"),
117            request_id: RequestId::new("1"),
118            data: PermissionRequestData {
119                extra: serde_json::json!({"tool": "shell"}),
120                ..Default::default()
121            },
122        }
123    }
124
125    #[tokio::test]
126    async fn approve_all_approves_permission_requests() {
127        let h = approve_all(Arc::new(ApproveAllHandler));
128        match h.on_event(request()).await {
129            HandlerResponse::Permission(PermissionResult::Approved) => {}
130            other => panic!("expected Approved, got {other:?}"),
131        }
132    }
133
134    #[tokio::test]
135    async fn deny_all_denies_permission_requests() {
136        let h = deny_all(Arc::new(ApproveAllHandler));
137        match h.on_event(request()).await {
138            HandlerResponse::Permission(PermissionResult::Denied) => {}
139            other => panic!("expected Denied, got {other:?}"),
140        }
141    }
142
143    #[tokio::test]
144    async fn approve_if_consults_predicate() {
145        let h = approve_if(Arc::new(ApproveAllHandler), |data| {
146            data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell")
147        });
148        match h.on_event(request()).await {
149            HandlerResponse::Permission(PermissionResult::Denied) => {}
150            other => panic!("expected Denied for shell, got {other:?}"),
151        }
152    }
153
154    #[tokio::test]
155    async fn non_permission_events_forward_to_inner() {
156        let h = deny_all(Arc::new(ApproveAllHandler));
157        let event = HandlerEvent::ExitPlanMode {
158            session_id: SessionId::from("s1"),
159            data: crate::types::ExitPlanModeData::default(),
160        };
161        match h.on_event(event).await {
162            HandlerResponse::ExitPlanMode(_) => {}
163            other => panic!("expected ExitPlanMode forwarded, got {other:?}"),
164        }
165    }
166}