use std::sync::Arc;
use async_trait::async_trait;
use crate::handler::{PermissionHandler, PermissionResult};
use crate::types::{PermissionRequestData, RequestId, SessionId};
pub fn approve_all() -> Arc<dyn PermissionHandler> {
Arc::new(PolicyHandler {
policy: Policy::ApproveAll,
})
}
pub fn deny_all() -> Arc<dyn PermissionHandler> {
Arc::new(PolicyHandler {
policy: Policy::DenyAll,
})
}
pub fn approve_if<F>(predicate: F) -> Arc<dyn PermissionHandler>
where
F: Fn(&PermissionRequestData) -> bool + Send + Sync + 'static,
{
Arc::new(PolicyHandler {
policy: Policy::Predicate(Arc::new(predicate)),
})
}
#[derive(Clone)]
pub(crate) enum Policy {
ApproveAll,
DenyAll,
Predicate(Arc<dyn Fn(&PermissionRequestData) -> bool + Send + Sync>),
}
impl std::fmt::Debug for Policy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ApproveAll => f.write_str("Policy::ApproveAll"),
Self::DenyAll => f.write_str("Policy::DenyAll"),
Self::Predicate(_) => f.write_str("Policy::Predicate(<fn>)"),
}
}
}
pub(crate) fn resolve_handler(
handler: Option<Arc<dyn PermissionHandler>>,
policy: Option<Policy>,
) -> Option<Arc<dyn PermissionHandler>> {
match (handler, policy) {
(_, Some(policy)) => Some(Arc::new(PolicyHandler { policy })),
(Some(h), None) => Some(h),
(None, None) => None,
}
}
struct PolicyHandler {
policy: Policy,
}
#[async_trait]
impl PermissionHandler for PolicyHandler {
async fn handle(
&self,
_session_id: SessionId,
_request_id: RequestId,
data: PermissionRequestData,
) -> PermissionResult {
let approved = match &self.policy {
Policy::ApproveAll => true,
Policy::DenyAll => false,
Policy::Predicate(f) => f(&data),
};
if approved {
PermissionResult::approve_once()
} else {
PermissionResult::reject(None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn data() -> PermissionRequestData {
PermissionRequestData {
extra: serde_json::json!({ "tool": "shell" }),
..Default::default()
}
}
#[tokio::test]
async fn approve_all_approves() {
let h = approve_all();
assert!(matches!(
h.handle(SessionId::from("s"), RequestId::new("1"), data())
.await,
PermissionResult::Decision(crate::types::PermissionDecision::ApproveOnce(_))
));
}
#[tokio::test]
async fn deny_all_denies() {
let h = deny_all();
assert!(matches!(
h.handle(SessionId::from("s"), RequestId::new("1"), data())
.await,
PermissionResult::Decision(crate::types::PermissionDecision::Reject(_))
));
}
#[tokio::test]
async fn approve_if_consults_predicate() {
let h = approve_if(|d| d.extra.get("tool").and_then(|v| v.as_str()) != Some("shell"));
assert!(matches!(
h.handle(SessionId::from("s"), RequestId::new("1"), data())
.await,
PermissionResult::Decision(crate::types::PermissionDecision::Reject(_))
));
}
#[tokio::test]
async fn resolve_handler_policy_wins() {
struct AlwaysApprove;
#[async_trait]
impl PermissionHandler for AlwaysApprove {
async fn handle(
&self,
_: SessionId,
_: RequestId,
_: PermissionRequestData,
) -> PermissionResult {
PermissionResult::approve_once()
}
}
let resolved =
resolve_handler(Some(Arc::new(AlwaysApprove)), Some(Policy::DenyAll)).unwrap();
assert!(matches!(
resolved
.handle(SessionId::from("s"), RequestId::new("1"), data())
.await,
PermissionResult::Decision(crate::types::PermissionDecision::Reject(_))
));
}
#[tokio::test]
async fn resolve_handler_with_only_handler() {
struct H;
#[async_trait]
impl PermissionHandler for H {
async fn handle(
&self,
_: SessionId,
_: RequestId,
_: PermissionRequestData,
) -> PermissionResult {
PermissionResult::approve_once()
}
}
let resolved = resolve_handler(Some(Arc::new(H)), None).unwrap();
assert!(matches!(
resolved
.handle(SessionId::from("s"), RequestId::new("1"), data())
.await,
PermissionResult::Decision(crate::types::PermissionDecision::ApproveOnce(_))
));
}
#[test]
fn resolve_handler_with_neither_returns_none() {
assert!(resolve_handler(None, None).is_none());
}
}