use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::Mutex;
use crate::tool::{ToolCall, ToolHook};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum HumanRequest {
Confirm {
tool_name: String,
arguments: serde_json::Value,
prompt: String,
},
Input { prompt: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum HumanResponse {
Approved,
Denied { reason: String },
Input { content: String },
Timeout,
}
pub struct ConfirmationHook {
patterns: Vec<String>,
tx: UnboundedSender<HumanRequest>,
rx: Arc<Mutex<UnboundedReceiver<HumanResponse>>>,
timeout: Duration,
}
impl std::fmt::Debug for ConfirmationHook {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConfirmationHook")
.field("patterns", &self.patterns)
.field("timeout", &self.timeout)
.finish_non_exhaustive()
}
}
impl ConfirmationHook {
pub fn new(patterns: Vec<String>, tx: UnboundedSender<HumanRequest>, rx: Arc<Mutex<UnboundedReceiver<HumanResponse>>>, timeout: Duration) -> Self {
Self { patterns, tx, rx, timeout }
}
fn matches(&self, tool_name: &str) -> bool {
self.patterns.iter().any(|p| tool_name.contains(p.as_str()))
}
}
#[async_trait]
impl ToolHook for ConfirmationHook {
async fn pre_call(&self, call: &ToolCall) -> anyhow::Result<()> {
if !self.matches(&call.name) {
return Ok(());
}
let request = HumanRequest::Confirm {
tool_name: call.name.clone(),
arguments: call.arguments.clone(),
prompt: format!("Tool '{}' requires confirmation. Allow?", call.name),
};
if self.tx.send(request).is_err() {
anyhow::bail!("human confirmation channel closed");
}
let mut rx = self.rx.lock().await;
match tokio::time::timeout(self.timeout, rx.recv()).await {
Ok(Some(HumanResponse::Approved)) => Ok(()),
Ok(Some(HumanResponse::Denied { reason })) => {
anyhow::bail!("User denied: {reason}")
}
Ok(Some(HumanResponse::Timeout)) => {
anyhow::bail!("confirmation timeout")
}
Ok(Some(HumanResponse::Input { .. })) => {
anyhow::bail!("unexpected Input response to Confirm request")
}
Ok(None) => {
anyhow::bail!("human response channel closed")
}
Err(_elapsed) => {
anyhow::bail!("confirmation timeout")
}
}
}
}
pub struct HumanChannelPair {
pub request_tx: UnboundedSender<HumanRequest>,
pub request_rx: UnboundedReceiver<HumanRequest>,
pub response_tx: UnboundedSender<HumanResponse>,
pub response_rx: Arc<Mutex<UnboundedReceiver<HumanResponse>>>,
}
pub fn human_channel() -> HumanChannelPair {
let (request_tx, request_rx) = tokio::sync::mpsc::unbounded_channel();
let (response_tx, resp_rx) = tokio::sync::mpsc::unbounded_channel();
HumanChannelPair {
request_tx,
request_rx,
response_tx,
response_rx: Arc::new(Mutex::new(resp_rx)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn human_request_confirm_serialization() {
let req = HumanRequest::Confirm {
tool_name: "rm_file".into(),
arguments: serde_json::json!({"path": "/tmp/foo"}),
prompt: "Allow rm_file?".into(),
};
let json = serde_json::to_string(&req).expect("serialize");
assert!(json.contains("Confirm"));
assert!(json.contains("rm_file"));
let deserialized: HumanRequest = serde_json::from_str(&json).expect("deserialize");
match deserialized {
HumanRequest::Confirm { tool_name, .. } => assert_eq!(tool_name, "rm_file"),
_ => panic!("expected Confirm variant"),
}
}
#[test]
fn human_response_denied_serialization() {
let resp = HumanResponse::Denied {
reason: "too dangerous".into(),
};
let json = serde_json::to_string(&resp).expect("serialize");
assert!(json.contains("Denied"));
assert!(json.contains("too dangerous"));
let deserialized: HumanResponse = serde_json::from_str(&json).expect("deserialize");
match deserialized {
HumanResponse::Denied { reason } => assert_eq!(reason, "too dangerous"),
_ => panic!("expected Denied variant"),
}
}
#[test]
fn human_request_response_roundtrip() {
let requests = vec![
HumanRequest::Confirm {
tool_name: "write".into(),
arguments: serde_json::json!({}),
prompt: "Allow?".into(),
},
HumanRequest::Input { prompt: "Enter value:".into() },
];
for req in &requests {
let json = serde_json::to_string(req).expect("serialize");
let back: HumanRequest = serde_json::from_str(&json).expect("deserialize");
let json2 = serde_json::to_string(&back).expect("re-serialize");
assert_eq!(json, json2);
}
let responses = vec![
HumanResponse::Approved,
HumanResponse::Denied { reason: "no".into() },
HumanResponse::Input { content: "hello".into() },
HumanResponse::Timeout,
];
for resp in &responses {
let json = serde_json::to_string(resp).expect("serialize");
let back: HumanResponse = serde_json::from_str(&json).expect("deserialize");
let json2 = serde_json::to_string(&back).expect("re-serialize");
assert_eq!(json, json2);
}
}
#[tokio::test]
async fn confirmation_hook_blocks_matching_pattern() {
let (req_tx, req_rx) = tokio::sync::mpsc::unbounded_channel();
let (resp_tx, resp_rx) = tokio::sync::mpsc::unbounded_channel();
let resp_rx = Arc::new(Mutex::new(resp_rx));
let hook = ConfirmationHook::new(vec!["dangerous".into()], req_tx, Arc::clone(&resp_rx), Duration::from_secs(5));
let call = ToolCall {
id: "call-1".into(),
name: "dangerous_delete".into(),
arguments: serde_json::json!({}),
};
let _req_rx = req_rx; tokio::spawn(async move {
resp_tx.send(HumanResponse::Approved).expect("send response");
});
tokio::task::yield_now().await;
let result = hook.pre_call(&call).await;
assert!(result.is_ok(), "approved tool should pass: {result:?}");
}
#[tokio::test]
async fn non_matching_pattern_passes_through() {
let (req_tx, _req_rx) = tokio::sync::mpsc::unbounded_channel();
let (_resp_tx, resp_rx) = tokio::sync::mpsc::unbounded_channel();
let resp_rx = Arc::new(Mutex::new(resp_rx));
let hook = ConfirmationHook::new(vec!["dangerous".into()], req_tx, resp_rx, Duration::from_secs(5));
let call = ToolCall {
id: "call-2".into(),
name: "safe_read".into(),
arguments: serde_json::json!({}),
};
let result = hook.pre_call(&call).await;
assert!(result.is_ok(), "non-matching tool should pass through");
}
#[tokio::test]
async fn timeout_on_no_response() {
let (req_tx, _req_rx) = tokio::sync::mpsc::unbounded_channel();
let (_resp_tx, resp_rx) = tokio::sync::mpsc::unbounded_channel();
let resp_rx = Arc::new(Mutex::new(resp_rx));
let hook = ConfirmationHook::new(vec!["slow".into()], req_tx, resp_rx, Duration::from_millis(50));
let call = ToolCall {
id: "call-3".into(),
name: "slow_tool".into(),
arguments: serde_json::json!({}),
};
let result = hook.pre_call(&call).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("confirmation timeout"), "error should mention timeout, got: {err}");
}
#[tokio::test]
async fn denied_reason_included_in_error() {
let (req_tx, _req_rx) = tokio::sync::mpsc::unbounded_channel();
let (resp_tx, resp_rx) = tokio::sync::mpsc::unbounded_channel();
let resp_rx = Arc::new(Mutex::new(resp_rx));
let hook = ConfirmationHook::new(vec!["write".into()], req_tx, Arc::clone(&resp_rx), Duration::from_secs(5));
let call = ToolCall {
id: "call-4".into(),
name: "write_file".into(),
arguments: serde_json::json!({}),
};
resp_tx.send(HumanResponse::Denied { reason: "not now".into() }).expect("send");
let result = hook.pre_call(&call).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("User denied"), "error should mention denial, got: {err}");
assert!(err.contains("not now"), "error should include reason, got: {err}");
}
}