use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc, oneshot};
const PENDING_CLEANUP_THRESHOLD: usize = 50;
const PENDING_MAX_AGE: Duration = Duration::from_secs(300);
use super::{
BatchPermissionRequest, BatchPermissionResponse, Grant, GrantTarget, PermissionRequest,
};
use crate::controller::types::{ControllerEvent, TurnId};
#[derive(Debug, Clone)]
pub struct PendingPermissionInfo {
pub tool_use_id: String,
pub session_id: i64,
pub request: PermissionRequest,
pub turn_id: Option<TurnId>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PermissionPanelResponse {
pub granted: bool,
#[serde(skip)]
pub grant: Option<Grant>,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
static BATCH_COUNTER: AtomicU64 = AtomicU64::new(1);
pub fn generate_batch_id() -> String {
let id = BATCH_COUNTER.fetch_add(1, Ordering::SeqCst);
format!("batch-{}", id)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PermissionError {
NotFound,
AlreadyResponded,
SendFailed,
EventSendFailed,
BatchAlreadyProcessed,
}
impl std::fmt::Display for PermissionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PermissionError::NotFound => write!(f, "No pending permission request found"),
PermissionError::AlreadyResponded => write!(f, "Permission already responded to"),
PermissionError::SendFailed => write!(f, "Failed to send response"),
PermissionError::EventSendFailed => write!(f, "Failed to send event notification"),
PermissionError::BatchAlreadyProcessed => write!(f, "Batch has already been processed"),
}
}
}
impl std::error::Error for PermissionError {}
struct PendingRequest {
session_id: i64,
request: PermissionRequest,
turn_id: Option<TurnId>,
responder: oneshot::Sender<PermissionPanelResponse>,
created_at: Instant,
}
struct PendingBatch {
session_id: i64,
#[allow(dead_code)]
requests: Vec<PermissionRequest>,
#[allow(dead_code)]
turn_id: Option<TurnId>,
responder: oneshot::Sender<BatchPermissionResponse>,
created_at: Instant,
}
pub struct PermissionRegistry {
session_grants: Mutex<HashMap<i64, Vec<Grant>>>,
pending_requests: Mutex<HashMap<String, PendingRequest>>,
pending_batches: Mutex<HashMap<String, PendingBatch>>,
event_tx: mpsc::Sender<ControllerEvent>,
}
impl PermissionRegistry {
pub fn new(event_tx: mpsc::Sender<ControllerEvent>) -> Self {
Self {
session_grants: Mutex::new(HashMap::new()),
pending_requests: Mutex::new(HashMap::new()),
pending_batches: Mutex::new(HashMap::new()),
event_tx,
}
}
pub async fn add_grant(&self, session_id: i64, grant: Grant) {
let mut grants = self.session_grants.lock().await;
let session_grants = grants.entry(session_id).or_insert_with(Vec::new);
session_grants.push(grant);
}
pub async fn add_grants(&self, session_id: i64, new_grants: Vec<Grant>) {
let mut grants = self.session_grants.lock().await;
let session_grants = grants.entry(session_id).or_insert_with(Vec::new);
session_grants.extend(new_grants);
}
pub async fn cleanup_expired(&self, session_id: i64) {
let mut grants = self.session_grants.lock().await;
if let Some(session_grants) = grants.get_mut(&session_id) {
session_grants.retain(|g| !g.is_expired());
}
}
pub async fn revoke_grants(&self, session_id: i64, target: &GrantTarget) -> usize {
let mut grants = self.session_grants.lock().await;
if let Some(session_grants) = grants.get_mut(&session_id) {
let original_len = session_grants.len();
session_grants.retain(|g| &g.target != target);
original_len - session_grants.len()
} else {
0
}
}
pub async fn get_grants(&self, session_id: i64) -> Vec<Grant> {
let grants = self.session_grants.lock().await;
grants.get(&session_id).cloned().unwrap_or_default()
}
pub async fn clear_grants(&self, session_id: i64) {
let mut grants = self.session_grants.lock().await;
grants.remove(&session_id);
}
pub async fn check(&self, session_id: i64, request: &PermissionRequest) -> bool {
let grants = self.session_grants.lock().await;
if let Some(session_grants) = grants.get(&session_id) {
session_grants.iter().any(|grant| grant.satisfies(request))
} else {
false
}
}
pub async fn check_batch(
&self,
session_id: i64,
requests: &[PermissionRequest],
) -> HashSet<String> {
let grants = self.session_grants.lock().await;
let session_grants = grants.get(&session_id);
let mut granted = HashSet::new();
for request in requests {
if let Some(sg) = session_grants
&& sg.iter().any(|grant| grant.satisfies(request))
{
granted.insert(request.id.clone());
}
}
granted
}
pub async fn find_satisfying_grant(
&self,
session_id: i64,
request: &PermissionRequest,
) -> Option<Grant> {
let grants = self.session_grants.lock().await;
if let Some(session_grants) = grants.get(&session_id) {
session_grants
.iter()
.find(|grant| grant.satisfies(request))
.cloned()
} else {
None
}
}
pub async fn request_permission(
&self,
session_id: i64,
request: PermissionRequest,
turn_id: Option<TurnId>,
) -> Result<oneshot::Receiver<PermissionPanelResponse>, PermissionError> {
if self.check(session_id, &request).await {
let (tx, rx) = oneshot::channel();
let _ = tx.send(PermissionPanelResponse {
granted: true,
grant: None, message: None,
});
return Ok(rx);
}
let (tx, rx) = oneshot::channel();
let request_id = request.id.clone();
{
let mut pending = self.pending_requests.lock().await;
if pending.len() >= PENDING_CLEANUP_THRESHOLD {
let now = Instant::now();
pending.retain(|id, req| {
let keep = now.duration_since(req.created_at) < PENDING_MAX_AGE;
if !keep {
tracing::warn!(
request_id = %id,
age_secs = now.duration_since(req.created_at).as_secs(),
"Cleaning up stale pending permission request"
);
}
keep
});
}
pending.insert(
request_id.clone(),
PendingRequest {
session_id,
request: request.clone(),
turn_id: turn_id.clone(),
responder: tx,
created_at: Instant::now(),
},
);
}
self.event_tx
.send(ControllerEvent::PermissionRequired {
session_id,
tool_use_id: request_id,
request,
turn_id,
})
.await
.map_err(|_| PermissionError::EventSendFailed)?;
Ok(rx)
}
pub async fn respond_to_request(
&self,
request_id: &str,
response: PermissionPanelResponse,
) -> Result<(), PermissionError> {
let pending = {
let mut pending = self.pending_requests.lock().await;
pending
.remove(request_id)
.ok_or(PermissionError::NotFound)?
};
if response.granted
&& let Some(ref g) = response.grant
{
self.add_grant(pending.session_id, g.clone()).await;
}
pending
.responder
.send(response)
.map_err(|_| PermissionError::SendFailed)
}
pub async fn cancel(&self, request_id: &str) -> Result<(), PermissionError> {
let mut pending = self.pending_requests.lock().await;
if pending.remove(request_id).is_some() {
Ok(())
} else {
Err(PermissionError::NotFound)
}
}
pub async fn pending_for_session(&self, session_id: i64) -> Vec<PendingPermissionInfo> {
let pending = self.pending_requests.lock().await;
pending
.iter()
.filter(|(_, req)| req.session_id == session_id)
.map(|(tool_use_id, req)| PendingPermissionInfo {
tool_use_id: tool_use_id.clone(),
session_id: req.session_id,
request: req.request.clone(),
turn_id: req.turn_id.clone(),
})
.collect()
}
pub async fn is_granted(&self, session_id: i64, request: &PermissionRequest) -> bool {
self.check(session_id, request).await
}
pub async fn register_batch(
&self,
session_id: i64,
requests: Vec<PermissionRequest>,
turn_id: Option<TurnId>,
) -> Result<oneshot::Receiver<BatchPermissionResponse>, PermissionError> {
let auto_approved = self.check_batch(session_id, &requests).await;
let needs_approval: Vec<_> = requests
.iter()
.filter(|r| !auto_approved.contains(&r.id))
.cloned()
.collect();
if needs_approval.is_empty() {
let (tx, rx) = oneshot::channel();
let response =
BatchPermissionResponse::with_auto_approved(generate_batch_id(), auto_approved);
let _ = tx.send(response);
return Ok(rx);
}
let batch_id = generate_batch_id();
let (tx, rx) = oneshot::channel();
let batch = BatchPermissionRequest::new(batch_id.clone(), needs_approval.clone());
{
let mut pending = self.pending_batches.lock().await;
if pending.len() >= PENDING_CLEANUP_THRESHOLD {
let now = Instant::now();
pending.retain(|id, batch| {
let keep = now.duration_since(batch.created_at) < PENDING_MAX_AGE;
if !keep {
tracing::warn!(
batch_id = %id,
age_secs = now.duration_since(batch.created_at).as_secs(),
"Cleaning up stale pending batch permission request"
);
}
keep
});
}
pending.insert(
batch_id.clone(),
PendingBatch {
session_id,
requests: needs_approval,
turn_id: turn_id.clone(),
responder: tx,
created_at: Instant::now(),
},
);
}
self.event_tx
.send(ControllerEvent::BatchPermissionRequired {
session_id,
batch,
turn_id,
})
.await
.map_err(|_| PermissionError::EventSendFailed)?;
Ok(rx)
}
pub async fn respond_to_batch(
&self,
batch_id: &str,
mut response: BatchPermissionResponse,
) -> Result<(), PermissionError> {
let pending = {
let mut pending = self.pending_batches.lock().await;
pending.remove(batch_id).ok_or(PermissionError::NotFound)?
};
if !response.approved_grants.is_empty() {
self.add_grants(pending.session_id, response.approved_grants.clone())
.await;
}
response.batch_id = batch_id.to_string();
pending
.responder
.send(response)
.map_err(|_| PermissionError::SendFailed)
}
pub async fn cancel_batch(&self, batch_id: &str) -> Result<(), PermissionError> {
let mut pending = self.pending_batches.lock().await;
if pending.remove(batch_id).is_some() {
Ok(())
} else {
Err(PermissionError::NotFound)
}
}
pub async fn cancel_session(&self, session_id: i64) {
{
let mut pending = self.pending_requests.lock().await;
pending.retain(|_, p| p.session_id != session_id);
}
{
let mut pending = self.pending_batches.lock().await;
pending.retain(|_, p| p.session_id != session_id);
}
}
pub async fn clear_session(&self, session_id: i64) {
self.cancel_session(session_id).await;
self.clear_grants(session_id).await;
}
pub async fn has_pending(&self, session_id: i64) -> bool {
let individual_pending = {
let pending = self.pending_requests.lock().await;
pending.values().any(|p| p.session_id == session_id)
};
if individual_pending {
return true;
}
let pending = self.pending_batches.lock().await;
pending.values().any(|p| p.session_id == session_id)
}
pub async fn pending_count(&self) -> usize {
let individual = self.pending_requests.lock().await.len();
let batch = self.pending_batches.lock().await.len();
individual + batch
}
pub async fn pending_request_ids(&self, session_id: i64) -> Vec<String> {
let pending = self.pending_requests.lock().await;
pending
.iter()
.filter(|(_, p)| p.session_id == session_id)
.map(|(id, _)| id.clone())
.collect()
}
pub async fn pending_batch_ids(&self, session_id: i64) -> Vec<String> {
let pending = self.pending_batches.lock().await;
pending
.iter()
.filter(|(_, p)| p.session_id == session_id)
.map(|(id, _)| id.clone())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::permissions::PermissionLevel;
fn create_read_request(id: &str, path: &str) -> PermissionRequest {
PermissionRequest::file_read(id, path)
}
fn create_write_request(id: &str, path: &str) -> PermissionRequest {
PermissionRequest::file_write(id, path)
}
#[tokio::test]
async fn test_add_and_check_grant() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::read_path("/project/src", true);
registry.add_grant(1, grant).await;
let request = create_read_request("req-1", "/project/src/main.rs");
assert!(registry.check(1, &request).await);
assert!(!registry.check(2, &request).await);
}
#[tokio::test]
async fn test_level_hierarchy() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::write_path("/project", true);
registry.add_grant(1, grant).await;
let read_request = create_read_request("req-1", "/project/file.rs");
assert!(registry.check(1, &read_request).await);
let write_request = create_write_request("req-2", "/project/file.rs");
assert!(registry.check(1, &write_request).await);
}
#[tokio::test]
async fn test_level_hierarchy_insufficient() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::read_path("/project", true);
registry.add_grant(1, grant).await;
let write_request = create_write_request("req-1", "/project/file.rs");
assert!(!registry.check(1, &write_request).await);
}
#[tokio::test]
async fn test_recursive_path_grant() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::read_path("/project", true);
registry.add_grant(1, grant).await;
let request = create_read_request("req-1", "/project/src/utils/mod.rs");
assert!(registry.check(1, &request).await);
}
#[tokio::test]
async fn test_non_recursive_path_grant() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::read_path("/project/src", false);
registry.add_grant(1, grant).await;
let direct = create_read_request("req-1", "/project/src/main.rs");
assert!(registry.check(1, &direct).await);
let nested = create_read_request("req-2", "/project/src/utils/mod.rs");
assert!(!registry.check(1, &nested).await);
}
#[tokio::test]
async fn test_check_batch() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::read_path("/project/src", true);
registry.add_grant(1, grant).await;
let requests = vec![
create_read_request("req-1", "/project/src/main.rs"),
create_read_request("req-2", "/project/tests/test.rs"), create_read_request("req-3", "/project/src/lib.rs"),
];
let granted = registry.check_batch(1, &requests).await;
assert!(granted.contains("req-1"));
assert!(!granted.contains("req-2"));
assert!(granted.contains("req-3"));
}
#[tokio::test]
async fn test_request_permission_auto_approve() {
let (tx, mut rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::read_path("/project", true);
registry.add_grant(1, grant).await;
let request = create_read_request("req-1", "/project/file.rs");
let result_rx = registry.request_permission(1, request, None).await.unwrap();
let response = result_rx.await.unwrap();
assert!(response.granted);
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn test_request_permission_needs_approval() {
let (tx, mut rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let request = create_read_request("req-1", "/project/file.rs");
let result_rx = registry.request_permission(1, request, None).await.unwrap();
let event = rx.recv().await.unwrap();
if let ControllerEvent::PermissionRequired { tool_use_id, .. } = event {
assert_eq!(tool_use_id, "req-1");
} else {
panic!("Expected PermissionRequired event");
}
let grant = Grant::read_path("/project", true);
let response = PermissionPanelResponse {
granted: true,
grant: Some(grant),
message: None,
};
registry
.respond_to_request("req-1", response)
.await
.unwrap();
let response = result_rx.await.unwrap();
assert!(response.granted);
let new_request = create_read_request("req-2", "/project/other.rs");
assert!(registry.check(1, &new_request).await);
}
#[tokio::test]
async fn test_request_permission_denied() {
let (tx, mut rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let request = create_read_request("req-1", "/project/file.rs");
let result_rx = registry.request_permission(1, request, None).await.unwrap();
let _ = rx.recv().await.unwrap();
let response = PermissionPanelResponse {
granted: false,
grant: None,
message: None,
};
registry
.respond_to_request("req-1", response)
.await
.unwrap();
let response = result_rx.await.unwrap();
assert!(!response.granted);
}
#[tokio::test]
async fn test_register_batch() {
let (tx, mut rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let requests = vec![
create_read_request("req-1", "/project/src/main.rs"),
create_read_request("req-2", "/project/src/lib.rs"),
];
let result_rx = registry.register_batch(1, requests, None).await.unwrap();
let event = rx.recv().await.unwrap();
let batch_id = if let ControllerEvent::BatchPermissionRequired { batch, .. } = event {
assert_eq!(batch.requests.len(), 2);
assert!(!batch.suggested_grants.is_empty());
batch.batch_id.clone()
} else {
panic!("Expected BatchPermissionRequired event");
};
let grant = Grant::read_path("/project/src", true);
let response = BatchPermissionResponse::all_granted(&batch_id, vec![grant]);
registry
.respond_to_batch(&batch_id, response)
.await
.unwrap();
let result = result_rx.await.unwrap();
assert!(!result.approved_grants.is_empty());
}
#[tokio::test]
async fn test_register_batch_partial_auto_approve() {
let (tx, mut rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::read_path("/project/src", true);
registry.add_grant(1, grant).await;
let requests = vec![
create_read_request("req-1", "/project/src/main.rs"), create_read_request("req-2", "/project/tests/test.rs"), ];
let result_rx = registry.register_batch(1, requests, None).await.unwrap();
let event = rx.recv().await.unwrap();
let batch_id = if let ControllerEvent::BatchPermissionRequired { batch, .. } = event {
assert_eq!(batch.requests.len(), 1);
assert_eq!(batch.requests[0].id, "req-2");
batch.batch_id.clone()
} else {
panic!("Expected BatchPermissionRequired event");
};
let grant = Grant::read_path("/project/tests", true);
let response = BatchPermissionResponse::all_granted(&batch_id, vec![grant]);
registry
.respond_to_batch(&batch_id, response)
.await
.unwrap();
let _ = result_rx.await.unwrap();
}
#[tokio::test]
async fn test_register_batch_all_auto_approved() {
let (tx, mut rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::read_path("/project", true);
registry.add_grant(1, grant).await;
let requests = vec![
create_read_request("req-1", "/project/src/main.rs"),
create_read_request("req-2", "/project/tests/test.rs"),
];
let result_rx = registry.register_batch(1, requests, None).await.unwrap();
let result = result_rx.await.unwrap();
assert!(result.auto_approved.contains("req-1"));
assert!(result.auto_approved.contains("req-2"));
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn test_revoke_grants() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant1 = Grant::read_path("/project/src", true);
let grant2 = Grant::read_path("/project/tests", true);
registry.add_grant(1, grant1).await;
registry.add_grant(1, grant2).await;
let target = GrantTarget::path("/project/src", true);
let revoked = registry.revoke_grants(1, &target).await;
assert_eq!(revoked, 1);
let request1 = create_read_request("req-1", "/project/src/file.rs");
assert!(!registry.check(1, &request1).await);
let request2 = create_read_request("req-2", "/project/tests/test.rs");
assert!(registry.check(1, &request2).await);
}
#[tokio::test]
async fn test_clear_session() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::read_path("/project", true);
registry.add_grant(1, grant).await;
registry.clear_session(1).await;
let request = create_read_request("req-1", "/project/file.rs");
assert!(!registry.check(1, &request).await);
}
#[tokio::test]
async fn test_cancel_session() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let request = create_read_request("req-1", "/project/file.rs");
let result_rx = registry.request_permission(1, request, None).await.unwrap();
assert!(registry.has_pending(1).await);
registry.cancel_session(1).await;
assert!(!registry.has_pending(1).await);
assert!(result_rx.await.is_err());
}
#[tokio::test]
async fn test_domain_grant() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::domain("*.github.com", PermissionLevel::Read);
registry.add_grant(1, grant).await;
let request =
PermissionRequest::network_access("req-1", "api.github.com", PermissionLevel::Read);
assert!(registry.check(1, &request).await);
let other_domain =
PermissionRequest::network_access("req-2", "api.gitlab.com", PermissionLevel::Read);
assert!(!registry.check(1, &other_domain).await);
}
#[tokio::test]
async fn test_command_grant() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::command("git *", PermissionLevel::Execute);
registry.add_grant(1, grant).await;
let request = PermissionRequest::command_execute("req-1", "git status");
assert!(registry.check(1, &request).await);
let other_cmd = PermissionRequest::command_execute("req-2", "docker run nginx");
assert!(!registry.check(1, &other_cmd).await);
}
#[tokio::test]
async fn test_find_satisfying_grant() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
let grant = Grant::write_path("/project", true);
registry.add_grant(1, grant.clone()).await;
let request = create_read_request("req-1", "/project/file.rs");
let found = registry.find_satisfying_grant(1, &request).await;
assert!(found.is_some());
assert_eq!(found.unwrap().target, grant.target);
}
#[tokio::test]
async fn test_pending_counts() {
let (tx, _rx) = mpsc::channel(10);
let registry = PermissionRegistry::new(tx);
assert_eq!(registry.pending_count().await, 0);
let request = create_read_request("req-1", "/project/file.rs");
let _ = registry.request_permission(1, request, None).await;
assert_eq!(registry.pending_count().await, 1);
let ids = registry.pending_request_ids(1).await;
assert_eq!(ids.len(), 1);
assert_eq!(ids[0], "req-1");
}
}