use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use motosan_agent_loop::core::decision::ToolDecision;
use motosan_agent_loop::core::ext_error::ExtError;
use motosan_agent_loop::core::extension::Extension;
use motosan_agent_loop::core::hook_ctx::HookCtx;
use motosan_agent_loop::llm::ToolCallItem;
use motosan_agent_tool::ToolResult;
use serde_json::Value;
use tokio::sync::{mpsc, oneshot};
use super::policy::Policy;
use super::session_cache::SessionCache;
use super::Decision;
use crate::events::UiEvent;
pub enum PromptStrategy {
Prompt(mpsc::Sender<UiEvent>),
HeadlessDeny,
}
pub struct PermissionExtension {
policy: Arc<Policy>,
cache: Arc<SessionCache>,
project_root: PathBuf,
prompt: PromptStrategy,
}
impl PermissionExtension {
pub fn new(
policy: Arc<Policy>,
cache: Arc<SessionCache>,
project_root: PathBuf,
ui_tx: mpsc::Sender<UiEvent>,
) -> Self {
Self {
policy,
cache,
project_root,
prompt: PromptStrategy::Prompt(ui_tx),
}
}
pub fn headless(policy: Arc<Policy>, cache: Arc<SessionCache>, project_root: PathBuf) -> Self {
Self {
policy,
cache,
project_root,
prompt: PromptStrategy::HeadlessDeny,
}
}
async fn decide(&self, tool_name: &str, args: &Value) -> Decision {
if matches!(tool_name, "write" | "edit") {
if let Some(path) = args.get("path").and_then(|v| v.as_str()) {
let abs = if std::path::Path::new(path).is_absolute() {
PathBuf::from(path)
} else {
self.project_root.join(path)
};
let blocked = match tool_name {
"edit" => self.policy.edit_is_blocked(&abs, &self.project_root),
_ => self.policy.write_is_blocked(&abs, &self.project_root),
};
if blocked {
return Decision::Denied(format!("{} is in a blocked path", abs.display()));
}
}
}
let policy_allowed = match tool_name {
"bash" => args
.get("command")
.and_then(|v| v.as_str())
.map(|c| self.policy.bash_is_allowed(c))
.unwrap_or(false),
"write" | "edit" => args
.get("path")
.and_then(|v| v.as_str())
.map(|p| {
let abs = std::path::PathBuf::from(p);
let abs = if abs.is_absolute() {
abs
} else {
self.project_root.join(&abs)
};
match tool_name {
"edit" => self.policy.edit_is_allowed(&abs, &self.project_root),
_ => self.policy.write_is_allowed(&abs, &self.project_root),
}
})
.unwrap_or(false),
"read" | "grep" | "find" | "ls" => return Decision::Allowed,
other if other.contains("__") => {
let mut parts = other.splitn(2, "__");
let server = parts.next().unwrap_or("");
let tool = parts.next().unwrap_or("");
self.policy.mcp_auto_allow(server, tool)
}
_ => false,
};
if policy_allowed {
return Decision::Allowed;
}
let cache_key = SessionCache::key(tool_name, args);
if let Some(cached) = self.cache.get(&cache_key) {
return cached;
}
match &self.prompt {
PromptStrategy::HeadlessDeny => {
Decision::Denied("non-interactive: tool requires approval".into())
}
PromptStrategy::Prompt(ui_tx) => {
let (resolver_tx, resolver_rx) = oneshot::channel::<Decision>();
if ui_tx
.send(UiEvent::PermissionRequested {
tool: tool_name.to_string(),
args: args.clone(),
resolver: resolver_tx,
})
.await
.is_err()
{
return Decision::Denied("no UI channel to prompt".into());
}
resolver_rx
.await
.unwrap_or(Decision::Denied("prompt cancelled".into()))
}
}
}
}
#[async_trait]
impl Extension for PermissionExtension {
fn name(&self) -> &'static str {
"capo-permissions"
}
async fn intercept_tool_call(
&mut self,
call: ToolCallItem,
_ctx: &mut HookCtx<'_>,
) -> Result<ToolDecision, ExtError> {
match self.decide(&call.name, &call.args).await {
Decision::Allowed => Ok(ToolDecision::Proceed(call)),
Decision::Denied(reason) => Ok(ToolDecision::ShortCircuit(ToolResult::error(format!(
"Permission denied: {reason}"
)))),
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio::sync::mpsc;
use super::*;
#[tokio::test]
async fn session_cache_short_circuits_prompt() {
let policy = Arc::new(Policy::default());
let cache = Arc::new(SessionCache::new());
let args = serde_json::json!({"command": "curl https://example.com"});
cache.insert(SessionCache::key("bash", &args), Decision::Allowed);
let (ui_tx, mut ui_rx) = mpsc::channel::<UiEvent>(4);
let ext = PermissionExtension::new(
Arc::clone(&policy),
Arc::clone(&cache),
std::env::current_dir().unwrap_or_default(),
ui_tx,
);
let decision = ext.decide("bash", &args).await;
assert!(matches!(decision, Decision::Allowed));
assert!(ui_rx.try_recv().is_err());
}
#[tokio::test]
async fn grep_find_ls_are_auto_allowed() {
let policy = Arc::new(Policy::default());
let cache = Arc::new(SessionCache::new());
let (ui_tx, mut ui_rx) = mpsc::channel::<UiEvent>(4);
let ext = PermissionExtension::new(
Arc::clone(&policy),
Arc::clone(&cache),
std::env::current_dir().unwrap_or_default(),
ui_tx,
);
for tool in ["grep", "find", "ls"] {
let decision = ext.decide(tool, &serde_json::json!({})).await;
assert!(
matches!(decision, Decision::Allowed),
"{tool} not auto-allowed"
);
}
assert!(ui_rx.try_recv().is_err());
}
#[tokio::test]
async fn headless_denies_a_would_prompt_tool_but_keeps_auto_allows() {
let policy = Arc::new(Policy::default());
let cache = Arc::new(SessionCache::new());
let ext = PermissionExtension::headless(
Arc::clone(&policy),
Arc::clone(&cache),
std::env::current_dir().unwrap_or_default(),
);
let denied = ext
.decide("bash", &serde_json::json!({"command": "curl https://x"}))
.await;
assert!(matches!(denied, Decision::Denied(_)));
let allowed = ext.decide("read", &serde_json::json!({})).await;
assert!(matches!(allowed, Decision::Allowed));
}
}