Skip to main content

abu_agent/middleware/
tool.rs

1use std::{convert::Infallible, io::Write};
2use abu_base::chat::ToolCall;
3use abu_tool::ToolCallResult;
4use super::{MiddlewareFlow, ToolCallMiddleware, ToolResultMiddleware};
5use std::collections::HashSet;
6
7// ====================================================================== //
8//                    HitlMiddleware
9// ====================================================================== //
10
11pub struct HitlMiddleware {
12    dangerous_tools: Vec<String>, 
13}
14
15impl HitlMiddleware {
16    pub fn new<S: Into<String>>(tools: impl IntoIterator<Item = S>) -> Self {
17        let tools = tools.into_iter()
18            .map(|t| t.into())
19            .collect();
20        Self {
21            dangerous_tools: tools,
22        }
23    }
24}
25
26#[async_trait::async_trait]
27impl ToolCallMiddleware for HitlMiddleware {
28    type Error = Infallible;
29
30    async fn intercept(&self, tool_call: &mut ToolCall) -> Result<MiddlewareFlow, Self::Error> {
31        if self.dangerous_tools.contains(&tool_call.name) {
32            println!("⚠️ [HITL] AI 想要执行高危操作: {}", tool_call.name);
33            println!("   参数: {}", tool_call.arguments);
34            print!("   同意执行吗?(y/N/edit): ");
35            std::io::stdout().flush().unwrap();
36
37            let mut input = String::new();
38            std::io::stdin().read_line(&mut input).unwrap();
39            let input = input.trim().to_lowercase();
40
41            match input.as_str() {
42                "y" | "yes" => {
43                    println!("✅ 人类已批准。");
44                    Ok(MiddlewareFlow::Continue)
45                }
46                "edit" => {
47                    print!("   请输入新的 JSON 参数: ");
48                    std::io::stdout().flush().unwrap();
49                    let mut new_args = String::new();
50                    std::io::stdin().read_line(&mut new_args).unwrap();
51                    tool_call.arguments = new_args.trim().to_string();
52                    Ok(MiddlewareFlow::Continue)
53                }
54                _ => {
55                    println!("🚫 人类已拒绝执行该操作。");
56                    Ok(MiddlewareFlow::Break("Rejected".to_string()))
57                }
58            }
59        } else {
60            Ok(MiddlewareFlow::Continue)
61        }
62    }
63}
64
65// ====================================================================== //
66//                    ToolPermissionGuardMiddleware
67// ====================================================================== //
68
69pub struct ToolPermissionGuardMiddleware {
70    forbidden_tools: HashSet<String>,
71}
72
73impl ToolPermissionGuardMiddleware {
74    pub fn new(forbidden: &[&str]) -> Self {
75        let forbidden_tools = forbidden.iter().map(|&s| s.to_string()).collect();
76        Self { forbidden_tools }
77    }
78}
79
80#[async_trait::async_trait]
81impl ToolCallMiddleware for ToolPermissionGuardMiddleware {
82    type Error = Infallible;
83
84    async fn intercept(&self, tool_call: &mut ToolCall) -> Result<MiddlewareFlow, Self::Error> {
85        if self.forbidden_tools.contains(&tool_call.name) {
86            let error_msg = format!(
87                "System Error: Permission denied to execute tool '{}'. Please do not try to use this tool.",
88                tool_call.name
89            );
90            return Ok(MiddlewareFlow::Break(error_msg));
91        }
92
93        Ok(MiddlewareFlow::Continue)
94    }
95}
96
97pub struct ResultTruncatorMiddleware {
98    max_length: usize,
99}
100
101impl ResultTruncatorMiddleware {
102    pub fn new(max_length: usize) -> Self {
103        Self { max_length }
104    }
105}
106
107// ====================================================================== //
108//                    ToolPermissionGuardMiddleware
109// ====================================================================== //
110
111#[async_trait::async_trait]
112impl ToolResultMiddleware for ResultTruncatorMiddleware {
113    type Error = Infallible;
114
115    async fn intercept(&self, _tool_name: &str, result: &mut ToolCallResult) -> Result<MiddlewareFlow, Self::Error> {
116        let current_len = result.context.chars().count();
117
118        if current_len > self.max_length {
119            let truncated: String = result.context.chars().take(self.max_length).collect();            
120            result.context = format!(
121                "{}...\n[System]: Output was too long and has been truncated to {} characters.",
122                truncated, self.max_length
123            );
124        }
125
126        Ok(MiddlewareFlow::Continue)
127    }
128}