use std::sync::Arc;
use async_trait::async_trait;
use crate::handler::{HandlerEvent, HandlerResponse, PermissionResult, SessionHandler};
use crate::types::PermissionRequestData;
pub fn approve_all(inner: Arc<dyn SessionHandler>) -> Arc<dyn SessionHandler> {
Arc::new(PermissionOverrideHandler {
inner,
policy: Policy::ApproveAll,
})
}
pub fn deny_all(inner: Arc<dyn SessionHandler>) -> Arc<dyn SessionHandler> {
Arc::new(PermissionOverrideHandler {
inner,
policy: Policy::DenyAll,
})
}
pub fn approve_if<F>(inner: Arc<dyn SessionHandler>, predicate: F) -> Arc<dyn SessionHandler>
where
F: Fn(&PermissionRequestData) -> bool + Send + Sync + 'static,
{
Arc::new(PermissionOverrideHandler {
inner,
policy: Policy::Predicate(Arc::new(predicate)),
})
}
enum Policy {
ApproveAll,
DenyAll,
Predicate(Arc<dyn Fn(&PermissionRequestData) -> bool + Send + Sync>),
}
struct PermissionOverrideHandler {
inner: Arc<dyn SessionHandler>,
policy: Policy,
}
#[async_trait]
impl SessionHandler for PermissionOverrideHandler {
async fn on_event(&self, event: HandlerEvent) -> HandlerResponse {
match event {
HandlerEvent::PermissionRequest { ref data, .. } => {
let approved = match &self.policy {
Policy::ApproveAll => true,
Policy::DenyAll => false,
Policy::Predicate(f) => f(data),
};
HandlerResponse::Permission(if approved {
PermissionResult::Approved
} else {
PermissionResult::Denied
})
}
other => self.inner.on_event(other).await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handler::ApproveAllHandler;
use crate::types::{RequestId, SessionId};
fn request() -> HandlerEvent {
HandlerEvent::PermissionRequest {
session_id: SessionId::from("s1"),
request_id: RequestId::new("1"),
data: PermissionRequestData {
extra: serde_json::json!({"tool": "shell"}),
..Default::default()
},
}
}
#[tokio::test]
async fn approve_all_approves_permission_requests() {
let h = approve_all(Arc::new(ApproveAllHandler));
match h.on_event(request()).await {
HandlerResponse::Permission(PermissionResult::Approved) => {}
other => panic!("expected Approved, got {other:?}"),
}
}
#[tokio::test]
async fn deny_all_denies_permission_requests() {
let h = deny_all(Arc::new(ApproveAllHandler));
match h.on_event(request()).await {
HandlerResponse::Permission(PermissionResult::Denied) => {}
other => panic!("expected Denied, got {other:?}"),
}
}
#[tokio::test]
async fn approve_if_consults_predicate() {
let h = approve_if(Arc::new(ApproveAllHandler), |data| {
data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell")
});
match h.on_event(request()).await {
HandlerResponse::Permission(PermissionResult::Denied) => {}
other => panic!("expected Denied for shell, got {other:?}"),
}
}
#[tokio::test]
async fn non_permission_events_forward_to_inner() {
let h = deny_all(Arc::new(ApproveAllHandler));
let event = HandlerEvent::ExitPlanMode {
session_id: SessionId::from("s1"),
data: crate::types::ExitPlanModeData::default(),
};
match h.on_event(event).await {
HandlerResponse::ExitPlanMode(_) => {}
other => panic!("expected ExitPlanMode forwarded, got {other:?}"),
}
}
}