use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use tokio::sync::{Mutex, mpsc, oneshot, watch};
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum ApproveMode {
Yolo = 0,
#[default]
Auto = 1,
Manual = 2,
}
impl ApproveMode {
fn from_u8(v: u8) -> Self {
match v {
0 => Self::Yolo,
2 => Self::Manual,
_ => Self::Auto,
}
}
pub fn next(self) -> Self {
match self {
Self::Manual => Self::Auto,
Self::Auto => Self::Yolo,
Self::Yolo => Self::Manual,
}
}
}
impl std::fmt::Display for ApproveMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Yolo => write!(f, "yolo"),
Self::Auto => write!(f, "auto"),
Self::Manual => write!(f, "manual"),
}
}
}
impl std::str::FromStr for ApproveMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"yolo" => Ok(Self::Yolo),
"auto" => Ok(Self::Auto),
"manual" => Ok(Self::Manual),
_ => Err(format!("unknown approve mode: {s}")),
}
}
}
const SAFE_TOOLS: &[&str] = &[
"file_read",
"search",
"search_symbol",
"list_dir",
"git_diff",
"git_log",
"git_status",
"codebase_summary",
];
static SAFE_TOOLS_SET: std::sync::LazyLock<HashSet<&'static str>> =
std::sync::LazyLock::new(|| SAFE_TOOLS.iter().copied().collect());
pub fn is_safe_tool(name: &str) -> bool {
if SAFE_TOOLS_SET.contains(name) {
return true;
}
if name.starts_with("mcp__") {
let parts: Vec<&str> = name.split("__").collect();
if parts.len() >= 3 {
let op = parts.last().unwrap_or(&"");
return matches!(
*op,
"search" | "read" | "get" | "list" | "query" | "find" | "resolve"
);
}
}
false
}
pub fn needs_approval(mode: ApproveMode, tool_name: &str) -> bool {
match mode {
ApproveMode::Yolo => false,
ApproveMode::Auto => !is_safe_tool(tool_name),
ApproveMode::Manual => true,
}
}
const DESTRUCTIVE_PATTERNS: &[&str] = &[
"git push --force",
"git push -f ",
"git push -f\n",
"git push -f\"",
"git push origin +", "git reset --hard",
"git branch -D ",
"git branch -d ",
"drop table",
"drop database",
"drop schema",
"dd if=", "mkfs.", "mkfs ",
"shutdown -", "reboot\n",
"reboot ",
"reboot\"",
" halt\n",
" halt ",
"poweroff\n",
"poweroff ",
":(){ :|:",
"chmod -r 777 /",
"chmod -r 000 /",
"chown -r nobody /",
"chown -r root /",
"| bash",
"| sh\n",
"| sh ",
"| sh\"",
"|bash",
"|sh",
"kill -9 -1", "kill -9 1", "pkill -9 ", "killall -9 ", "crontab -r", "iptables -f", "iptables --flush",
"systemctl mask ",
"systemctl disable ",
"history -c", "shred ", ];
pub fn is_destructive_command(tool_name: &str, tool_args: &str) -> bool {
if tool_name != "bash" {
return false;
}
let lower = tool_args.to_lowercase();
if DESTRUCTIVE_PATTERNS.iter().any(|p| lower.contains(p)) {
return true;
}
if lower.contains("rm -rf /") || lower.contains("rm -fr /") {
if lower.contains("rm -rf / ")
|| lower.contains("rm -rf /\"")
|| lower.contains("rm -rf /\n")
|| lower.ends_with("rm -rf /")
|| lower.contains("rm -fr / ")
|| lower.contains("rm -fr /\"")
|| lower.contains("rm -fr /\n")
|| lower.ends_with("rm -fr /")
{
return true;
}
}
if lower.contains("rm -rf ~/") || lower.contains("rm -fr ~/") {
return true;
}
false
}
pub(crate) fn normalize_path_lexical(path: &std::path::Path) -> std::path::PathBuf {
use std::path::Component;
let mut out = std::path::PathBuf::new();
for component in path.components() {
match component {
Component::ParentDir => {
out.pop();
}
Component::CurDir => {}
c => out.push(c),
}
}
out
}
pub fn is_path_denied(resolved_path: &str, deny_patterns: &[String]) -> bool {
let home = dirs::home_dir().map(|h| h.to_string_lossy().to_string());
for pattern in deny_patterns {
let expanded: std::borrow::Cow<str> = if pattern.starts_with("~/") {
let Some(ref h) = home else { continue };
format!("{}{}", h, &pattern[1..]).into()
} else {
pattern.as_str().into()
};
if let Some(glob_suffix) = expanded.strip_prefix("**/") {
let suffix = format!("/{glob_suffix}");
if resolved_path.ends_with(suffix.as_str()) || resolved_path == glob_suffix {
return true;
}
} else {
let prefix_slash = format!("{}/", expanded);
if resolved_path == expanded.as_ref()
|| resolved_path.starts_with(prefix_slash.as_str())
{
return true;
}
}
}
false
}
pub fn is_path_outside_workdir(tool_name: &str, tool_args: &str, working_dir: &str) -> bool {
if !matches!(
tool_name,
"file_read" | "file_write" | "file_edit" | "git_patch"
) {
return false;
}
let Ok(val) = serde_json::from_str::<serde_json::Value>(tool_args) else {
return false;
};
let Some(path) = val.get("path").and_then(|p| p.as_str()) else {
return false;
};
let candidate = if std::path::Path::new(path).is_absolute() {
std::path::PathBuf::from(path)
} else {
std::path::Path::new(working_dir).join(path)
};
let normalized = normalize_path_lexical(&candidate);
let workdir_norm = normalize_path_lexical(std::path::Path::new(working_dir));
!normalized.starts_with(&workdir_norm)
}
#[derive(Default)]
struct SessionApprovalsInner {
allowed: HashSet<String>,
denied: HashSet<String>,
pending: HashMap<String, watch::Sender<Option<bool>>>,
}
pub enum SessionCheckResult {
Allowed,
Denied,
NeedsApproval,
WaitForResult(watch::Receiver<Option<bool>>),
}
#[derive(Clone, Default)]
pub struct SessionApprovals(Arc<Mutex<SessionApprovalsInner>>);
impl SessionApprovals {
pub fn new() -> Self {
Self::default()
}
pub async fn pre_check(&self, tool_name: &str) -> SessionCheckResult {
let mut inner = self.0.lock().await;
if inner.allowed.contains(tool_name) {
return SessionCheckResult::Allowed;
}
if inner.denied.contains(tool_name) {
return SessionCheckResult::Denied;
}
if let Some(tx) = inner.pending.get(tool_name) {
return SessionCheckResult::WaitForResult(tx.subscribe());
}
let (tx, _rx) = watch::channel(None::<bool>);
inner.pending.insert(tool_name.to_string(), tx);
SessionCheckResult::NeedsApproval
}
pub async fn resolve(&self, tool_name: &str, approved: bool) {
let mut inner = self.0.lock().await;
if approved {
inner.allowed.insert(tool_name.to_string());
} else {
inner.denied.insert(tool_name.to_string());
}
if let Some(tx) = inner.pending.remove(tool_name) {
let _ = tx.send(Some(approved));
}
}
pub async fn resolve_once(&self, tool_name: &str) {
let mut inner = self.0.lock().await;
if let Some(tx) = inner.pending.remove(tool_name) {
let _ = tx.send(Some(true));
}
}
pub async fn clear(&self) {
let mut inner = self.0.lock().await;
inner.allowed.clear();
inner.denied.clear();
for (_, tx) in inner.pending.drain() {
let _ = tx.send(Some(false));
}
}
}
#[derive(Debug)]
pub struct ApprovalRequest {
pub tool_name: String,
pub tool_args: String,
pub response_tx: oneshot::Sender<ApprovalResponse>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApprovalResponse {
Approve,
Deny,
ApproveAll,
}
#[derive(Clone)]
pub struct SharedApproveMode(Arc<AtomicU8>);
impl SharedApproveMode {
pub fn new(mode: ApproveMode) -> Self {
Self(Arc::new(AtomicU8::new(mode as u8)))
}
pub fn get(&self) -> ApproveMode {
ApproveMode::from_u8(self.0.load(Ordering::Acquire))
}
pub fn set(&self, mode: ApproveMode) {
self.0.store(mode as u8, Ordering::Release);
}
pub fn cycle(&self) -> ApproveMode {
let current = self.get();
let next = current.next();
self.set(next);
next
}
}
#[derive(Clone)]
pub struct ApprovalGate {
mode: SharedApproveMode,
request_tx: Option<Arc<mpsc::UnboundedSender<ApprovalRequest>>>,
session_approvals: Option<SessionApprovals>,
}
impl ApprovalGate {
pub fn yolo() -> Self {
Self {
mode: SharedApproveMode::new(ApproveMode::Yolo),
request_tx: None,
session_approvals: None,
}
}
pub fn headless(mode: SharedApproveMode) -> Self {
Self {
mode,
request_tx: None,
session_approvals: None,
}
}
pub fn new(
mode: SharedApproveMode,
request_tx: mpsc::UnboundedSender<ApprovalRequest>,
) -> Self {
Self {
mode: mode.clone(),
request_tx: Some(Arc::new(request_tx)),
session_approvals: None,
}
}
pub fn new_with_session(
mode: SharedApproveMode,
request_tx: mpsc::UnboundedSender<ApprovalRequest>,
session_approvals: SessionApprovals,
) -> Self {
Self {
mode: mode.clone(),
request_tx: Some(Arc::new(request_tx)),
session_approvals: Some(session_approvals),
}
}
pub fn with_session_approvals(mut self, sa: SessionApprovals) -> Self {
self.session_approvals = Some(sa);
self
}
pub fn mode(&self) -> ApproveMode {
self.mode.get()
}
pub fn shared_mode(&self) -> SharedApproveMode {
self.mode.clone()
}
pub async fn check(
&self,
tool_name: &str,
tool_args: &str,
working_dir: &str,
) -> ApprovalResponse {
let current_mode = self.mode.get();
let destructive = is_destructive_command(tool_name, tool_args);
let outside_workdir = current_mode != ApproveMode::Yolo
&& is_path_outside_workdir(tool_name, tool_args, working_dir);
if !destructive && !outside_workdir && !needs_approval(current_mode, tool_name) {
return ApprovalResponse::Approve;
}
if !destructive && let Some(ref sa) = self.session_approvals {
match sa.pre_check(tool_name).await {
SessionCheckResult::Allowed => return ApprovalResponse::Approve,
SessionCheckResult::Denied => return ApprovalResponse::Deny,
SessionCheckResult::WaitForResult(mut rx) => {
let _ = rx.wait_for(|v| v.is_some()).await;
return match *rx.borrow() {
Some(true) => ApprovalResponse::Approve,
_ => ApprovalResponse::Deny,
};
}
SessionCheckResult::NeedsApproval => {
}
}
}
let Some(tx) = &self.request_tx else {
if destructive {
tracing::warn!("No approval frontend, denying destructive command: {tool_name}");
return ApprovalResponse::Deny;
}
if !is_safe_tool(tool_name) {
tracing::warn!(
"No approval frontend, denying non-safe tool in headless mode: {tool_name}"
);
return ApprovalResponse::Deny;
}
if let Some(ref sa) = self.session_approvals {
sa.resolve(tool_name, true).await;
}
tracing::warn!("No approval frontend, auto-approving safe tool: {tool_name}");
return ApprovalResponse::Approve;
};
let (response_tx, response_rx) = oneshot::channel();
let request = ApprovalRequest {
tool_name: tool_name.to_string(),
tool_args: tool_args.to_string(),
response_tx,
};
if tx.send(request).is_err() {
tracing::warn!("Approval frontend disconnected, auto-approving {tool_name}");
if let Some(ref sa) = self.session_approvals {
sa.resolve(tool_name, true).await;
}
return ApprovalResponse::Approve;
}
let response = match response_rx.await {
Ok(r) => r,
Err(_) => {
tracing::warn!("Approval response channel dropped for {tool_name}, auto-approving");
ApprovalResponse::Approve
}
};
if let Some(ref sa) = self.session_approvals {
match response {
ApprovalResponse::Approve if tool_name == "bash" => {
sa.resolve_once(tool_name).await;
}
_ => {
let approved = matches!(
response,
ApprovalResponse::Approve | ApprovalResponse::ApproveAll
);
sa.resolve(tool_name, approved).await;
}
}
}
response
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_safe_tools() {
assert!(is_safe_tool("file_read"));
assert!(is_safe_tool("search"));
assert!(!is_safe_tool("file_write"));
assert!(!is_safe_tool("bash"));
}
#[test]
fn test_mcp_safe_tools() {
assert!(is_safe_tool("mcp__serena__search"));
assert!(is_safe_tool("mcp__context7__query"));
assert!(!is_safe_tool("mcp__serena__edit"));
}
#[test]
fn test_needs_approval() {
assert!(!needs_approval(ApproveMode::Yolo, "bash"));
assert!(!needs_approval(ApproveMode::Auto, "file_read"));
assert!(needs_approval(ApproveMode::Auto, "file_write"));
assert!(needs_approval(ApproveMode::Auto, "bash"));
assert!(needs_approval(ApproveMode::Manual, "file_read"));
}
#[test]
fn test_mode_cycle() {
assert_eq!(ApproveMode::Manual.next(), ApproveMode::Auto);
assert_eq!(ApproveMode::Auto.next(), ApproveMode::Yolo);
assert_eq!(ApproveMode::Yolo.next(), ApproveMode::Manual);
}
#[test]
fn test_shared_mode() {
let shared = SharedApproveMode::new(ApproveMode::Manual);
assert_eq!(shared.get(), ApproveMode::Manual);
assert_eq!(shared.cycle(), ApproveMode::Auto);
assert_eq!(shared.get(), ApproveMode::Auto);
}
#[test]
fn test_is_destructive_command() {
assert!(!is_destructive_command("file_write", r#"{"path":"x"}"#));
assert!(!is_destructive_command(
"bash",
r#"{"command":"git status"}"#
));
assert!(!is_destructive_command(
"bash",
r#"{"command":"git push origin main"}"#
));
assert!(!is_destructive_command(
"bash",
r#"{"command":"rm -r build/"}"#
));
assert!(!is_destructive_command(
"bash",
r#"{"command":"rm -rf /tmp/build"}"#
));
assert!(!is_destructive_command(
"bash",
r#"{"command":"rm -rf ./dist"}"#
));
assert!(!is_destructive_command(
"bash",
r#"{"command":"rm -rf node_modules"}"#
));
assert!(is_destructive_command("bash", r#"{"command":"rm -rf /"}"#));
assert!(is_destructive_command("bash", r#"{"command":"rm -fr /"}"#));
assert!(is_destructive_command("bash", r#"{"command":"rm -rf ~/"}"#));
assert!(is_destructive_command("bash", r#"{"command":"rm -fr ~/"}"#));
assert!(is_destructive_command(
"bash",
r#"{"command":"git push --force origin main"}"#
));
assert!(is_destructive_command(
"bash",
r#"{"command":"git push -f origin main"}"#
));
assert!(is_destructive_command(
"bash",
r#"{"command":"git reset --hard HEAD~1"}"#
));
assert!(is_destructive_command(
"bash",
r#"{"command":"git branch -D feature/old"}"#
));
assert!(is_destructive_command(
"bash",
r#"{"command":"DROP TABLE users;"}"#
));
}
#[test]
fn test_is_path_outside_workdir() {
let wd = "/home/user/project";
assert!(!is_path_outside_workdir(
"file_read",
r#"{"path":"src/main.rs"}"#,
wd
));
assert!(!is_path_outside_workdir(
"file_read",
r#"{"path":"/home/user/project/src/main.rs"}"#,
wd
));
assert!(is_path_outside_workdir(
"file_read",
r#"{"path":"/etc/passwd"}"#,
wd
));
assert!(is_path_outside_workdir(
"file_write",
r#"{"path":"../../etc/shadow"}"#,
wd
));
assert!(!is_path_outside_workdir(
"bash",
r#"{"command":"cat /etc/passwd"}"#,
wd
));
assert!(!is_path_outside_workdir("file_read", r#"{}"#, wd));
}
#[tokio::test]
async fn test_auto_gate_file_read_outside_workdir_requires_approval() {
let shared = SharedApproveMode::new(ApproveMode::Auto);
let (tx, mut rx) = mpsc::unbounded_channel::<ApprovalRequest>();
let gate = ApprovalGate::new(shared, tx);
tokio::spawn(async move {
let req = rx.recv().await.unwrap();
assert_eq!(req.tool_name, "file_read");
req.response_tx.send(ApprovalResponse::Approve).unwrap();
});
let result = gate
.check(
"file_read",
r#"{"path":"/etc/passwd"}"#,
"/home/user/project",
)
.await;
assert_eq!(result, ApprovalResponse::Approve);
}
#[tokio::test]
async fn test_yolo_gate_file_read_outside_workdir_auto_approved() {
let gate = ApprovalGate::yolo();
let result = gate
.check(
"file_read",
r#"{"path":"/etc/passwd"}"#,
"/home/user/project",
)
.await;
assert_eq!(result, ApprovalResponse::Approve);
}
#[tokio::test]
async fn test_yolo_gate_safe() {
let gate = ApprovalGate::yolo();
assert_eq!(
gate.check("bash", r#"{"command":"git status"}"#, "/tmp")
.await,
ApprovalResponse::Approve
);
}
#[tokio::test]
async fn test_yolo_gate_destructive_no_frontend() {
let gate = ApprovalGate::yolo();
assert_eq!(
gate.check("bash", r#"{"command":"git reset --hard HEAD"}"#, "/tmp")
.await,
ApprovalResponse::Deny
);
}
#[tokio::test]
async fn test_yolo_gate_destructive_with_frontend() {
let shared = SharedApproveMode::new(ApproveMode::Yolo);
let (tx, mut rx) = mpsc::unbounded_channel();
let gate = ApprovalGate::new(shared, tx);
tokio::spawn(async move {
let req = rx.recv().await.unwrap();
assert!(req.tool_args.contains("reset --hard"));
req.response_tx.send(ApprovalResponse::Approve).unwrap();
});
assert_eq!(
gate.check("bash", r#"{"command":"git reset --hard HEAD"}"#, "/tmp")
.await,
ApprovalResponse::Approve
);
}
#[tokio::test]
async fn test_auto_gate_safe_tool() {
let shared = SharedApproveMode::new(ApproveMode::Auto);
let (tx, _rx) = mpsc::unbounded_channel();
let gate = ApprovalGate::new(shared, tx);
assert_eq!(
gate.check("file_read", "{}", "/tmp").await,
ApprovalResponse::Approve
);
}
#[tokio::test]
async fn test_manual_gate_sends_request() {
let shared = SharedApproveMode::new(ApproveMode::Manual);
let (tx, mut rx) = mpsc::unbounded_channel();
let gate = ApprovalGate::new(shared, tx);
tokio::spawn(async move {
let req = rx.recv().await.unwrap();
assert_eq!(req.tool_name, "bash");
req.response_tx.send(ApprovalResponse::Deny).unwrap();
});
let result = gate.check("bash", "{}", "/tmp").await;
assert_eq!(result, ApprovalResponse::Deny);
}
#[tokio::test]
async fn test_runtime_mode_switch() {
let shared = SharedApproveMode::new(ApproveMode::Yolo);
let (tx, _rx) = mpsc::unbounded_channel();
let gate = ApprovalGate::new(shared.clone(), tx);
assert_eq!(
gate.check("bash", "{}", "/tmp").await,
ApprovalResponse::Approve
);
shared.set(ApproveMode::Manual);
assert_eq!(gate.mode(), ApproveMode::Manual);
}
#[tokio::test]
async fn test_session_approvals_cache() {
let sa = SessionApprovals::new();
let shared = SharedApproveMode::new(ApproveMode::Auto);
let (tx, mut rx) = mpsc::unbounded_channel::<ApprovalRequest>();
let gate = ApprovalGate::new_with_session(shared, tx, sa.clone());
tokio::spawn(async move {
let req = rx.recv().await.unwrap();
req.response_tx.send(ApprovalResponse::Approve).unwrap();
});
assert_eq!(
gate.check("bash", "{}", "/tmp").await,
ApprovalResponse::Approve
);
assert_eq!(
gate.check("bash", "{}", "/tmp").await,
ApprovalResponse::Approve
);
}
#[tokio::test]
async fn test_session_approvals_deny_cached() {
let sa = SessionApprovals::new();
let shared = SharedApproveMode::new(ApproveMode::Manual);
let (tx, mut rx) = mpsc::unbounded_channel::<ApprovalRequest>();
let gate = ApprovalGate::new_with_session(shared, tx, sa.clone());
tokio::spawn(async move {
let req = rx.recv().await.unwrap();
req.response_tx.send(ApprovalResponse::Deny).unwrap();
});
assert_eq!(
gate.check("file_write", "{}", "/tmp").await,
ApprovalResponse::Deny
);
assert_eq!(
gate.check("file_write", "{}", "/tmp").await,
ApprovalResponse::Deny
);
}
#[tokio::test]
async fn test_parallel_agents_share_approval() {
let sa = SessionApprovals::new();
let shared = SharedApproveMode::new(ApproveMode::Auto);
let (tx, mut rx) = mpsc::unbounded_channel::<ApprovalRequest>();
let gate_a = ApprovalGate::new_with_session(shared.clone(), tx.clone(), sa.clone());
let gate_b = ApprovalGate::new_with_session(shared.clone(), tx, sa.clone());
tokio::spawn(async move {
let req = rx.recv().await.unwrap();
req.response_tx.send(ApprovalResponse::Approve).unwrap();
assert!(
tokio::time::timeout(std::time::Duration::from_millis(50), rx.recv())
.await
.is_err(),
"Expected only one approval popup for parallel agents"
);
});
let (r_a, r_b) = tokio::join!(
gate_a.check("bash", "{}", "/tmp"),
gate_b.check("bash", "{}", "/tmp"),
);
assert_eq!(r_a, ApprovalResponse::Approve);
assert_eq!(r_b, ApprovalResponse::Approve);
}
fn patterns(strs: &[&str]) -> Vec<String> {
strs.iter().map(|s| s.to_string()).collect()
}
#[test]
fn test_deny_prefix_exact_match() {
let p = patterns(&["/etc/passwd"]);
assert!(is_path_denied("/etc/passwd", &p));
}
#[test]
fn test_deny_prefix_child_match() {
let p = patterns(&["/etc"]);
assert!(is_path_denied("/etc/shadow", &p));
assert!(is_path_denied("/etc/ssh/sshd_config", &p));
}
#[test]
fn test_deny_prefix_no_partial_match() {
let p = patterns(&["/etc"]);
assert!(!is_path_denied("/etcfoo/bar", &p));
}
#[test]
fn test_deny_glob_suffix() {
let p = patterns(&["**/.env"]);
assert!(is_path_denied("/project/backend/.env", &p));
assert!(is_path_denied("/any/deeply/nested/.env", &p));
assert!(!is_path_denied("/project/.env.local", &p));
}
#[test]
fn test_deny_empty_patterns() {
assert!(!is_path_denied("/etc/passwd", &[]));
}
#[test]
fn test_deny_no_match() {
let p = patterns(&["/etc", "~/.ssh"]);
assert!(!is_path_denied("/home/user/projects/main.rs", &p));
}
}