abu_agent/middleware/
mod.rs1mod privacy;
2pub use privacy::*;
3mod tool;
4pub use tool::*;
5
6use abu_base::chat::{AssistantMessage, ToolCall};
7use abu_tool::ToolCallResult;
8use crate::{AgentError, AgentResult};
9
10pub enum MiddlewareFlow {
11 Continue,
12 Break(String),
13}
14
15#[async_trait::async_trait]
16pub trait LlmOutMiddleware: Send + Sync {
17 type Error: std::error::Error + Send + Sync + 'static;
18 async fn intercept(&self, ai_message: &mut AssistantMessage) -> Result<MiddlewareFlow, Self::Error>;
19}
20
21#[async_trait::async_trait]
22pub trait ToolCallMiddleware: Send + Sync {
23 type Error: std::error::Error + Send + Sync + 'static;
24 async fn intercept(&self, tool_call: &mut ToolCall) -> Result<MiddlewareFlow, Self::Error>;
25}
26
27#[async_trait::async_trait]
28pub trait ToolResultMiddleware: Send + Sync {
29 type Error: std::error::Error + Send + Sync + 'static;
30 async fn intercept(&self, tool_name: &str, result: &mut ToolCallResult) -> Result<MiddlewareFlow, Self::Error>;
31}
32
33pub enum Middleware {
34 LlmOut(Box<dyn DynLlmOutMiddleware>),
35 ToolCall(Box<dyn DynToolCallMiddleware>),
36 ToolResult(Box<dyn DynToolResultMiddleware>),
37}
38
39impl Middleware {
40 pub fn llm_out<M: LlmOutMiddleware + 'static>(m: M) -> Self {
41 Self::LlmOut(Box::new(m))
42 }
43
44 pub fn tool_call<M: ToolCallMiddleware + 'static>(m: M) -> Self {
45 Self::ToolCall(Box::new(m))
46 }
47
48 pub fn tool_result<M: ToolResultMiddleware + 'static>(m: M) -> Self {
49 Self::ToolResult(Box::new(m))
50 }
51}
52
53#[derive(Default)]
54pub struct MiddlewareManager {
55 llm_outs: Vec<Box<dyn DynLlmOutMiddleware>>,
56 tool_calls: Vec<Box<dyn DynToolCallMiddleware>>,
57 tool_results: Vec<Box<dyn DynToolResultMiddleware>>,
58}
59
60macro_rules! pass_middleware_flow {
61 ($flow:ident) => {
62 if matches!($flow, MiddlewareFlow::Break(_)) {
63 return Ok($flow);
64 }
65 };
66}
67
68impl MiddlewareManager {
69 pub fn new() -> Self {
70 Self::default()
71 }
72
73 pub async fn intercept_llm_out(&self, ai_message: &mut AssistantMessage) -> AgentResult<MiddlewareFlow> {
74 for middleware in self.llm_outs.iter() {
75 let flow = middleware.intercept(ai_message).await?;
76 pass_middleware_flow!(flow);
77 }
78 Ok(MiddlewareFlow::Continue)
79 }
80
81 pub async fn intercept_tool_call(&self, tool_call: &mut ToolCall) -> AgentResult<MiddlewareFlow> {
82 for middleware in self.tool_calls.iter() {
83 let flow = middleware.intercept(tool_call).await?;
84 pass_middleware_flow!(flow);
85 }
86 Ok(MiddlewareFlow::Continue)
87 }
88
89 pub async fn intercept_tool_result(&self, tool_name: &str, result: &mut ToolCallResult) -> AgentResult<MiddlewareFlow> {
90 for middleware in self.tool_results.iter() {
91 let flow = middleware.intercept(tool_name, result).await?;
92 pass_middleware_flow!(flow);
93 }
94 Ok(MiddlewareFlow::Continue)
95 }
96
97 pub fn add_llm_out<M: LlmOutMiddleware + 'static>(&mut self, middleware: M) {
98 self.llm_outs.push(Box::new(middleware));
99 }
100
101 pub fn add_tool_call<M: ToolCallMiddleware + 'static>(&mut self, middleware: M) {
102 self.tool_calls.push(Box::new(middleware));
103 }
104
105 pub fn add_tool_result<M: ToolResultMiddleware + 'static>(&mut self, middleware: M) {
106 self.tool_results.push(Box::new(middleware));
107 }
108
109 pub fn add_middleware(&mut self, middleware: impl Into<Middleware>) {
110 match middleware.into() {
111 Middleware::LlmOut(m) => self.llm_outs.push(m),
112 Middleware::ToolCall(m) => self.tool_calls.push(m),
113 Middleware::ToolResult(m) => self.tool_results.push(m),
114 }
115 }
116}
117
118#[async_trait::async_trait]
123pub trait DynLlmOutMiddleware: Send + Sync {
124 async fn intercept(&self, ai_message: &mut AssistantMessage) -> AgentResult<MiddlewareFlow>;
125}
126
127#[async_trait::async_trait]
128pub trait DynToolCallMiddleware: Send + Sync {
129 async fn intercept(&self, tool_call: &mut ToolCall) -> AgentResult<MiddlewareFlow>;
130}
131
132#[async_trait::async_trait]
133pub trait DynToolResultMiddleware: Send + Sync {
134 async fn intercept(&self, tool_name: &str, result: &mut ToolCallResult) -> AgentResult<MiddlewareFlow>;
135}
136
137#[async_trait::async_trait]
138impl<M: LlmOutMiddleware> DynLlmOutMiddleware for M {
139 #[inline]
140 async fn intercept(&self, ai_message: &mut AssistantMessage) -> AgentResult<MiddlewareFlow> {
141 let res = self
142 .intercept(ai_message).await
143 .map_err(|e| AgentError::Middleware("llm out", Box::new(e)))?;
144 Ok(res)
145 }
146}
147
148#[async_trait::async_trait]
149impl<M: ToolCallMiddleware> DynToolCallMiddleware for M {
150 #[inline]
151 async fn intercept(&self, tool_call: &mut ToolCall) -> AgentResult<MiddlewareFlow> {
152 let res = self
153 .intercept(tool_call).await
154 .map_err(|e| AgentError::Middleware("tool call", Box::new(e)))?;
155 Ok(res)
156 }
157}
158
159#[async_trait::async_trait]
160impl<M: ToolResultMiddleware> DynToolResultMiddleware for M {
161 #[inline]
162 async fn intercept(&self, tool_name: &str, result: &mut ToolCallResult) -> AgentResult<MiddlewareFlow> {
163 let res = self
164 .intercept(tool_name, result).await
165 .map_err(|e| AgentError::Middleware("tool result", Box::new(e)))?;
166 Ok(res)
167 }
168}