Skip to main content

agent_base/engine/
middleware.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::Value;
5use tokio::sync::broadcast;
6
7use crate::types::{AgentResult, AgentEvent, ChatMessage, SessionId};
8
9#[derive(Clone)]
10pub struct UserMessageCtx {
11    pub session_id: SessionId,
12    pub user_input: String,
13    pub event_bus: broadcast::Sender<AgentEvent>,
14}
15
16#[derive(Clone)]
17pub struct PreLlmCtx {
18    pub session_id: SessionId,
19    pub messages: Vec<ChatMessage>,
20    pub tools: Vec<Value>,
21    pub event_bus: broadcast::Sender<AgentEvent>,
22}
23
24#[derive(Clone)]
25pub struct PostLlmCtx {
26    pub session_id: SessionId,
27    pub full_text: String,
28    pub is_tool_call: bool,
29    pub tool_calls: Vec<(String, String, String)>,
30    pub event_bus: broadcast::Sender<AgentEvent>,
31}
32
33#[async_trait]
34pub trait Middleware: Send + Sync {
35    async fn on_user_message(&self, _ctx: &mut UserMessageCtx) -> AgentResult<()> {
36        Ok(())
37    }
38
39    async fn on_pre_llm(&self, _ctx: &mut PreLlmCtx) -> AgentResult<()> {
40        Ok(())
41    }
42
43    async fn on_post_llm(&self, _ctx: &mut PostLlmCtx) -> AgentResult<()> {
44        Ok(())
45    }
46}
47
48pub(crate) type MiddlewareRef = Arc<dyn Middleware>;