Skip to main content

abu_agent/middleware/
mod.rs

1mod privacy;
2pub use privacy::*;
3mod tool;
4pub use tool::*;
5
6use abu_base::chat::{AssistantMessage, ToolCall};
7use abu_tool::ToolCallResult;
8use crate::{AgentError, AgentResult};
9
10pub enum MiddlewareFlow {
11    Continue,
12    Break(String),
13}
14
15#[async_trait::async_trait]
16pub trait LlmOutMiddleware: Send + Sync {
17    type Error: std::error::Error + Send + Sync + 'static;
18    async fn intercept(&self, ai_message: &mut AssistantMessage) -> Result<MiddlewareFlow, Self::Error>;
19}
20
21#[async_trait::async_trait]
22pub trait ToolCallMiddleware: Send + Sync {
23    type Error: std::error::Error + Send + Sync + 'static;
24    async fn intercept(&self, tool_call: &mut ToolCall) -> Result<MiddlewareFlow, Self::Error>;
25}
26
27#[async_trait::async_trait]
28pub trait ToolResultMiddleware: Send + Sync {
29    type Error: std::error::Error + Send + Sync + 'static;
30    async fn intercept(&self, tool_name: &str, result: &mut ToolCallResult) -> Result<MiddlewareFlow, Self::Error>;
31}
32
33pub enum Middleware {
34    LlmOut(Box<dyn DynLlmOutMiddleware>),
35    ToolCall(Box<dyn DynToolCallMiddleware>),
36    ToolResult(Box<dyn DynToolResultMiddleware>),
37} 
38
39impl Middleware {
40    pub fn llm_out<M: LlmOutMiddleware + 'static>(m: M) -> Self {
41        Self::LlmOut(Box::new(m))
42    }
43
44    pub fn tool_call<M: ToolCallMiddleware + 'static>(m: M) -> Self {
45        Self::ToolCall(Box::new(m))
46    }
47
48    pub fn tool_result<M: ToolResultMiddleware + 'static>(m: M) -> Self {
49        Self::ToolResult(Box::new(m))
50    }  
51}
52
53#[derive(Default)]
54pub struct MiddlewareManager {
55    llm_outs: Vec<Box<dyn DynLlmOutMiddleware>>,
56    tool_calls: Vec<Box<dyn DynToolCallMiddleware>>,
57    tool_results: Vec<Box<dyn DynToolResultMiddleware>>,
58}
59
60macro_rules! pass_middleware_flow {
61    ($flow:ident) => {
62        if matches!($flow, MiddlewareFlow::Break(_)) {
63            return Ok($flow);
64        }
65    };
66}
67
68impl MiddlewareManager {
69    pub fn new() -> Self {
70        Self::default()
71    }
72
73    pub async fn intercept_llm_out(&self, ai_message: &mut AssistantMessage) -> AgentResult<MiddlewareFlow> {
74        for middleware in self.llm_outs.iter() {
75            let flow = middleware.intercept(ai_message).await?;
76            pass_middleware_flow!(flow);
77        }
78        Ok(MiddlewareFlow::Continue)
79    }
80
81    pub async fn intercept_tool_call(&self, tool_call: &mut ToolCall) -> AgentResult<MiddlewareFlow> {
82        for middleware in self.tool_calls.iter() {
83            let flow = middleware.intercept(tool_call).await?;
84            pass_middleware_flow!(flow);
85        }
86        Ok(MiddlewareFlow::Continue)
87    }
88
89    pub async fn intercept_tool_result(&self, tool_name: &str, result: &mut ToolCallResult) -> AgentResult<MiddlewareFlow> {
90        for middleware in self.tool_results.iter() {
91            let flow = middleware.intercept(tool_name, result).await?;
92            pass_middleware_flow!(flow);
93        }
94        Ok(MiddlewareFlow::Continue)
95    }
96
97    pub fn add_llm_out<M: LlmOutMiddleware + 'static>(&mut self, middleware: M) {
98        self.llm_outs.push(Box::new(middleware));
99    }
100
101    pub fn add_tool_call<M: ToolCallMiddleware + 'static>(&mut self, middleware: M) {
102        self.tool_calls.push(Box::new(middleware));
103    }
104
105    pub fn add_tool_result<M: ToolResultMiddleware + 'static>(&mut self, middleware: M) {
106        self.tool_results.push(Box::new(middleware));
107    }
108
109    pub fn add_middleware(&mut self, middleware: impl Into<Middleware>) {
110        match middleware.into() {
111            Middleware::LlmOut(m) => self.llm_outs.push(m),
112            Middleware::ToolCall(m) => self.tool_calls.push(m),
113            Middleware::ToolResult(m) => self.tool_results.push(m),
114        }
115    }
116}
117
118// ======================================================================================= //
119//                   Dyn trait
120// ======================================================================================= //
121
122#[async_trait::async_trait]
123pub trait DynLlmOutMiddleware: Send + Sync {
124    async fn intercept(&self, ai_message: &mut AssistantMessage) -> AgentResult<MiddlewareFlow>;
125}
126
127#[async_trait::async_trait]
128pub trait DynToolCallMiddleware: Send + Sync {
129    async fn intercept(&self, tool_call: &mut ToolCall) -> AgentResult<MiddlewareFlow>;
130}
131
132#[async_trait::async_trait]
133pub trait DynToolResultMiddleware: Send + Sync {
134    async fn intercept(&self, tool_name: &str, result: &mut ToolCallResult) -> AgentResult<MiddlewareFlow>;
135}
136
137#[async_trait::async_trait]
138impl<M: LlmOutMiddleware> DynLlmOutMiddleware for M {
139    #[inline]
140    async fn intercept(&self, ai_message: &mut AssistantMessage) -> AgentResult<MiddlewareFlow> {
141        let res = self
142            .intercept(ai_message).await
143            .map_err(|e| AgentError::Middleware("llm out", Box::new(e)))?;
144        Ok(res)
145    }
146}
147
148#[async_trait::async_trait]
149impl<M: ToolCallMiddleware> DynToolCallMiddleware for M {
150    #[inline]
151    async fn intercept(&self, tool_call: &mut ToolCall) -> AgentResult<MiddlewareFlow> {
152        let res = self
153            .intercept(tool_call).await
154            .map_err(|e| AgentError::Middleware("tool call", Box::new(e)))?;
155        Ok(res)
156    }
157}
158
159#[async_trait::async_trait]
160impl<M: ToolResultMiddleware> DynToolResultMiddleware for M {
161    #[inline]
162    async fn intercept(&self, tool_name: &str, result: &mut ToolCallResult) -> AgentResult<MiddlewareFlow> {
163        let res = self
164            .intercept(tool_name, result).await
165            .map_err(|e| AgentError::Middleware("tool result", Box::new(e)))?;
166        Ok(res)
167    }
168}