github_copilot_sdk/
permission.rs1use std::sync::Arc;
27
28use async_trait::async_trait;
29
30use crate::handler::{HandlerEvent, HandlerResponse, PermissionResult, SessionHandler};
31use crate::types::PermissionRequestData;
32
33pub fn approve_all(inner: Arc<dyn SessionHandler>) -> Arc<dyn SessionHandler> {
36 Arc::new(PermissionOverrideHandler {
37 inner,
38 policy: Policy::ApproveAll,
39 })
40}
41
42pub fn deny_all(inner: Arc<dyn SessionHandler>) -> Arc<dyn SessionHandler> {
45 Arc::new(PermissionOverrideHandler {
46 inner,
47 policy: Policy::DenyAll,
48 })
49}
50
51pub fn approve_if<F>(inner: Arc<dyn SessionHandler>, predicate: F) -> Arc<dyn SessionHandler>
67where
68 F: Fn(&PermissionRequestData) -> bool + Send + Sync + 'static,
69{
70 Arc::new(PermissionOverrideHandler {
71 inner,
72 policy: Policy::Predicate(Arc::new(predicate)),
73 })
74}
75
76enum Policy {
77 ApproveAll,
78 DenyAll,
79 Predicate(Arc<dyn Fn(&PermissionRequestData) -> bool + Send + Sync>),
80}
81
82struct PermissionOverrideHandler {
83 inner: Arc<dyn SessionHandler>,
84 policy: Policy,
85}
86
87#[async_trait]
88impl SessionHandler for PermissionOverrideHandler {
89 async fn on_event(&self, event: HandlerEvent) -> HandlerResponse {
90 match event {
91 HandlerEvent::PermissionRequest { ref data, .. } => {
92 let approved = match &self.policy {
93 Policy::ApproveAll => true,
94 Policy::DenyAll => false,
95 Policy::Predicate(f) => f(data),
96 };
97 HandlerResponse::Permission(if approved {
98 PermissionResult::Approved
99 } else {
100 PermissionResult::Denied
101 })
102 }
103 other => self.inner.on_event(other).await,
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::handler::ApproveAllHandler;
112 use crate::types::{RequestId, SessionId};
113
114 fn request() -> HandlerEvent {
115 HandlerEvent::PermissionRequest {
116 session_id: SessionId::from("s1"),
117 request_id: RequestId::new("1"),
118 data: PermissionRequestData {
119 extra: serde_json::json!({"tool": "shell"}),
120 ..Default::default()
121 },
122 }
123 }
124
125 #[tokio::test]
126 async fn approve_all_approves_permission_requests() {
127 let h = approve_all(Arc::new(ApproveAllHandler));
128 match h.on_event(request()).await {
129 HandlerResponse::Permission(PermissionResult::Approved) => {}
130 other => panic!("expected Approved, got {other:?}"),
131 }
132 }
133
134 #[tokio::test]
135 async fn deny_all_denies_permission_requests() {
136 let h = deny_all(Arc::new(ApproveAllHandler));
137 match h.on_event(request()).await {
138 HandlerResponse::Permission(PermissionResult::Denied) => {}
139 other => panic!("expected Denied, got {other:?}"),
140 }
141 }
142
143 #[tokio::test]
144 async fn approve_if_consults_predicate() {
145 let h = approve_if(Arc::new(ApproveAllHandler), |data| {
146 data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell")
147 });
148 match h.on_event(request()).await {
149 HandlerResponse::Permission(PermissionResult::Denied) => {}
150 other => panic!("expected Denied for shell, got {other:?}"),
151 }
152 }
153
154 #[tokio::test]
155 async fn non_permission_events_forward_to_inner() {
156 let h = deny_all(Arc::new(ApproveAllHandler));
157 let event = HandlerEvent::ExitPlanMode {
158 session_id: SessionId::from("s1"),
159 data: crate::types::ExitPlanModeData::default(),
160 };
161 match h.on_event(event).await {
162 HandlerResponse::ExitPlanMode(_) => {}
163 other => panic!("expected ExitPlanMode forwarded, got {other:?}"),
164 }
165 }
166}