Skip to main content

agent_proxy_rust_core/
middleware.rs

1//! The [`ProxyMiddleware`] trait — the core extension point for the proxy engine.
2
3use async_trait::async_trait;
4
5use crate::{
6    error::ProxyError,
7    types::{ConnectionContext, ProxyRequest, ProxyResponse},
8};
9
10/// The central extension trait for `agent-proxy-rust`.
11///
12/// Implementors can intercept and transform requests/responses flowing through
13/// the proxy. Execution order:
14///
15/// - `on_init`: registration order (at startup)
16/// - `on_request`: registration order
17/// - `on_response`: **reverse** registration order
18/// - `on_disconnect`: reverse registration order
19/// - `on_shutdown`: reverse registration order
20///
21/// All methods have default no-op implementations except `on_request`,`on_response`, and `name`.
22#[async_trait]
23pub trait ProxyMiddleware: Send + Sync {
24    /// Called before forwarding the request to upstream.
25    ///
26    /// Middleware may modify the request body, headers, or context extensions.
27    /// For example, the compress middleware reduces tool definition sizes,
28    /// and the model-router middleware selects a channel and sets the upstream URL.
29    async fn on_request(
30        &self,
31        req: &mut ProxyRequest,
32        ctx: &mut ConnectionContext,
33    ) -> Result<(), ProxyError>;
34
35    /// Called after receiving the response from upstream.
36    ///
37    /// Middleware may modify the response body, headers, or context extensions.
38    /// Called in **reverse** registration order for symmetry with `on_request`.
39    async fn on_response(
40        &self,
41        res: &mut ProxyResponse,
42        ctx: &ConnectionContext,
43    ) -> Result<(), ProxyError>;
44
45    /// Called when a new connection is established. Runs in registration order.
46    async fn on_connect(&self, _ctx: &ConnectionContext) {}
47
48    /// Called when a connection is closed. Runs in reverse registration order.
49    async fn on_disconnect(&self, _ctx: &ConnectionContext) {}
50
51    /// Called once when the proxy starts. Use for opening DB pools, loading config, etc.
52    async fn on_init(&self) -> Result<(), ProxyError> {
53        Ok(())
54    }
55
56    /// Called once when the proxy shuts down gracefully.
57    async fn on_shutdown(&self) -> Result<(), ProxyError> {
58        Ok(())
59    }
60
61    /// Returns the unique name of this middleware.
62    ///
63    /// Used for logging and debugging.
64    fn name(&self) -> &'static str;
65}
66
67/// Runs the `on_request` chain in registration order.
68///
69/// If any middleware returns `Err`, the chain is aborted and the error is returned.
70///
71/// # Errors
72///
73/// Returns the first [`ProxyError`] encountered from any middleware in the chain.
74pub async fn run_on_request_chain(
75    middlewares: &[Box<dyn ProxyMiddleware>],
76    req: &mut ProxyRequest,
77    ctx: &mut ConnectionContext,
78) -> Result<(), ProxyError> {
79    for mw in middlewares {
80        mw.on_request(req, ctx).await?;
81    }
82    Ok(())
83}
84
85/// Runs the `on_response` chain in **reverse** registration order.
86///
87/// If any middleware returns `Err`, the chain is aborted and the error is returned.
88///
89/// # Errors
90///
91/// Returns the first [`ProxyError`] encountered from any middleware in the chain.
92pub async fn run_on_response_chain(
93    middlewares: &[Box<dyn ProxyMiddleware>],
94    res: &mut ProxyResponse,
95    ctx: &ConnectionContext,
96) -> Result<(), ProxyError> {
97    for mw in middlewares.iter().rev() {
98        mw.on_response(res, ctx).await?;
99    }
100    Ok(())
101}
102
103// ── Cost recorder trait ────────────────────────────────────────────────
104
105/// Post-response cost recording hook.
106///
107/// Called after the `on_response` middleware chain completes and before the
108/// axum response is built. Implementors (typically the `cost` crate) use this
109/// to extract usage, calculate cost, and persist a `CostRecord`.
110///
111/// This is deliberately not part of [`ProxyMiddleware`] because cost recording
112/// needs to happen after ALL other response transformations are done.
113#[async_trait::async_trait]
114pub trait CostRecorder: Send + Sync + std::fmt::Debug {
115    /// Record a cost entry for the completed request.
116    ///
117    /// `response_body` is the final response body JSON (after all middleware
118    /// transforms have been applied).
119    async fn record(
120        &self,
121        ctx: &crate::types::ConnectionContext,
122        response_body: &serde_json::Value,
123    ) -> Result<(), crate::error::ProxyError>;
124}
125
126#[cfg(test)]
127#[allow(clippy::unwrap_used)]
128mod tests {
129    use std::sync::{
130        Arc,
131        atomic::{AtomicUsize, Ordering},
132    };
133
134    use bytes::Bytes;
135    use http::{HeaderMap, Method, StatusCode};
136
137    use super::*;
138
139    struct RecordingMiddleware {
140        name: &'static str,
141        request_order: Arc<AtomicUsize>,
142        response_order: Arc<AtomicUsize>,
143        request_counter: AtomicUsize,
144        response_counter: AtomicUsize,
145        request_err: Option<ProxyError>,
146    }
147
148    #[async_trait]
149    impl ProxyMiddleware for RecordingMiddleware {
150        async fn on_request(
151            &self,
152            _req: &mut ProxyRequest,
153            _ctx: &mut ConnectionContext,
154        ) -> Result<(), ProxyError> {
155            if let Some(ref err) = self.request_err {
156                return Err(ProxyError::BadRequest(err.to_string()));
157            }
158            let seq = self.request_order.fetch_add(1, Ordering::SeqCst);
159            self.request_counter.store(seq, Ordering::SeqCst);
160            Ok(())
161        }
162
163        async fn on_response(
164            &self,
165            _res: &mut ProxyResponse,
166            _ctx: &ConnectionContext,
167        ) -> Result<(), ProxyError> {
168            let seq = self.response_order.fetch_add(1, Ordering::SeqCst);
169            self.response_counter.store(seq, Ordering::SeqCst);
170            Ok(())
171        }
172
173        fn name(&self) -> &'static str {
174            self.name
175        }
176    }
177
178    fn make_request() -> ProxyRequest {
179        ProxyRequest::new(
180            Method::POST,
181            "/v1/messages".into(),
182            HeaderMap::new(),
183            Bytes::from(r#"{"model":"test"}"#),
184        )
185    }
186
187    fn make_context() -> ConnectionContext {
188        ConnectionContext::new(1, crate::types::AgentType::Unknown, None, None)
189    }
190
191    fn make_response() -> ProxyResponse {
192        ProxyResponse::new(StatusCode::OK, HeaderMap::new(), Bytes::new(), false)
193    }
194
195    #[tokio::test]
196    async fn test_on_request_runs_in_registration_order() {
197        let order = Arc::new(AtomicUsize::new(0));
198        let mw_a = RecordingMiddleware {
199            name: "A",
200            request_order: order.clone(),
201            response_order: Arc::new(AtomicUsize::new(0)),
202            request_counter: AtomicUsize::new(0),
203            response_counter: AtomicUsize::new(0),
204            request_err: None,
205        };
206        let mw_b = RecordingMiddleware {
207            name: "B",
208            request_order: order.clone(),
209            response_order: Arc::new(AtomicUsize::new(0)),
210            request_counter: AtomicUsize::new(0),
211            response_counter: AtomicUsize::new(0),
212            request_err: None,
213        };
214        let mw_c = RecordingMiddleware {
215            name: "C",
216            request_order: order.clone(),
217            response_order: Arc::new(AtomicUsize::new(0)),
218            request_counter: AtomicUsize::new(0),
219            response_counter: AtomicUsize::new(0),
220            request_err: None,
221        };
222
223        let middlewares: Vec<Box<dyn ProxyMiddleware>> =
224            vec![Box::new(mw_a), Box::new(mw_b), Box::new(mw_c)];
225
226        let mut req = make_request();
227        let mut ctx = make_context();
228
229        run_on_request_chain(&middlewares, &mut req, &mut ctx)
230            .await
231            .unwrap();
232
233        // After running, the order counter should be 3 (0→1→2→3)
234        assert_eq!(order.load(Ordering::SeqCst), 3);
235    }
236
237    #[tokio::test]
238    async fn test_on_response_runs_in_reverse_registration_order() {
239        let order = Arc::new(AtomicUsize::new(0));
240        let mw_a = RecordingMiddleware {
241            name: "A",
242            request_order: Arc::new(AtomicUsize::new(0)),
243            response_order: order.clone(),
244            request_counter: AtomicUsize::new(0),
245            response_counter: AtomicUsize::new(0),
246            request_err: None,
247        };
248        let mw_b = RecordingMiddleware {
249            name: "B",
250            request_order: Arc::new(AtomicUsize::new(0)),
251            response_order: order.clone(),
252            request_counter: AtomicUsize::new(0),
253            response_counter: AtomicUsize::new(0),
254            request_err: None,
255        };
256        let mw_c = RecordingMiddleware {
257            name: "C",
258            request_order: Arc::new(AtomicUsize::new(0)),
259            response_order: order.clone(),
260            request_counter: AtomicUsize::new(0),
261            response_counter: AtomicUsize::new(0),
262            request_err: None,
263        };
264
265        let middlewares: Vec<Box<dyn ProxyMiddleware>> =
266            vec![Box::new(mw_a), Box::new(mw_b), Box::new(mw_c)];
267
268        let mut res = make_response();
269        let ctx = make_context();
270
271        run_on_response_chain(&middlewares, &mut res, &ctx)
272            .await
273            .unwrap();
274
275        assert_eq!(order.load(Ordering::SeqCst), 3);
276    }
277
278    #[tokio::test]
279    async fn test_on_request_aborts_on_error() {
280        let mw_ok = RecordingMiddleware {
281            name: "ok",
282            request_order: Arc::new(AtomicUsize::new(0)),
283            response_order: Arc::new(AtomicUsize::new(0)),
284            request_counter: AtomicUsize::new(0),
285            response_counter: AtomicUsize::new(0),
286            request_err: None,
287        };
288        let mw_err = RecordingMiddleware {
289            name: "err",
290            request_order: Arc::new(AtomicUsize::new(0)),
291            response_order: Arc::new(AtomicUsize::new(0)),
292            request_counter: AtomicUsize::new(0),
293            response_counter: AtomicUsize::new(0),
294            request_err: Some(ProxyError::BadRequest("test error".into())),
295        };
296        let mw_never = RecordingMiddleware {
297            name: "never",
298            request_order: Arc::new(AtomicUsize::new(0)),
299            response_order: Arc::new(AtomicUsize::new(0)),
300            request_counter: AtomicUsize::new(0),
301            response_counter: AtomicUsize::new(0),
302            request_err: None,
303        };
304
305        let middlewares: Vec<Box<dyn ProxyMiddleware>> =
306            vec![Box::new(mw_ok), Box::new(mw_err), Box::new(mw_never)];
307
308        let mut req = make_request();
309        let mut ctx = make_context();
310
311        let result = run_on_request_chain(&middlewares, &mut req, &mut ctx).await;
312        assert!(result.is_err());
313    }
314}