1use async_trait::async_trait;
4
5use crate::{
6 error::ProxyError,
7 types::{ConnectionContext, ProxyRequest, ProxyResponse},
8};
9
10#[async_trait]
23pub trait ProxyMiddleware: Send + Sync {
24 async fn on_request(
30 &self,
31 req: &mut ProxyRequest,
32 ctx: &mut ConnectionContext,
33 ) -> Result<(), ProxyError>;
34
35 async fn on_response(
40 &self,
41 res: &mut ProxyResponse,
42 ctx: &ConnectionContext,
43 ) -> Result<(), ProxyError>;
44
45 async fn on_connect(&self, _ctx: &ConnectionContext) {}
47
48 async fn on_disconnect(&self, _ctx: &ConnectionContext) {}
50
51 async fn on_init(&self) -> Result<(), ProxyError> {
53 Ok(())
54 }
55
56 async fn on_shutdown(&self) -> Result<(), ProxyError> {
58 Ok(())
59 }
60
61 fn name(&self) -> &'static str;
65}
66
67pub async fn run_on_request_chain(
75 middlewares: &[Box<dyn ProxyMiddleware>],
76 req: &mut ProxyRequest,
77 ctx: &mut ConnectionContext,
78) -> Result<(), ProxyError> {
79 for mw in middlewares {
80 mw.on_request(req, ctx).await?;
81 }
82 Ok(())
83}
84
85pub async fn run_on_response_chain(
93 middlewares: &[Box<dyn ProxyMiddleware>],
94 res: &mut ProxyResponse,
95 ctx: &ConnectionContext,
96) -> Result<(), ProxyError> {
97 for mw in middlewares.iter().rev() {
98 mw.on_response(res, ctx).await?;
99 }
100 Ok(())
101}
102
103#[async_trait::async_trait]
114pub trait CostRecorder: Send + Sync + std::fmt::Debug {
115 async fn record(
120 &self,
121 ctx: &crate::types::ConnectionContext,
122 response_body: &serde_json::Value,
123 ) -> Result<(), crate::error::ProxyError>;
124}
125
126#[cfg(test)]
127#[allow(clippy::unwrap_used)]
128mod tests {
129 use std::sync::{
130 Arc,
131 atomic::{AtomicUsize, Ordering},
132 };
133
134 use bytes::Bytes;
135 use http::{HeaderMap, Method, StatusCode};
136
137 use super::*;
138
139 struct RecordingMiddleware {
140 name: &'static str,
141 request_order: Arc<AtomicUsize>,
142 response_order: Arc<AtomicUsize>,
143 request_counter: AtomicUsize,
144 response_counter: AtomicUsize,
145 request_err: Option<ProxyError>,
146 }
147
148 #[async_trait]
149 impl ProxyMiddleware for RecordingMiddleware {
150 async fn on_request(
151 &self,
152 _req: &mut ProxyRequest,
153 _ctx: &mut ConnectionContext,
154 ) -> Result<(), ProxyError> {
155 if let Some(ref err) = self.request_err {
156 return Err(ProxyError::BadRequest(err.to_string()));
157 }
158 let seq = self.request_order.fetch_add(1, Ordering::SeqCst);
159 self.request_counter.store(seq, Ordering::SeqCst);
160 Ok(())
161 }
162
163 async fn on_response(
164 &self,
165 _res: &mut ProxyResponse,
166 _ctx: &ConnectionContext,
167 ) -> Result<(), ProxyError> {
168 let seq = self.response_order.fetch_add(1, Ordering::SeqCst);
169 self.response_counter.store(seq, Ordering::SeqCst);
170 Ok(())
171 }
172
173 fn name(&self) -> &'static str {
174 self.name
175 }
176 }
177
178 fn make_request() -> ProxyRequest {
179 ProxyRequest::new(
180 Method::POST,
181 "/v1/messages".into(),
182 HeaderMap::new(),
183 Bytes::from(r#"{"model":"test"}"#),
184 )
185 }
186
187 fn make_context() -> ConnectionContext {
188 ConnectionContext::new(1, crate::types::AgentType::Unknown, None, None)
189 }
190
191 fn make_response() -> ProxyResponse {
192 ProxyResponse::new(StatusCode::OK, HeaderMap::new(), Bytes::new(), false)
193 }
194
195 #[tokio::test]
196 async fn test_on_request_runs_in_registration_order() {
197 let order = Arc::new(AtomicUsize::new(0));
198 let mw_a = RecordingMiddleware {
199 name: "A",
200 request_order: order.clone(),
201 response_order: Arc::new(AtomicUsize::new(0)),
202 request_counter: AtomicUsize::new(0),
203 response_counter: AtomicUsize::new(0),
204 request_err: None,
205 };
206 let mw_b = RecordingMiddleware {
207 name: "B",
208 request_order: order.clone(),
209 response_order: Arc::new(AtomicUsize::new(0)),
210 request_counter: AtomicUsize::new(0),
211 response_counter: AtomicUsize::new(0),
212 request_err: None,
213 };
214 let mw_c = RecordingMiddleware {
215 name: "C",
216 request_order: order.clone(),
217 response_order: Arc::new(AtomicUsize::new(0)),
218 request_counter: AtomicUsize::new(0),
219 response_counter: AtomicUsize::new(0),
220 request_err: None,
221 };
222
223 let middlewares: Vec<Box<dyn ProxyMiddleware>> =
224 vec![Box::new(mw_a), Box::new(mw_b), Box::new(mw_c)];
225
226 let mut req = make_request();
227 let mut ctx = make_context();
228
229 run_on_request_chain(&middlewares, &mut req, &mut ctx)
230 .await
231 .unwrap();
232
233 assert_eq!(order.load(Ordering::SeqCst), 3);
235 }
236
237 #[tokio::test]
238 async fn test_on_response_runs_in_reverse_registration_order() {
239 let order = Arc::new(AtomicUsize::new(0));
240 let mw_a = RecordingMiddleware {
241 name: "A",
242 request_order: Arc::new(AtomicUsize::new(0)),
243 response_order: order.clone(),
244 request_counter: AtomicUsize::new(0),
245 response_counter: AtomicUsize::new(0),
246 request_err: None,
247 };
248 let mw_b = RecordingMiddleware {
249 name: "B",
250 request_order: Arc::new(AtomicUsize::new(0)),
251 response_order: order.clone(),
252 request_counter: AtomicUsize::new(0),
253 response_counter: AtomicUsize::new(0),
254 request_err: None,
255 };
256 let mw_c = RecordingMiddleware {
257 name: "C",
258 request_order: Arc::new(AtomicUsize::new(0)),
259 response_order: order.clone(),
260 request_counter: AtomicUsize::new(0),
261 response_counter: AtomicUsize::new(0),
262 request_err: None,
263 };
264
265 let middlewares: Vec<Box<dyn ProxyMiddleware>> =
266 vec![Box::new(mw_a), Box::new(mw_b), Box::new(mw_c)];
267
268 let mut res = make_response();
269 let ctx = make_context();
270
271 run_on_response_chain(&middlewares, &mut res, &ctx)
272 .await
273 .unwrap();
274
275 assert_eq!(order.load(Ordering::SeqCst), 3);
276 }
277
278 #[tokio::test]
279 async fn test_on_request_aborts_on_error() {
280 let mw_ok = RecordingMiddleware {
281 name: "ok",
282 request_order: Arc::new(AtomicUsize::new(0)),
283 response_order: Arc::new(AtomicUsize::new(0)),
284 request_counter: AtomicUsize::new(0),
285 response_counter: AtomicUsize::new(0),
286 request_err: None,
287 };
288 let mw_err = RecordingMiddleware {
289 name: "err",
290 request_order: Arc::new(AtomicUsize::new(0)),
291 response_order: Arc::new(AtomicUsize::new(0)),
292 request_counter: AtomicUsize::new(0),
293 response_counter: AtomicUsize::new(0),
294 request_err: Some(ProxyError::BadRequest("test error".into())),
295 };
296 let mw_never = RecordingMiddleware {
297 name: "never",
298 request_order: Arc::new(AtomicUsize::new(0)),
299 response_order: Arc::new(AtomicUsize::new(0)),
300 request_counter: AtomicUsize::new(0),
301 response_counter: AtomicUsize::new(0),
302 request_err: None,
303 };
304
305 let middlewares: Vec<Box<dyn ProxyMiddleware>> =
306 vec![Box::new(mw_ok), Box::new(mw_err), Box::new(mw_never)];
307
308 let mut req = make_request();
309 let mut ctx = make_context();
310
311 let result = run_on_request_chain(&middlewares, &mut req, &mut ctx).await;
312 assert!(result.is_err());
313 }
314}