github_copilot_sdk/
permission.rs1use std::sync::Arc;
16
17use async_trait::async_trait;
18
19use crate::handler::{PermissionHandler, PermissionResult};
20use crate::types::{PermissionRequestData, RequestId, SessionId};
21
22pub fn approve_all() -> Arc<dyn PermissionHandler> {
24 Arc::new(PolicyHandler {
25 policy: Policy::ApproveAll,
26 })
27}
28
29pub fn deny_all() -> Arc<dyn PermissionHandler> {
31 Arc::new(PolicyHandler {
32 policy: Policy::DenyAll,
33 })
34}
35
36pub fn approve_if<F>(predicate: F) -> Arc<dyn PermissionHandler>
47where
48 F: Fn(&PermissionRequestData) -> bool + Send + Sync + 'static,
49{
50 Arc::new(PolicyHandler {
51 policy: Policy::Predicate(Arc::new(predicate)),
52 })
53}
54
55#[derive(Clone)]
63pub(crate) enum Policy {
64 ApproveAll,
65 DenyAll,
66 Predicate(Arc<dyn Fn(&PermissionRequestData) -> bool + Send + Sync>),
67}
68
69impl std::fmt::Debug for Policy {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 Self::ApproveAll => f.write_str("Policy::ApproveAll"),
73 Self::DenyAll => f.write_str("Policy::DenyAll"),
74 Self::Predicate(_) => f.write_str("Policy::Predicate(<fn>)"),
75 }
76 }
77}
78
79pub(crate) fn resolve_handler(
91 handler: Option<Arc<dyn PermissionHandler>>,
92 policy: Option<Policy>,
93) -> Option<Arc<dyn PermissionHandler>> {
94 match (handler, policy) {
95 (_, Some(policy)) => Some(Arc::new(PolicyHandler { policy })),
96 (Some(h), None) => Some(h),
97 (None, None) => None,
98 }
99}
100
101struct PolicyHandler {
102 policy: Policy,
103}
104
105#[async_trait]
106impl PermissionHandler for PolicyHandler {
107 async fn handle(
108 &self,
109 _session_id: SessionId,
110 _request_id: RequestId,
111 data: PermissionRequestData,
112 ) -> PermissionResult {
113 let approved = match &self.policy {
114 Policy::ApproveAll => true,
115 Policy::DenyAll => false,
116 Policy::Predicate(f) => f(&data),
117 };
118 if approved {
119 PermissionResult::approve_once()
120 } else {
121 PermissionResult::reject(None)
122 }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 fn data() -> PermissionRequestData {
131 PermissionRequestData {
132 extra: serde_json::json!({ "tool": "shell" }),
133 ..Default::default()
134 }
135 }
136
137 #[tokio::test]
138 async fn approve_all_approves() {
139 let h = approve_all();
140 assert!(matches!(
141 h.handle(SessionId::from("s"), RequestId::new("1"), data())
142 .await,
143 PermissionResult::Decision(crate::types::PermissionDecision::ApproveOnce(_))
144 ));
145 }
146
147 #[tokio::test]
148 async fn deny_all_denies() {
149 let h = deny_all();
150 assert!(matches!(
151 h.handle(SessionId::from("s"), RequestId::new("1"), data())
152 .await,
153 PermissionResult::Decision(crate::types::PermissionDecision::Reject(_))
154 ));
155 }
156
157 #[tokio::test]
158 async fn approve_if_consults_predicate() {
159 let h = approve_if(|d| d.extra.get("tool").and_then(|v| v.as_str()) != Some("shell"));
160 assert!(matches!(
161 h.handle(SessionId::from("s"), RequestId::new("1"), data())
162 .await,
163 PermissionResult::Decision(crate::types::PermissionDecision::Reject(_))
164 ));
165 }
166
167 #[tokio::test]
168 async fn resolve_handler_policy_wins() {
169 struct AlwaysApprove;
170 #[async_trait]
171 impl PermissionHandler for AlwaysApprove {
172 async fn handle(
173 &self,
174 _: SessionId,
175 _: RequestId,
176 _: PermissionRequestData,
177 ) -> PermissionResult {
178 PermissionResult::approve_once()
179 }
180 }
181 let resolved =
182 resolve_handler(Some(Arc::new(AlwaysApprove)), Some(Policy::DenyAll)).unwrap();
183 assert!(matches!(
185 resolved
186 .handle(SessionId::from("s"), RequestId::new("1"), data())
187 .await,
188 PermissionResult::Decision(crate::types::PermissionDecision::Reject(_))
189 ));
190 }
191
192 #[tokio::test]
193 async fn resolve_handler_with_only_handler() {
194 struct H;
195 #[async_trait]
196 impl PermissionHandler for H {
197 async fn handle(
198 &self,
199 _: SessionId,
200 _: RequestId,
201 _: PermissionRequestData,
202 ) -> PermissionResult {
203 PermissionResult::approve_once()
204 }
205 }
206 let resolved = resolve_handler(Some(Arc::new(H)), None).unwrap();
207 assert!(matches!(
208 resolved
209 .handle(SessionId::from("s"), RequestId::new("1"), data())
210 .await,
211 PermissionResult::Decision(crate::types::PermissionDecision::ApproveOnce(_))
212 ));
213 }
214
215 #[test]
216 fn resolve_handler_with_neither_returns_none() {
217 assert!(resolve_handler(None, None).is_none());
218 }
219}