1use crate::Result;
43use std::time::Duration;
44
45#[derive(Debug)]
54pub struct BeforeRequestContext<'a, T = ()> {
55 pub operation: &'a str,
57 pub model: &'a str,
59 pub request_json: &'a str,
61 pub state: &'a mut T,
63}
64
65#[derive(Debug)]
70pub struct AfterResponseContext<'a, T = ()> {
71 pub operation: &'a str,
73 pub model: &'a str,
75 pub request_json: &'a str,
77 pub response_json: &'a str,
79 pub duration: Duration,
81 pub input_tokens: Option<i64>,
83 pub output_tokens: Option<i64>,
85 pub state: &'a T,
87}
88
89#[derive(Debug)]
94pub struct StreamChunkContext<'a, T = ()> {
95 pub operation: &'a str,
97 pub model: &'a str,
99 pub request_json: &'a str,
101 pub chunk_json: &'a str,
103 pub chunk_index: usize,
105 pub state: &'a T,
107}
108
109#[derive(Debug)]
114pub struct StreamEndContext<'a, T = ()> {
115 pub operation: &'a str,
117 pub model: &'a str,
119 pub request_json: &'a str,
121 pub total_chunks: usize,
123 pub duration: Duration,
125 pub input_tokens: Option<i64>,
127 pub output_tokens: Option<i64>,
129 pub state: &'a T,
131}
132
133#[derive(Debug)]
138pub struct ErrorContext<'a, T = ()> {
139 pub operation: &'a str,
141 pub model: Option<&'a str>,
143 pub request_json: Option<&'a str>,
145 pub error: &'a crate::Error,
147 pub state: Option<&'a T>,
149}
150
151#[async_trait::async_trait]
161pub trait Interceptor<T = ()>: Send + Sync {
162 async fn before_request(&self, _ctx: &mut BeforeRequestContext<'_, T>) -> Result<()> {
167 Ok(())
168 }
169
170 async fn after_response(&self, _ctx: &AfterResponseContext<'_, T>) -> Result<()> {
172 Ok(())
173 }
174
175 async fn on_stream_chunk(&self, _ctx: &StreamChunkContext<'_, T>) -> Result<()> {
177 Ok(())
178 }
179
180 async fn on_stream_end(&self, _ctx: &StreamEndContext<'_, T>) -> Result<()> {
182 Ok(())
183 }
184
185 async fn on_error(&self, _ctx: &ErrorContext<'_, T>) {
190 }
192}
193
194pub struct InterceptorChain<T = ()> {
203 interceptors: Vec<Box<dyn Interceptor<T>>>,
204}
205
206impl<T> Default for InterceptorChain<T> {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212impl<T> InterceptorChain<T> {
213 pub fn new() -> Self {
215 Self {
216 interceptors: Vec::new(),
217 }
218 }
219
220 pub fn add(&mut self, interceptor: Box<dyn Interceptor<T>>) {
224 self.interceptors.push(interceptor);
225 }
226
227 pub async fn before_request(&self, ctx: &mut BeforeRequestContext<'_, T>) -> Result<()> {
229 for interceptor in &self.interceptors {
230 interceptor.before_request(ctx).await?;
231 }
232 Ok(())
233 }
234
235 pub async fn after_response(&self, ctx: &AfterResponseContext<'_, T>) -> Result<()>
237 where
238 T: Sync,
239 {
240 for interceptor in &self.interceptors {
241 interceptor.after_response(ctx).await?;
242 }
243 Ok(())
244 }
245
246 pub async fn on_stream_chunk(&self, ctx: &StreamChunkContext<'_, T>) -> Result<()>
248 where
249 T: Sync,
250 {
251 for interceptor in &self.interceptors {
252 interceptor.on_stream_chunk(ctx).await?;
253 }
254 Ok(())
255 }
256
257 pub async fn on_stream_end(&self, ctx: &StreamEndContext<'_, T>) -> Result<()>
259 where
260 T: Sync,
261 {
262 for interceptor in &self.interceptors {
263 interceptor.on_stream_end(ctx).await?;
264 }
265 Ok(())
266 }
267
268 pub async fn on_error(&self, ctx: &ErrorContext<'_, T>)
273 where
274 T: Sync,
275 {
276 for interceptor in &self.interceptors {
277 interceptor.on_error(ctx).await;
279 }
280 }
281
282 pub fn is_empty(&self) -> bool {
284 self.interceptors.is_empty()
285 }
286
287 pub fn len(&self) -> usize {
289 self.interceptors.len()
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use std::collections::HashMap;
297 use std::sync::atomic::{AtomicUsize, Ordering};
298 use std::sync::Arc;
299
300 #[allow(clippy::struct_field_names)]
302 struct TestInterceptor {
303 before_request_count: Arc<AtomicUsize>,
304 after_response_count: Arc<AtomicUsize>,
305 on_stream_chunk_count: Arc<AtomicUsize>,
306 on_stream_end_count: Arc<AtomicUsize>,
307 on_error_count: Arc<AtomicUsize>,
308 }
309
310 impl TestInterceptor {
311 fn new() -> Self {
312 Self {
313 before_request_count: Arc::new(AtomicUsize::new(0)),
314 after_response_count: Arc::new(AtomicUsize::new(0)),
315 on_stream_chunk_count: Arc::new(AtomicUsize::new(0)),
316 on_stream_end_count: Arc::new(AtomicUsize::new(0)),
317 on_error_count: Arc::new(AtomicUsize::new(0)),
318 }
319 }
320 }
321
322 #[async_trait::async_trait]
323 impl Interceptor for TestInterceptor {
324 async fn before_request(&self, _ctx: &mut BeforeRequestContext<'_>) -> Result<()> {
325 self.before_request_count.fetch_add(1, Ordering::SeqCst);
326 Ok(())
327 }
328
329 async fn after_response(&self, _ctx: &AfterResponseContext<'_>) -> Result<()> {
330 self.after_response_count.fetch_add(1, Ordering::SeqCst);
331 Ok(())
332 }
333
334 async fn on_stream_chunk(&self, _ctx: &StreamChunkContext<'_>) -> Result<()> {
335 self.on_stream_chunk_count.fetch_add(1, Ordering::SeqCst);
336 Ok(())
337 }
338
339 async fn on_stream_end(&self, _ctx: &StreamEndContext<'_>) -> Result<()> {
340 self.on_stream_end_count.fetch_add(1, Ordering::SeqCst);
341 Ok(())
342 }
343
344 async fn on_error(&self, _ctx: &ErrorContext<'_>) {
345 self.on_error_count.fetch_add(1, Ordering::SeqCst);
346 }
347 }
348
349 #[tokio::test]
350 async fn test_interceptor_chain_executes_in_order() {
351 let mut chain = InterceptorChain::new();
352 let interceptor1 = TestInterceptor::new();
353 let interceptor2 = TestInterceptor::new();
354
355 let count1 = interceptor1.before_request_count.clone();
356 let count2 = interceptor2.before_request_count.clone();
357
358 chain.add(Box::new(interceptor1));
359 chain.add(Box::new(interceptor2));
360
361 let mut state = ();
363 let mut ctx = BeforeRequestContext {
364 operation: "test",
365 model: "gpt-4",
366 request_json: "{}",
367 state: &mut state,
368 };
369 chain.before_request(&mut ctx).await.unwrap();
370
371 assert_eq!(count1.load(Ordering::SeqCst), 1);
372 assert_eq!(count2.load(Ordering::SeqCst), 1);
373 }
374
375 #[tokio::test]
376 async fn test_interceptor_chain_handles_errors() {
377 struct FailingInterceptor;
378
379 #[async_trait::async_trait]
380 impl Interceptor for FailingInterceptor {
381 async fn before_request(&self, _ctx: &mut BeforeRequestContext<'_>) -> Result<()> {
382 Err(crate::Error::Internal("Test error".to_string()))
383 }
384 }
385
386 let mut chain = InterceptorChain::new();
387 chain.add(Box::new(FailingInterceptor));
388
389 let mut state = ();
390 let mut ctx = BeforeRequestContext {
391 operation: "test",
392 model: "gpt-4",
393 request_json: "{}",
394 state: &mut state,
395 };
396
397 let result = chain.before_request(&mut ctx).await;
398 assert!(result.is_err());
399 }
400
401 #[tokio::test]
402 async fn test_interceptor_chain_empty() {
403 let chain = InterceptorChain::new();
404 assert!(chain.is_empty());
405 assert_eq!(chain.len(), 0);
406
407 let mut state = ();
409 let mut ctx = BeforeRequestContext {
410 operation: "test",
411 model: "gpt-4",
412 request_json: "{}",
413 state: &mut state,
414 };
415 chain.before_request(&mut ctx).await.unwrap();
416 }
417
418 #[tokio::test]
419 async fn test_state_passing() {
420 struct StateInterceptor;
421
422 #[async_trait::async_trait]
423 impl Interceptor<HashMap<String, String>> for StateInterceptor {
424 async fn before_request(
425 &self,
426 ctx: &mut BeforeRequestContext<'_, HashMap<String, String>>,
427 ) -> Result<()> {
428 ctx.state
429 .insert("test_key".to_string(), "test_value".to_string());
430 Ok(())
431 }
432 }
433
434 let mut chain = InterceptorChain::new();
435 chain.add(Box::new(StateInterceptor));
436
437 let mut state = HashMap::new();
438 let mut ctx = BeforeRequestContext {
439 operation: "test",
440 model: "gpt-4",
441 request_json: "{}",
442 state: &mut state,
443 };
444
445 chain.before_request(&mut ctx).await.unwrap();
446 assert_eq!(state.get("test_key"), Some(&"test_value".to_string()));
447 }
448
449 #[tokio::test]
450 async fn test_error_handler_doesnt_propagate_errors() {
451 #[allow(dead_code)]
452 struct ErrorInterceptor {
453 called: Arc<AtomicUsize>,
454 }
455
456 #[async_trait::async_trait]
457 impl Interceptor for ErrorInterceptor {
458 async fn on_error(&self, _ctx: &ErrorContext<'_>) {
459 self.called.fetch_add(1, Ordering::SeqCst);
460 panic!("This panic should be caught");
462 }
463 }
464
465 let chain: InterceptorChain<()> = InterceptorChain::new();
466 let error = crate::Error::Internal("Test".to_string());
467 let ctx = ErrorContext {
468 operation: "test",
469 model: None,
470 request_json: None,
471 error: &error,
472 state: None,
473 };
474
475 chain.on_error(&ctx).await;
477 }
478}