use std::collections::HashSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use rig::agent::{HookAction, PromptHook, ToolCallHookAction};
use rig::completion::CompletionModel;
use rig::message::Message;
use tokio::sync::{Mutex, mpsc, oneshot};
use crate::action::Action;
use crate::approval::display::format_tool_display;
use crate::config::types::ApprovalMode;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolCategory {
Read,
Write,
Shell,
}
pub fn classify_tool(name: &str) -> ToolCategory {
match name {
"read" | "grep" | "glob" | "ls" | "web_fetch" | "web_search" | "save_memory" => {
ToolCategory::Read
}
"write" | "edit" => ToolCategory::Write,
_ => ToolCategory::Shell,
}
}
pub fn matches_deny_rule(deny_rules: &[String], args_json: &str) -> Option<String> {
let command = serde_json::from_str::<serde_json::Value>(args_json)
.ok()
.and_then(|v| v.get("command")?.as_str().map(String::from))?;
deny_rules
.iter()
.find(|rule| command.contains(rule.as_str()))
.cloned()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApprovalDecision {
Approve,
Deny,
ApproveAll,
}
pub struct ApprovalRequest {
pub tool_name: String,
pub args_json: String,
pub formatted_display: String,
pub response_tx: oneshot::Sender<ApprovalDecision>,
}
#[derive(Clone)]
pub struct ApprovalHook {
approval_tx: mpsc::UnboundedSender<ApprovalRequest>,
action_tx: mpsc::UnboundedSender<Action>,
mode: ApprovalMode,
deny_rules: Vec<String>,
approved_all: Arc<Mutex<HashSet<String>>>,
turn_counter: Arc<AtomicUsize>,
max_turns: usize,
}
impl ApprovalHook {
pub fn new(
mode: ApprovalMode,
deny_rules: Vec<String>,
approval_tx: mpsc::UnboundedSender<ApprovalRequest>,
action_tx: mpsc::UnboundedSender<Action>,
max_turns: usize,
) -> Self {
Self {
approval_tx,
action_tx,
mode,
deny_rules,
approved_all: Arc::new(Mutex::new(HashSet::new())),
turn_counter: Arc::new(AtomicUsize::new(0)),
max_turns,
}
}
pub fn turn_counter(&self) -> Arc<AtomicUsize> {
self.turn_counter.clone()
}
pub fn turn_count(&self) -> usize {
self.turn_counter.load(Ordering::Relaxed)
}
pub fn max_turns_for_display(&self) -> usize {
self.max_turns
}
pub fn should_auto_decide(
&self,
tool_name: &str,
args_json: &str,
) -> Option<ToolCallHookAction> {
let category = classify_tool(tool_name);
if category == ToolCategory::Shell
&& let Some(rule) = matches_deny_rule(&self.deny_rules, args_json)
{
return Some(ToolCallHookAction::skip(format!(
"Command blocked by deny rule: {rule}"
)));
}
match self.mode {
ApprovalMode::Yolo => return Some(ToolCallHookAction::Continue),
ApprovalMode::Plan => {
return if category == ToolCategory::Read {
Some(ToolCallHookAction::Continue)
} else {
Some(ToolCallHookAction::skip(
"Tool execution denied in Plan mode (read-only)",
))
};
}
ApprovalMode::AutoEdit => {
if category == ToolCategory::Read || category == ToolCategory::Write {
return Some(ToolCallHookAction::Continue);
}
}
ApprovalMode::Default => {
if category == ToolCategory::Read {
return Some(ToolCallHookAction::Continue);
}
}
}
None
}
}
impl<M: CompletionModel> PromptHook<M> for ApprovalHook {
async fn on_tool_call(
&self,
tool_name: &str,
_tool_call_id: Option<String>,
_internal_call_id: &str,
args: &str,
) -> ToolCallHookAction {
if let Some(action) = self.should_auto_decide(tool_name, args) {
if let ToolCallHookAction::Skip { reason } = &action {
let _ = self.action_tx.send(Action::ToolDenied {
name: tool_name.to_string(),
reason: reason.clone(),
});
}
return action;
}
{
let approved = self.approved_all.lock().await;
if approved.contains(tool_name) {
return ToolCallHookAction::Continue;
}
}
let (response_tx, response_rx) = oneshot::channel();
let request = ApprovalRequest {
tool_name: tool_name.to_string(),
args_json: args.to_string(),
formatted_display: format_tool_display(tool_name, args),
response_tx,
};
if self.approval_tx.send(request).is_err() {
return ToolCallHookAction::skip("Approval channel closed");
}
match response_rx.await {
Ok(ApprovalDecision::Approve) => ToolCallHookAction::Continue,
Ok(ApprovalDecision::Deny) => ToolCallHookAction::skip("Tool execution denied by user"),
Ok(ApprovalDecision::ApproveAll) => {
let mut approved = self.approved_all.lock().await;
approved.insert(tool_name.to_string());
ToolCallHookAction::Continue
}
Err(_) => ToolCallHookAction::skip("Approval request cancelled"),
}
}
async fn on_completion_call(&self, _prompt: &Message, _history: &[Message]) -> HookAction {
self.turn_counter.fetch_add(1, Ordering::Relaxed);
HookAction::cont()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classify_read_tools() {
assert_eq!(classify_tool("read"), ToolCategory::Read);
assert_eq!(classify_tool("grep"), ToolCategory::Read);
assert_eq!(classify_tool("glob"), ToolCategory::Read);
assert_eq!(classify_tool("ls"), ToolCategory::Read);
assert_eq!(classify_tool("web_fetch"), ToolCategory::Read);
assert_eq!(classify_tool("web_search"), ToolCategory::Read);
}
#[test]
fn test_classify_write_tools() {
assert_eq!(classify_tool("write"), ToolCategory::Write);
assert_eq!(classify_tool("edit"), ToolCategory::Write);
}
#[test]
fn test_classify_shell() {
assert_eq!(classify_tool("shell"), ToolCategory::Shell);
}
#[test]
fn test_classify_save_memory_as_read() {
assert_eq!(classify_tool("save_memory"), ToolCategory::Read);
}
#[test]
fn test_classify_unknown_defaults_to_shell() {
assert_eq!(classify_tool("unknown_tool"), ToolCategory::Shell);
assert_eq!(classify_tool(""), ToolCategory::Shell);
assert_eq!(classify_tool("custom"), ToolCategory::Shell);
}
#[test]
fn test_deny_rule_match() {
let rules = vec!["rm -rf /".to_string()];
let args = r#"{"command": "rm -rf /"}"#;
assert_eq!(
matches_deny_rule(&rules, args),
Some("rm -rf /".to_string())
);
}
#[test]
fn test_deny_rule_no_match() {
let rules = vec!["rm -rf /".to_string()];
let args = r#"{"command": "ls -la"}"#;
assert_eq!(matches_deny_rule(&rules, args), None);
}
#[test]
fn test_deny_rule_substring_match() {
let rules = vec!["rm -rf /".to_string()];
let args = r#"{"command": "sudo rm -rf /home"}"#;
assert_eq!(
matches_deny_rule(&rules, args),
Some("rm -rf /".to_string())
);
}
#[test]
fn test_deny_rule_no_command_field() {
let rules = vec!["rm -rf /".to_string()];
let args = r#"{"not_command": "rm -rf /"}"#;
assert_eq!(matches_deny_rule(&rules, args), None);
}
#[test]
fn test_deny_rule_invalid_json() {
let rules = vec!["rm -rf /".to_string()];
assert_eq!(matches_deny_rule(&rules, "not json"), None);
}
#[test]
fn test_deny_rule_empty_rules() {
let rules: Vec<String> = vec![];
let args = r#"{"command": "rm -rf /"}"#;
assert_eq!(matches_deny_rule(&rules, args), None);
}
fn make_hook(mode: ApprovalMode) -> ApprovalHook {
let (tx, _rx) = mpsc::unbounded_channel();
let (atx, _arx) = mpsc::unbounded_channel();
ApprovalHook::new(mode, vec![], tx, atx, 25)
}
fn make_hook_with_deny(mode: ApprovalMode, deny_rules: Vec<String>) -> ApprovalHook {
let (tx, _rx) = mpsc::unbounded_channel();
let (atx, _arx) = mpsc::unbounded_channel();
ApprovalHook::new(mode, deny_rules, tx, atx, 25)
}
#[test]
fn test_yolo_approves_read() {
let hook = make_hook(ApprovalMode::Yolo);
let result = hook.should_auto_decide("read", "{}");
assert_eq!(result, Some(ToolCallHookAction::Continue));
}
#[test]
fn test_yolo_approves_write() {
let hook = make_hook(ApprovalMode::Yolo);
let result = hook.should_auto_decide("write", "{}");
assert_eq!(result, Some(ToolCallHookAction::Continue));
}
#[test]
fn test_yolo_approves_shell() {
let hook = make_hook(ApprovalMode::Yolo);
let result = hook.should_auto_decide("shell", r#"{"command": "ls"}"#);
assert_eq!(result, Some(ToolCallHookAction::Continue));
}
#[test]
fn test_plan_approves_read() {
let hook = make_hook(ApprovalMode::Plan);
let result = hook.should_auto_decide("read", "{}");
assert_eq!(result, Some(ToolCallHookAction::Continue));
}
#[test]
fn test_plan_denies_write() {
let hook = make_hook(ApprovalMode::Plan);
let result = hook.should_auto_decide("write", "{}");
assert!(matches!(result, Some(ToolCallHookAction::Skip { .. })));
}
#[test]
fn test_plan_denies_shell() {
let hook = make_hook(ApprovalMode::Plan);
let result = hook.should_auto_decide("shell", r#"{"command": "ls"}"#);
assert!(matches!(result, Some(ToolCallHookAction::Skip { .. })));
}
#[test]
fn test_autoedit_approves_read() {
let hook = make_hook(ApprovalMode::AutoEdit);
let result = hook.should_auto_decide("read", "{}");
assert_eq!(result, Some(ToolCallHookAction::Continue));
}
#[test]
fn test_autoedit_approves_write() {
let hook = make_hook(ApprovalMode::AutoEdit);
let result = hook.should_auto_decide("write", "{}");
assert_eq!(result, Some(ToolCallHookAction::Continue));
}
#[test]
fn test_autoedit_falls_through_shell() {
let hook = make_hook(ApprovalMode::AutoEdit);
let result = hook.should_auto_decide("shell", r#"{"command": "ls"}"#);
assert_eq!(result, None); }
#[test]
fn test_default_approves_read() {
let hook = make_hook(ApprovalMode::Default);
let result = hook.should_auto_decide("read", "{}");
assert_eq!(result, Some(ToolCallHookAction::Continue));
}
#[test]
fn test_default_falls_through_write() {
let hook = make_hook(ApprovalMode::Default);
let result = hook.should_auto_decide("write", "{}");
assert_eq!(result, None);
}
#[test]
fn test_default_falls_through_shell() {
let hook = make_hook(ApprovalMode::Default);
let result = hook.should_auto_decide("shell", r#"{"command": "ls"}"#);
assert_eq!(result, None);
}
#[test]
fn test_deny_rule_overrides_yolo() {
let hook = make_hook_with_deny(ApprovalMode::Yolo, vec!["rm -rf /".to_string()]);
let result = hook.should_auto_decide("shell", r#"{"command": "rm -rf /"}"#);
assert!(matches!(result, Some(ToolCallHookAction::Skip { .. })));
}
#[test]
fn test_deny_rule_only_checks_shell() {
let hook = make_hook_with_deny(ApprovalMode::Yolo, vec!["rm -rf /".to_string()]);
let result = hook.should_auto_decide("read", "{}");
assert_eq!(result, Some(ToolCallHookAction::Continue));
}
#[tokio::test]
async fn test_approve_all_set() {
let (tx, mut rx) = mpsc::unbounded_channel();
let (atx, _arx) = mpsc::unbounded_channel();
let hook = ApprovalHook::new(ApprovalMode::Default, vec![], tx, atx, 25);
{
let mut approved = hook.approved_all.lock().await;
approved.insert("shell".to_string());
}
let result =
<ApprovalHook as PromptHook<rig_bedrock::completion::CompletionModel>>::on_tool_call(
&hook,
"shell",
None,
"test-id",
r#"{"command": "ls"}"#,
)
.await;
assert_eq!(result, ToolCallHookAction::Continue);
assert!(rx.try_recv().is_err());
}
#[test]
fn test_turn_counter_initial() {
let hook = make_hook(ApprovalMode::Default);
assert_eq!(hook.turn_count(), 0);
}
#[test]
fn test_max_turns_for_display() {
let hook = make_hook(ApprovalMode::Default);
assert_eq!(hook.max_turns_for_display(), 25);
}
#[test]
fn test_hook_is_clone_send_sync() {
fn assert_clone_send_sync<T: Clone + Send + Sync>() {}
assert_clone_send_sync::<ApprovalHook>();
}
}