oxify_connect_llm/
interceptor.rs

1//! Request/Response Interceptor System
2//!
3//! This module provides a flexible interceptor/middleware system that allows users to
4//! intercept and modify requests before they are sent and responses after they are received.
5//!
6//! # Examples
7//!
8//! ```
9//! use oxify_connect_llm::{
10//!     LlmProvider, LlmRequest, LlmResponse, Result,
11//!     InterceptorProvider, RequestInterceptor, ResponseInterceptor,
12//!     OpenAIProvider,
13//! };
14//! use async_trait::async_trait;
15//!
16//! // Create a request interceptor that adds a prefix to all prompts
17//! struct PrefixInterceptor {
18//!     prefix: String,
19//! }
20//!
21//! #[async_trait]
22//! impl RequestInterceptor for PrefixInterceptor {
23//!     async fn intercept_request(&self, mut request: LlmRequest) -> Result<LlmRequest> {
24//!         request.prompt = format!("{}{}", self.prefix, request.prompt);
25//!         Ok(request)
26//!     }
27//! }
28//!
29//! // Create a response interceptor that converts to uppercase
30//! struct UppercaseInterceptor;
31//!
32//! #[async_trait]
33//! impl ResponseInterceptor for UppercaseInterceptor {
34//!     async fn intercept_response(&self, mut response: LlmResponse) -> Result<LlmResponse> {
35//!         response.content = response.content.to_uppercase();
36//!         Ok(response)
37//!     }
38//! }
39//!
40//! # async fn example() -> Result<()> {
41//! // Wrap a provider with interceptors
42//! let provider = OpenAIProvider::new("key".to_string(), "gpt-4".to_string());
43//! let provider = InterceptorProvider::new(provider)
44//!     .with_request_interceptor(Box::new(PrefixInterceptor {
45//!         prefix: "Context: ".to_string(),
46//!     }))
47//!     .with_response_interceptor(Box::new(UppercaseInterceptor));
48//!
49//! // Use the provider normally - interceptors will be applied automatically
50//! # Ok(())
51//! # }
52//! ```
53
54use crate::{
55    EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmProvider, LlmRequest, LlmResponse,
56    LlmStream, Result, StreamingLlmProvider,
57};
58use async_trait::async_trait;
59use std::sync::Arc;
60
61/// Trait for intercepting and modifying requests before they are sent
62#[async_trait]
63pub trait RequestInterceptor: Send + Sync {
64    /// Intercept and potentially modify a request before it's sent to the provider
65    async fn intercept_request(&self, request: LlmRequest) -> Result<LlmRequest>;
66}
67
68/// Trait for intercepting and modifying responses after they are received
69#[async_trait]
70pub trait ResponseInterceptor: Send + Sync {
71    /// Intercept and potentially modify a response after it's received from the provider
72    async fn intercept_response(&self, response: LlmResponse) -> Result<LlmResponse>;
73}
74
75/// Trait for intercepting embedding requests
76#[async_trait]
77pub trait EmbeddingRequestInterceptor: Send + Sync {
78    /// Intercept and potentially modify an embedding request before it's sent
79    async fn intercept_embedding_request(
80        &self,
81        request: EmbeddingRequest,
82    ) -> Result<EmbeddingRequest>;
83}
84
85/// Trait for intercepting embedding responses
86#[async_trait]
87pub trait EmbeddingResponseInterceptor: Send + Sync {
88    /// Intercept and potentially modify an embedding response after it's received
89    async fn intercept_embedding_response(
90        &self,
91        response: EmbeddingResponse,
92    ) -> Result<EmbeddingResponse>;
93}
94
95/// Provider wrapper that applies interceptors to requests and responses
96pub struct InterceptorProvider<P> {
97    provider: Arc<P>,
98    request_interceptors: Vec<Box<dyn RequestInterceptor>>,
99    response_interceptors: Vec<Box<dyn ResponseInterceptor>>,
100}
101
102impl<P> InterceptorProvider<P> {
103    /// Create a new interceptor provider wrapping the given provider
104    pub fn new(provider: P) -> Self {
105        Self {
106            provider: Arc::new(provider),
107            request_interceptors: Vec::new(),
108            response_interceptors: Vec::new(),
109        }
110    }
111
112    /// Add a request interceptor (interceptors are applied in the order they are added)
113    pub fn with_request_interceptor(mut self, interceptor: Box<dyn RequestInterceptor>) -> Self {
114        self.request_interceptors.push(interceptor);
115        self
116    }
117
118    /// Add a response interceptor (interceptors are applied in the order they are added)
119    pub fn with_response_interceptor(mut self, interceptor: Box<dyn ResponseInterceptor>) -> Self {
120        self.response_interceptors.push(interceptor);
121        self
122    }
123
124    /// Apply all request interceptors to a request
125    async fn apply_request_interceptors(&self, mut request: LlmRequest) -> Result<LlmRequest> {
126        for interceptor in &self.request_interceptors {
127            request = interceptor.intercept_request(request).await?;
128        }
129        Ok(request)
130    }
131
132    /// Apply all response interceptors to a response
133    async fn apply_response_interceptors(&self, mut response: LlmResponse) -> Result<LlmResponse> {
134        for interceptor in &self.response_interceptors {
135            response = interceptor.intercept_response(response).await?;
136        }
137        Ok(response)
138    }
139}
140
141#[async_trait]
142impl<P: LlmProvider> LlmProvider for InterceptorProvider<P> {
143    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
144        let request = self.apply_request_interceptors(request).await?;
145        let response = self.provider.complete(request).await?;
146        self.apply_response_interceptors(response).await
147    }
148}
149
150#[async_trait]
151impl<P: StreamingLlmProvider> StreamingLlmProvider for InterceptorProvider<P> {
152    async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
153        let request = self.apply_request_interceptors(request).await?;
154        self.provider.complete_stream(request).await
155    }
156}
157
158/// Provider wrapper for embedding interceptors
159pub struct EmbeddingInterceptorProvider<P> {
160    provider: Arc<P>,
161    request_interceptors: Vec<Box<dyn EmbeddingRequestInterceptor>>,
162    response_interceptors: Vec<Box<dyn EmbeddingResponseInterceptor>>,
163}
164
165impl<P> EmbeddingInterceptorProvider<P> {
166    /// Create a new embedding interceptor provider wrapping the given provider
167    pub fn new(provider: P) -> Self {
168        Self {
169            provider: Arc::new(provider),
170            request_interceptors: Vec::new(),
171            response_interceptors: Vec::new(),
172        }
173    }
174
175    /// Add an embedding request interceptor
176    pub fn with_request_interceptor(
177        mut self,
178        interceptor: Box<dyn EmbeddingRequestInterceptor>,
179    ) -> Self {
180        self.request_interceptors.push(interceptor);
181        self
182    }
183
184    /// Add an embedding response interceptor
185    pub fn with_response_interceptor(
186        mut self,
187        interceptor: Box<dyn EmbeddingResponseInterceptor>,
188    ) -> Self {
189        self.response_interceptors.push(interceptor);
190        self
191    }
192
193    /// Apply all request interceptors to an embedding request
194    async fn apply_request_interceptors(
195        &self,
196        mut request: EmbeddingRequest,
197    ) -> Result<EmbeddingRequest> {
198        for interceptor in &self.request_interceptors {
199            request = interceptor.intercept_embedding_request(request).await?;
200        }
201        Ok(request)
202    }
203
204    /// Apply all response interceptors to an embedding response
205    async fn apply_response_interceptors(
206        &self,
207        mut response: EmbeddingResponse,
208    ) -> Result<EmbeddingResponse> {
209        for interceptor in &self.response_interceptors {
210            response = interceptor.intercept_embedding_response(response).await?;
211        }
212        Ok(response)
213    }
214}
215
216#[async_trait]
217impl<P: EmbeddingProvider> EmbeddingProvider for EmbeddingInterceptorProvider<P> {
218    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
219        let request = self.apply_request_interceptors(request).await?;
220        let response = self.provider.embed(request).await?;
221        self.apply_response_interceptors(response).await
222    }
223}
224
225// ===== Built-in Interceptors =====
226
227/// Interceptor that logs requests
228pub struct LoggingInterceptor {
229    prefix: String,
230}
231
232impl LoggingInterceptor {
233    /// Create a new logging interceptor with the given prefix
234    pub fn new(prefix: String) -> Self {
235        Self { prefix }
236    }
237}
238
239#[async_trait]
240impl RequestInterceptor for LoggingInterceptor {
241    async fn intercept_request(&self, request: LlmRequest) -> Result<LlmRequest> {
242        tracing::info!(
243            "{} Request: prompt_len={}, temp={:?}, max_tokens={:?}",
244            self.prefix,
245            request.prompt.len(),
246            request.temperature,
247            request.max_tokens
248        );
249        Ok(request)
250    }
251}
252
253#[async_trait]
254impl ResponseInterceptor for LoggingInterceptor {
255    async fn intercept_response(&self, response: LlmResponse) -> Result<LlmResponse> {
256        tracing::info!(
257            "{} Response: content_len={}, model={}, tokens={:?}",
258            self.prefix,
259            response.content.len(),
260            response.model,
261            response.usage.as_ref().map(|u| u.total_tokens)
262        );
263        Ok(response)
264    }
265}
266
267/// Interceptor that sanitizes prompts by removing sensitive patterns
268pub struct SanitizationInterceptor {
269    patterns: Vec<String>,
270}
271
272impl SanitizationInterceptor {
273    /// Create a new sanitization interceptor with patterns to remove
274    pub fn new(patterns: Vec<String>) -> Self {
275        Self { patterns }
276    }
277}
278
279#[async_trait]
280impl RequestInterceptor for SanitizationInterceptor {
281    async fn intercept_request(&self, mut request: LlmRequest) -> Result<LlmRequest> {
282        for pattern in &self.patterns {
283            request.prompt = request.prompt.replace(pattern, "[REDACTED]");
284            if let Some(ref mut system) = request.system_prompt {
285                *system = system.replace(pattern, "[REDACTED]");
286            }
287        }
288        Ok(request)
289    }
290}
291
292/// Interceptor that enforces maximum content length
293pub struct ContentLengthInterceptor {
294    max_prompt_length: usize,
295    max_response_length: usize,
296}
297
298impl ContentLengthInterceptor {
299    /// Create a new content length interceptor
300    pub fn new(max_prompt_length: usize, max_response_length: usize) -> Self {
301        Self {
302            max_prompt_length,
303            max_response_length,
304        }
305    }
306}
307
308#[async_trait]
309impl RequestInterceptor for ContentLengthInterceptor {
310    async fn intercept_request(&self, mut request: LlmRequest) -> Result<LlmRequest> {
311        if request.prompt.len() > self.max_prompt_length {
312            request.prompt.truncate(self.max_prompt_length);
313            request.prompt.push_str("...[truncated]");
314        }
315        Ok(request)
316    }
317}
318
319#[async_trait]
320impl ResponseInterceptor for ContentLengthInterceptor {
321    async fn intercept_response(&self, mut response: LlmResponse) -> Result<LlmResponse> {
322        if response.content.len() > self.max_response_length {
323            response.content.truncate(self.max_response_length);
324            response.content.push_str("...[truncated]");
325        }
326        Ok(response)
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::{LlmError, Usage};
334
335    // Mock provider for testing
336    struct MockProvider;
337
338    #[async_trait]
339    impl LlmProvider for MockProvider {
340        async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
341            Ok(LlmResponse {
342                content: format!("Response to: {}", request.prompt),
343                model: "mock-model".to_string(),
344                usage: Some(Usage {
345                    prompt_tokens: 10,
346                    completion_tokens: 20,
347                    total_tokens: 30,
348                }),
349                tool_calls: Vec::new(),
350            })
351        }
352    }
353
354    // Test interceptor that adds a prefix
355    struct PrefixInterceptor {
356        prefix: String,
357    }
358
359    #[async_trait]
360    impl RequestInterceptor for PrefixInterceptor {
361        async fn intercept_request(&self, mut request: LlmRequest) -> Result<LlmRequest> {
362            request.prompt = format!("{}{}", self.prefix, request.prompt);
363            Ok(request)
364        }
365    }
366
367    // Test interceptor that adds a suffix
368    struct SuffixInterceptor {
369        suffix: String,
370    }
371
372    #[async_trait]
373    impl ResponseInterceptor for SuffixInterceptor {
374        async fn intercept_response(&self, mut response: LlmResponse) -> Result<LlmResponse> {
375            response.content.push_str(&self.suffix);
376            Ok(response)
377        }
378    }
379
380    // Test interceptor that fails
381    struct FailingInterceptor;
382
383    #[async_trait]
384    impl RequestInterceptor for FailingInterceptor {
385        async fn intercept_request(&self, _request: LlmRequest) -> Result<LlmRequest> {
386            Err(LlmError::InvalidRequest("Test error".to_string()))
387        }
388    }
389
390    #[tokio::test]
391    async fn test_request_interceptor() {
392        let provider = MockProvider;
393        let provider = InterceptorProvider::new(provider).with_request_interceptor(Box::new(
394            PrefixInterceptor {
395                prefix: "PREFIX: ".to_string(),
396            },
397        ));
398
399        let request = LlmRequest {
400            prompt: "test".to_string(),
401            system_prompt: None,
402            temperature: None,
403            max_tokens: None,
404            tools: Vec::new(),
405            images: Vec::new(),
406        };
407
408        let response = provider.complete(request).await.unwrap();
409        assert_eq!(response.content, "Response to: PREFIX: test");
410    }
411
412    #[tokio::test]
413    async fn test_response_interceptor() {
414        let provider = MockProvider;
415        let provider = InterceptorProvider::new(provider).with_response_interceptor(Box::new(
416            SuffixInterceptor {
417                suffix: " SUFFIX".to_string(),
418            },
419        ));
420
421        let request = LlmRequest {
422            prompt: "test".to_string(),
423            system_prompt: None,
424            temperature: None,
425            max_tokens: None,
426            tools: Vec::new(),
427            images: Vec::new(),
428        };
429
430        let response = provider.complete(request).await.unwrap();
431        assert!(response.content.ends_with(" SUFFIX"));
432    }
433
434    #[tokio::test]
435    async fn test_multiple_interceptors() {
436        let provider = MockProvider;
437        let provider = InterceptorProvider::new(provider)
438            .with_request_interceptor(Box::new(PrefixInterceptor {
439                prefix: "A: ".to_string(),
440            }))
441            .with_request_interceptor(Box::new(PrefixInterceptor {
442                prefix: "B: ".to_string(),
443            }))
444            .with_response_interceptor(Box::new(SuffixInterceptor {
445                suffix: " X".to_string(),
446            }))
447            .with_response_interceptor(Box::new(SuffixInterceptor {
448                suffix: " Y".to_string(),
449            }));
450
451        let request = LlmRequest {
452            prompt: "test".to_string(),
453            system_prompt: None,
454            temperature: None,
455            max_tokens: None,
456            tools: Vec::new(),
457            images: Vec::new(),
458        };
459
460        let response = provider.complete(request).await.unwrap();
461        assert_eq!(response.content, "Response to: B: A: test X Y");
462    }
463
464    #[tokio::test]
465    async fn test_failing_interceptor() {
466        let provider = MockProvider;
467        let provider = InterceptorProvider::new(provider)
468            .with_request_interceptor(Box::new(FailingInterceptor));
469
470        let request = LlmRequest {
471            prompt: "test".to_string(),
472            system_prompt: None,
473            temperature: None,
474            max_tokens: None,
475            tools: Vec::new(),
476            images: Vec::new(),
477        };
478
479        let result = provider.complete(request).await;
480        assert!(result.is_err());
481    }
482
483    #[tokio::test]
484    async fn test_sanitization_interceptor() {
485        let provider = MockProvider;
486        let provider = InterceptorProvider::new(provider).with_request_interceptor(Box::new(
487            SanitizationInterceptor::new(vec!["secret".to_string(), "password".to_string()]),
488        ));
489
490        let request = LlmRequest {
491            prompt: "My secret is password123".to_string(),
492            system_prompt: None,
493            temperature: None,
494            max_tokens: None,
495            tools: Vec::new(),
496            images: Vec::new(),
497        };
498
499        let response = provider.complete(request).await.unwrap();
500        assert_eq!(
501            response.content,
502            "Response to: My [REDACTED] is [REDACTED]123"
503        );
504    }
505
506    #[tokio::test]
507    async fn test_content_length_interceptor() {
508        let provider = MockProvider;
509        let provider = InterceptorProvider::new(provider)
510            .with_request_interceptor(Box::new(ContentLengthInterceptor::new(10, 100)))
511            .with_response_interceptor(Box::new(ContentLengthInterceptor::new(100, 20)));
512
513        let request = LlmRequest {
514            prompt: "This is a very long prompt that should be truncated".to_string(),
515            system_prompt: None,
516            temperature: None,
517            max_tokens: None,
518            tools: Vec::new(),
519            images: Vec::new(),
520        };
521
522        let response = provider.complete(request).await.unwrap();
523        // Response should be truncated
524        assert!(response.content.ends_with("...[truncated]"));
525        assert!(response.content.len() <= 20 + "...[truncated]".len());
526    }
527
528    #[tokio::test]
529    async fn test_logging_interceptor() {
530        let provider = MockProvider;
531        let provider = InterceptorProvider::new(provider)
532            .with_request_interceptor(Box::new(LoggingInterceptor::new("TEST".to_string())))
533            .with_response_interceptor(Box::new(LoggingInterceptor::new("TEST".to_string())));
534
535        let request = LlmRequest {
536            prompt: "test".to_string(),
537            system_prompt: None,
538            temperature: Some(0.7),
539            max_tokens: Some(100),
540            tools: Vec::new(),
541            images: Vec::new(),
542        };
543
544        // Should not fail even with logging
545        let response = provider.complete(request).await.unwrap();
546        assert!(response.content.contains("Response to: test"));
547    }
548}