abu_agent/middleware/
tool.rs1use std::{convert::Infallible, io::Write};
2use abu_base::chat::ToolCall;
3use abu_tool::ToolCallResult;
4use super::{MiddlewareFlow, ToolCallMiddleware, ToolResultMiddleware};
5use std::collections::HashSet;
6
7pub struct HitlMiddleware {
12 dangerous_tools: Vec<String>,
13}
14
15impl HitlMiddleware {
16 pub fn new<S: Into<String>>(tools: impl IntoIterator<Item = S>) -> Self {
17 let tools = tools.into_iter()
18 .map(|t| t.into())
19 .collect();
20 Self {
21 dangerous_tools: tools,
22 }
23 }
24}
25
26#[async_trait::async_trait]
27impl ToolCallMiddleware for HitlMiddleware {
28 type Error = Infallible;
29
30 async fn intercept(&self, tool_call: &mut ToolCall) -> Result<MiddlewareFlow, Self::Error> {
31 if self.dangerous_tools.contains(&tool_call.name) {
32 println!("⚠️ [HITL] AI 想要执行高危操作: {}", tool_call.name);
33 println!(" 参数: {}", tool_call.arguments);
34 print!(" 同意执行吗?(y/N/edit): ");
35 std::io::stdout().flush().unwrap();
36
37 let mut input = String::new();
38 std::io::stdin().read_line(&mut input).unwrap();
39 let input = input.trim().to_lowercase();
40
41 match input.as_str() {
42 "y" | "yes" => {
43 println!("✅ 人类已批准。");
44 Ok(MiddlewareFlow::Continue)
45 }
46 "edit" => {
47 print!(" 请输入新的 JSON 参数: ");
48 std::io::stdout().flush().unwrap();
49 let mut new_args = String::new();
50 std::io::stdin().read_line(&mut new_args).unwrap();
51 tool_call.arguments = new_args.trim().to_string();
52 Ok(MiddlewareFlow::Continue)
53 }
54 _ => {
55 println!("🚫 人类已拒绝执行该操作。");
56 Ok(MiddlewareFlow::Break("Rejected".to_string()))
57 }
58 }
59 } else {
60 Ok(MiddlewareFlow::Continue)
61 }
62 }
63}
64
65pub struct ToolPermissionGuardMiddleware {
70 forbidden_tools: HashSet<String>,
71}
72
73impl ToolPermissionGuardMiddleware {
74 pub fn new(forbidden: &[&str]) -> Self {
75 let forbidden_tools = forbidden.iter().map(|&s| s.to_string()).collect();
76 Self { forbidden_tools }
77 }
78}
79
80#[async_trait::async_trait]
81impl ToolCallMiddleware for ToolPermissionGuardMiddleware {
82 type Error = Infallible;
83
84 async fn intercept(&self, tool_call: &mut ToolCall) -> Result<MiddlewareFlow, Self::Error> {
85 if self.forbidden_tools.contains(&tool_call.name) {
86 let error_msg = format!(
87 "System Error: Permission denied to execute tool '{}'. Please do not try to use this tool.",
88 tool_call.name
89 );
90 return Ok(MiddlewareFlow::Break(error_msg));
91 }
92
93 Ok(MiddlewareFlow::Continue)
94 }
95}
96
97pub struct ResultTruncatorMiddleware {
98 max_length: usize,
99}
100
101impl ResultTruncatorMiddleware {
102 pub fn new(max_length: usize) -> Self {
103 Self { max_length }
104 }
105}
106
107#[async_trait::async_trait]
112impl ToolResultMiddleware for ResultTruncatorMiddleware {
113 type Error = Infallible;
114
115 async fn intercept(&self, _tool_name: &str, result: &mut ToolCallResult) -> Result<MiddlewareFlow, Self::Error> {
116 let current_len = result.context.chars().count();
117
118 if current_len > self.max_length {
119 let truncated: String = result.context.chars().take(self.max_length).collect();
120 result.context = format!(
121 "{}...\n[System]: Output was too long and has been truncated to {} characters.",
122 truncated, self.max_length
123 );
124 }
125
126 Ok(MiddlewareFlow::Continue)
127 }
128}