use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, oneshot};
#[derive(Debug, Clone)]
pub struct PermissionDecision {
pub approved: bool,
pub remember_for_session: bool,
pub always_allow: bool,
}
#[derive(Clone)]
pub struct PermissionRelay {
inner: Arc<Mutex<RelayInner>>,
}
struct RelayInner {
pending: HashMap<String, oneshot::Sender<PermissionDecision>>,
session_allowed: Vec<String>,
default_timeout: Duration,
}
impl PermissionRelay {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(RelayInner {
pending: HashMap::new(),
session_allowed: Vec::new(),
default_timeout: Duration::from_secs(60),
})),
}
}
pub async fn is_session_allowed(&self, tool_name: &str) -> bool {
let inner = self.inner.lock().await;
inner.session_allowed.contains(&tool_name.to_string())
}
pub async fn register_request(
&self,
request_id: String,
) -> oneshot::Receiver<PermissionDecision> {
let (tx, rx) = oneshot::channel();
let mut inner = self.inner.lock().await;
inner.pending.insert(request_id, tx);
rx
}
pub async fn resolve(&self, request_id: &str, decision: PermissionDecision) -> bool {
let mut inner = self.inner.lock().await;
if decision.always_allow || decision.remember_for_session {
}
if let Some(tx) = inner.pending.remove(request_id) {
tx.send(decision).is_ok()
} else {
false
}
}
pub async fn resolve_with_tool(
&self,
request_id: &str,
tool_name: &str,
decision: PermissionDecision,
) -> bool {
let mut inner = self.inner.lock().await;
if decision.always_allow && !inner.session_allowed.contains(&tool_name.to_string()) {
inner.session_allowed.push(tool_name.to_string());
}
if let Some(tx) = inner.pending.remove(request_id) {
tx.send(decision).is_ok()
} else {
false
}
}
pub async fn cancel(&self, request_id: &str) -> bool {
let mut inner = self.inner.lock().await;
inner.pending.remove(request_id).is_some()
}
pub async fn pending_count(&self) -> usize {
let inner = self.inner.lock().await;
inner.pending.len()
}
pub async fn default_timeout(&self) -> Duration {
let inner = self.inner.lock().await;
inner.default_timeout
}
pub async fn add_session_allowed(&self, tool_name: &str) {
let mut inner = self.inner.lock().await;
if !inner.session_allowed.contains(&tool_name.to_string()) {
inner.session_allowed.push(tool_name.to_string());
}
}
pub async fn clear(&self) {
let mut inner = self.inner.lock().await;
inner.pending.clear();
}
}
impl Default for PermissionRelay {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_register_and_resolve() {
let relay = PermissionRelay::new();
let rx = relay.register_request("req-1".to_string()).await;
assert_eq!(relay.pending_count().await, 1);
let resolved = relay
.resolve(
"req-1",
PermissionDecision {
approved: true,
remember_for_session: false,
always_allow: false,
},
)
.await;
assert!(resolved);
assert_eq!(relay.pending_count().await, 0);
let decision = rx.await.unwrap();
assert!(decision.approved);
}
#[tokio::test]
async fn test_resolve_unknown_request() {
let relay = PermissionRelay::new();
let resolved = relay
.resolve(
"nonexistent",
PermissionDecision {
approved: true,
remember_for_session: false,
always_allow: false,
},
)
.await;
assert!(!resolved);
}
#[tokio::test]
async fn test_session_allowed() {
let relay = PermissionRelay::new();
assert!(!relay.is_session_allowed("bash").await);
relay
.resolve_with_tool(
"req-1",
"bash",
PermissionDecision {
approved: true,
remember_for_session: false,
always_allow: true,
},
)
.await;
assert!(relay.is_session_allowed("bash").await);
}
#[tokio::test]
async fn test_cancel_request() {
let relay = PermissionRelay::new();
let rx = relay.register_request("req-1".to_string()).await;
assert!(relay.cancel("req-1").await);
assert_eq!(relay.pending_count().await, 0);
assert!(rx.await.is_err());
}
}