openai_ergonomic/
interceptor.rs

1//! Interceptor system for middleware and observability.
2//!
3//! Interceptors provide hooks into the request/response lifecycle, enabling:
4//! - Telemetry and tracing
5//! - Logging and debugging
6//! - Metrics collection
7//! - Request/response transformation
8//! - Custom error handling
9//! - Authentication enhancement
10//!
11//! # Architecture
12//!
13//! The interceptor system follows a chain-of-responsibility pattern where
14//! multiple interceptors can be registered and executed in order. Each
15//! interceptor can:
16//!
17//! - Modify request context before sending
18//! - Observe and react to responses
19//! - Handle streaming chunks
20//! - Process errors
21//!
22//! # Example
23//!
24//! ```rust,ignore
25//! use openai_ergonomic::{Client, Interceptor, BeforeRequestContext};
26//!
27//! struct LoggingInterceptor;
28//!
29//! #[async_trait::async_trait]
30//! impl Interceptor for LoggingInterceptor {
31//!     async fn before_request(&self, ctx: &mut BeforeRequestContext<'_>) -> Result<()> {
32//!         println!("Calling {} with model {}", ctx.operation, ctx.model);
33//!         Ok(())
34//!     }
35//! }
36//!
37//! let client = Client::from_env()?
38//!     .with_interceptor(Box::new(LoggingInterceptor))
39//!     .build();
40//! ```
41
42use crate::Result;
43use std::time::Duration;
44
45/// Context provided before a request is sent.
46///
47/// This context contains all the information about the request that's about
48/// to be made, and allows interceptors to store state that will be
49/// carried through the request lifecycle.
50///
51/// The generic type parameter `T` allows interceptors to define their own
52/// state type for passing data between lifecycle hooks.
53#[derive(Debug)]
54pub struct BeforeRequestContext<'a, T = ()> {
55    /// The operation being performed (e.g., "chat", "embedding", "`image_generation`")
56    pub operation: &'a str,
57    /// The model being used for the request
58    pub model: &'a str,
59    /// The serialized request body as JSON
60    pub request_json: &'a str,
61    /// Typed state for passing data between interceptor hooks
62    pub state: &'a mut T,
63}
64
65/// Context provided after a successful non-streaming response.
66///
67/// This context contains the complete request and response information,
68/// allowing interceptors to observe and react to successful API calls.
69#[derive(Debug)]
70pub struct AfterResponseContext<'a, T = ()> {
71    /// The operation that was performed
72    pub operation: &'a str,
73    /// The model that was used
74    pub model: &'a str,
75    /// The original request body as JSON
76    pub request_json: &'a str,
77    /// The response body as JSON
78    pub response_json: &'a str,
79    /// Time taken for the request
80    pub duration: Duration,
81    /// Number of input tokens used (if available)
82    pub input_tokens: Option<i64>,
83    /// Number of output tokens generated (if available)
84    pub output_tokens: Option<i64>,
85    /// Typed state from the request context
86    pub state: &'a T,
87}
88
89/// Context provided for each chunk in a streaming response.
90///
91/// This context allows interceptors to process streaming data as it arrives,
92/// useful for real-time monitoring or transformation.
93#[derive(Debug)]
94pub struct StreamChunkContext<'a, T = ()> {
95    /// The operation being performed
96    pub operation: &'a str,
97    /// The model being used
98    pub model: &'a str,
99    /// The original request body as JSON
100    pub request_json: &'a str,
101    /// The current chunk data as JSON
102    pub chunk_json: &'a str,
103    /// Zero-based index of this chunk
104    pub chunk_index: usize,
105    /// Typed state from the request context
106    pub state: &'a T,
107}
108
109/// Context provided when a streaming response completes.
110///
111/// This context provides summary information about the completed stream,
112/// including total chunks and token usage.
113#[derive(Debug)]
114pub struct StreamEndContext<'a, T = ()> {
115    /// The operation that was performed
116    pub operation: &'a str,
117    /// The model that was used
118    pub model: &'a str,
119    /// The original request body as JSON
120    pub request_json: &'a str,
121    /// Total number of chunks received
122    pub total_chunks: usize,
123    /// Total time for the streaming response
124    pub duration: Duration,
125    /// Total input tokens used (if available)
126    pub input_tokens: Option<i64>,
127    /// Total output tokens generated (if available)
128    pub output_tokens: Option<i64>,
129    /// Typed state from the request context
130    pub state: &'a T,
131}
132
133/// Context provided when an error occurs.
134///
135/// This context allows interceptors to observe and react to errors,
136/// useful for error tracking and recovery strategies.
137#[derive(Debug)]
138pub struct ErrorContext<'a, T = ()> {
139    /// The operation that failed
140    pub operation: &'a str,
141    /// The model being used (if known)
142    pub model: Option<&'a str>,
143    /// The request body as JSON (if available)
144    pub request_json: Option<&'a str>,
145    /// The error that occurred
146    pub error: &'a crate::Error,
147    /// Typed state from the request context (if available)
148    pub state: Option<&'a T>,
149}
150
151/// Trait for implementing interceptors.
152///
153/// Interceptors can hook into various stages of the request/response lifecycle.
154/// All methods have default no-op implementations, so you only need to implement
155/// the hooks you're interested in.
156///
157/// The generic type parameter `T` defines the state type that will be passed
158/// through the request lifecycle. Use `()` (the default) for simple interceptors
159/// that don't need to maintain state across hooks.
160#[async_trait::async_trait]
161pub trait Interceptor<T = ()>: Send + Sync {
162    /// Called before a request is sent.
163    ///
164    /// This method can modify the state that will be passed through
165    /// the request lifecycle.
166    async fn before_request(&self, _ctx: &mut BeforeRequestContext<'_, T>) -> Result<()> {
167        Ok(())
168    }
169
170    /// Called after a successful non-streaming response is received.
171    async fn after_response(&self, _ctx: &AfterResponseContext<'_, T>) -> Result<()> {
172        Ok(())
173    }
174
175    /// Called for each chunk in a streaming response.
176    async fn on_stream_chunk(&self, _ctx: &StreamChunkContext<'_, T>) -> Result<()> {
177        Ok(())
178    }
179
180    /// Called when a streaming response completes successfully.
181    async fn on_stream_end(&self, _ctx: &StreamEndContext<'_, T>) -> Result<()> {
182        Ok(())
183    }
184
185    /// Called when an error occurs at any stage.
186    ///
187    /// Note: This method doesn't return a Result as it's called during
188    /// error handling and shouldn't fail.
189    async fn on_error(&self, _ctx: &ErrorContext<'_, T>) {
190        // Default: no-op
191    }
192}
193
194/// A chain of interceptors that are executed in order.
195///
196/// This struct manages multiple interceptors and ensures they are
197/// called in the correct order for each lifecycle event.
198///
199/// The generic type parameter `T` defines the state type that will be passed
200/// through the request lifecycle. Use `()` (the default) for interceptors
201/// that don't need to maintain state.
202pub struct InterceptorChain<T = ()> {
203    interceptors: Vec<Box<dyn Interceptor<T>>>,
204}
205
206impl<T> Default for InterceptorChain<T> {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212impl<T> InterceptorChain<T> {
213    /// Create a new, empty interceptor chain.
214    pub fn new() -> Self {
215        Self {
216            interceptors: Vec::new(),
217        }
218    }
219
220    /// Add an interceptor to the chain.
221    ///
222    /// Interceptors are executed in the order they are added.
223    pub fn add(&mut self, interceptor: Box<dyn Interceptor<T>>) {
224        self.interceptors.push(interceptor);
225    }
226
227    /// Execute the `before_request` hook for all interceptors.
228    pub async fn before_request(&self, ctx: &mut BeforeRequestContext<'_, T>) -> Result<()> {
229        for interceptor in &self.interceptors {
230            interceptor.before_request(ctx).await?;
231        }
232        Ok(())
233    }
234
235    /// Execute the `after_response` hook for all interceptors.
236    pub async fn after_response(&self, ctx: &AfterResponseContext<'_, T>) -> Result<()>
237    where
238        T: Sync,
239    {
240        for interceptor in &self.interceptors {
241            interceptor.after_response(ctx).await?;
242        }
243        Ok(())
244    }
245
246    /// Execute the `on_stream_chunk` hook for all interceptors.
247    pub async fn on_stream_chunk(&self, ctx: &StreamChunkContext<'_, T>) -> Result<()>
248    where
249        T: Sync,
250    {
251        for interceptor in &self.interceptors {
252            interceptor.on_stream_chunk(ctx).await?;
253        }
254        Ok(())
255    }
256
257    /// Execute the `on_stream_end` hook for all interceptors.
258    pub async fn on_stream_end(&self, ctx: &StreamEndContext<'_, T>) -> Result<()>
259    where
260        T: Sync,
261    {
262        for interceptor in &self.interceptors {
263            interceptor.on_stream_end(ctx).await?;
264        }
265        Ok(())
266    }
267
268    /// Execute the `on_error` hook for all interceptors.
269    ///
270    /// Errors in individual interceptors are ignored to prevent
271    /// cascading failures during error handling.
272    pub async fn on_error(&self, ctx: &ErrorContext<'_, T>)
273    where
274        T: Sync,
275    {
276        for interceptor in &self.interceptors {
277            // Ignore errors in error handlers to prevent cascading failures
278            interceptor.on_error(ctx).await;
279        }
280    }
281
282    /// Check if the chain has any interceptors.
283    pub fn is_empty(&self) -> bool {
284        self.interceptors.is_empty()
285    }
286
287    /// Get the number of interceptors in the chain.
288    pub fn len(&self) -> usize {
289        self.interceptors.len()
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use std::collections::HashMap;
297    use std::sync::atomic::{AtomicUsize, Ordering};
298    use std::sync::Arc;
299
300    /// A test interceptor that tracks how many times each method was called.
301    #[allow(clippy::struct_field_names)]
302    struct TestInterceptor {
303        before_request_count: Arc<AtomicUsize>,
304        after_response_count: Arc<AtomicUsize>,
305        on_stream_chunk_count: Arc<AtomicUsize>,
306        on_stream_end_count: Arc<AtomicUsize>,
307        on_error_count: Arc<AtomicUsize>,
308    }
309
310    impl TestInterceptor {
311        fn new() -> Self {
312            Self {
313                before_request_count: Arc::new(AtomicUsize::new(0)),
314                after_response_count: Arc::new(AtomicUsize::new(0)),
315                on_stream_chunk_count: Arc::new(AtomicUsize::new(0)),
316                on_stream_end_count: Arc::new(AtomicUsize::new(0)),
317                on_error_count: Arc::new(AtomicUsize::new(0)),
318            }
319        }
320    }
321
322    #[async_trait::async_trait]
323    impl Interceptor for TestInterceptor {
324        async fn before_request(&self, _ctx: &mut BeforeRequestContext<'_>) -> Result<()> {
325            self.before_request_count.fetch_add(1, Ordering::SeqCst);
326            Ok(())
327        }
328
329        async fn after_response(&self, _ctx: &AfterResponseContext<'_>) -> Result<()> {
330            self.after_response_count.fetch_add(1, Ordering::SeqCst);
331            Ok(())
332        }
333
334        async fn on_stream_chunk(&self, _ctx: &StreamChunkContext<'_>) -> Result<()> {
335            self.on_stream_chunk_count.fetch_add(1, Ordering::SeqCst);
336            Ok(())
337        }
338
339        async fn on_stream_end(&self, _ctx: &StreamEndContext<'_>) -> Result<()> {
340            self.on_stream_end_count.fetch_add(1, Ordering::SeqCst);
341            Ok(())
342        }
343
344        async fn on_error(&self, _ctx: &ErrorContext<'_>) {
345            self.on_error_count.fetch_add(1, Ordering::SeqCst);
346        }
347    }
348
349    #[tokio::test]
350    async fn test_interceptor_chain_executes_in_order() {
351        let mut chain = InterceptorChain::new();
352        let interceptor1 = TestInterceptor::new();
353        let interceptor2 = TestInterceptor::new();
354
355        let count1 = interceptor1.before_request_count.clone();
356        let count2 = interceptor2.before_request_count.clone();
357
358        chain.add(Box::new(interceptor1));
359        chain.add(Box::new(interceptor2));
360
361        // Test before_request
362        let mut state = ();
363        let mut ctx = BeforeRequestContext {
364            operation: "test",
365            model: "gpt-4",
366            request_json: "{}",
367            state: &mut state,
368        };
369        chain.before_request(&mut ctx).await.unwrap();
370
371        assert_eq!(count1.load(Ordering::SeqCst), 1);
372        assert_eq!(count2.load(Ordering::SeqCst), 1);
373    }
374
375    #[tokio::test]
376    async fn test_interceptor_chain_handles_errors() {
377        struct FailingInterceptor;
378
379        #[async_trait::async_trait]
380        impl Interceptor for FailingInterceptor {
381            async fn before_request(&self, _ctx: &mut BeforeRequestContext<'_>) -> Result<()> {
382                Err(crate::Error::Internal("Test error".to_string()))
383            }
384        }
385
386        let mut chain = InterceptorChain::new();
387        chain.add(Box::new(FailingInterceptor));
388
389        let mut state = ();
390        let mut ctx = BeforeRequestContext {
391            operation: "test",
392            model: "gpt-4",
393            request_json: "{}",
394            state: &mut state,
395        };
396
397        let result = chain.before_request(&mut ctx).await;
398        assert!(result.is_err());
399    }
400
401    #[tokio::test]
402    async fn test_interceptor_chain_empty() {
403        let chain = InterceptorChain::new();
404        assert!(chain.is_empty());
405        assert_eq!(chain.len(), 0);
406
407        // Empty chain should succeed without doing anything
408        let mut state = ();
409        let mut ctx = BeforeRequestContext {
410            operation: "test",
411            model: "gpt-4",
412            request_json: "{}",
413            state: &mut state,
414        };
415        chain.before_request(&mut ctx).await.unwrap();
416    }
417
418    #[tokio::test]
419    async fn test_state_passing() {
420        struct StateInterceptor;
421
422        #[async_trait::async_trait]
423        impl Interceptor<HashMap<String, String>> for StateInterceptor {
424            async fn before_request(
425                &self,
426                ctx: &mut BeforeRequestContext<'_, HashMap<String, String>>,
427            ) -> Result<()> {
428                ctx.state
429                    .insert("test_key".to_string(), "test_value".to_string());
430                Ok(())
431            }
432        }
433
434        let mut chain = InterceptorChain::new();
435        chain.add(Box::new(StateInterceptor));
436
437        let mut state = HashMap::new();
438        let mut ctx = BeforeRequestContext {
439            operation: "test",
440            model: "gpt-4",
441            request_json: "{}",
442            state: &mut state,
443        };
444
445        chain.before_request(&mut ctx).await.unwrap();
446        assert_eq!(state.get("test_key"), Some(&"test_value".to_string()));
447    }
448
449    #[tokio::test]
450    async fn test_error_handler_doesnt_propagate_errors() {
451        #[allow(dead_code)]
452        struct ErrorInterceptor {
453            called: Arc<AtomicUsize>,
454        }
455
456        #[async_trait::async_trait]
457        impl Interceptor for ErrorInterceptor {
458            async fn on_error(&self, _ctx: &ErrorContext<'_>) {
459                self.called.fetch_add(1, Ordering::SeqCst);
460                // This would panic in a real scenario, but shouldn't crash the chain
461                panic!("This panic should be caught");
462            }
463        }
464
465        let chain: InterceptorChain<()> = InterceptorChain::new();
466        let error = crate::Error::Internal("Test".to_string());
467        let ctx = ErrorContext {
468            operation: "test",
469            model: None,
470            request_json: None,
471            error: &error,
472            state: None,
473        };
474
475        // Should not panic even though the interceptor panics
476        chain.on_error(&ctx).await;
477    }
478}