use std::future::Future;
use std::sync::Arc;
use async_trait::async_trait;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ApprovalAsk {
pub tool_name: String,
pub permission: String,
pub resource: String,
}
#[async_trait]
pub trait ApprovalProxy: Send + Sync {
async fn request_approval(&self, ask: ApprovalAsk) -> bool;
}
tokio::task_local! {
static APPROVAL_PROXY: Option<Arc<dyn ApprovalProxy>>;
}
pub async fn with_approval_proxy<F, T>(proxy: Option<Arc<dyn ApprovalProxy>>, fut: F) -> T
where
F: Future<Output = T>,
{
APPROVAL_PROXY.scope(proxy, fut).await
}
pub fn current_approval_proxy() -> Option<Arc<dyn ApprovalProxy>> {
APPROVAL_PROXY.try_with(|p| p.clone()).ok().flatten()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
struct Recorder {
approve: bool,
seen: Arc<AtomicUsize>,
}
#[async_trait]
impl ApprovalProxy for Recorder {
async fn request_approval(&self, _ask: ApprovalAsk) -> bool {
self.seen.fetch_add(1, Ordering::SeqCst);
self.approve
}
}
#[tokio::test]
async fn current_proxy_is_none_outside_scope() {
assert!(current_approval_proxy().is_none());
}
#[tokio::test]
async fn scope_installs_and_clears_proxy() {
let seen = Arc::new(AtomicUsize::new(0));
let proxy: Arc<dyn ApprovalProxy> = Arc::new(Recorder {
approve: true,
seen: seen.clone(),
});
with_approval_proxy(Some(proxy), async {
let got = current_approval_proxy().expect("proxy installed in scope");
assert!(
got.request_approval(ApprovalAsk {
tool_name: "Write".into(),
permission: "write".into(),
resource: "/tmp/x".into(),
})
.await
);
})
.await;
assert_eq!(seen.load(Ordering::SeqCst), 1);
assert!(current_approval_proxy().is_none());
}
#[tokio::test]
async fn none_scope_keeps_proxy_unset() {
with_approval_proxy(None, async {
assert!(current_approval_proxy().is_none());
})
.await;
}
}