agent_base/engine/
middleware.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::Value;
5use tokio::sync::broadcast;
6
7use crate::types::{AgentResult, AgentEvent, ChatMessage, SessionId};
8
9#[derive(Clone)]
10pub struct UserMessageCtx {
11 pub session_id: SessionId,
12 pub user_input: String,
13 pub event_bus: broadcast::Sender<AgentEvent>,
14}
15
16#[derive(Clone)]
17pub struct PreLlmCtx {
18 pub session_id: SessionId,
19 pub messages: Vec<ChatMessage>,
20 pub tools: Vec<Value>,
21 pub event_bus: broadcast::Sender<AgentEvent>,
22}
23
24#[derive(Clone)]
25pub struct PostLlmCtx {
26 pub session_id: SessionId,
27 pub full_text: String,
28 pub is_tool_call: bool,
29 pub tool_calls: Vec<(String, String, String)>,
30 pub event_bus: broadcast::Sender<AgentEvent>,
31}
32
33#[async_trait]
34pub trait Middleware: Send + Sync {
35 async fn on_user_message(&self, _ctx: &mut UserMessageCtx) -> AgentResult<()> {
36 Ok(())
37 }
38
39 async fn on_pre_llm(&self, _ctx: &mut PreLlmCtx) -> AgentResult<()> {
40 Ok(())
41 }
42
43 async fn on_post_llm(&self, _ctx: &mut PostLlmCtx) -> AgentResult<()> {
44 Ok(())
45 }
46}
47
48pub(crate) type MiddlewareRef = Arc<dyn Middleware>;