1use crate::{
55 EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmProvider, LlmRequest, LlmResponse,
56 LlmStream, Result, StreamingLlmProvider,
57};
58use async_trait::async_trait;
59use std::sync::Arc;
60
61#[async_trait]
63pub trait RequestInterceptor: Send + Sync {
64 async fn intercept_request(&self, request: LlmRequest) -> Result<LlmRequest>;
66}
67
68#[async_trait]
70pub trait ResponseInterceptor: Send + Sync {
71 async fn intercept_response(&self, response: LlmResponse) -> Result<LlmResponse>;
73}
74
75#[async_trait]
77pub trait EmbeddingRequestInterceptor: Send + Sync {
78 async fn intercept_embedding_request(
80 &self,
81 request: EmbeddingRequest,
82 ) -> Result<EmbeddingRequest>;
83}
84
85#[async_trait]
87pub trait EmbeddingResponseInterceptor: Send + Sync {
88 async fn intercept_embedding_response(
90 &self,
91 response: EmbeddingResponse,
92 ) -> Result<EmbeddingResponse>;
93}
94
95pub struct InterceptorProvider<P> {
97 provider: Arc<P>,
98 request_interceptors: Vec<Box<dyn RequestInterceptor>>,
99 response_interceptors: Vec<Box<dyn ResponseInterceptor>>,
100}
101
102impl<P> InterceptorProvider<P> {
103 pub fn new(provider: P) -> Self {
105 Self {
106 provider: Arc::new(provider),
107 request_interceptors: Vec::new(),
108 response_interceptors: Vec::new(),
109 }
110 }
111
112 pub fn with_request_interceptor(mut self, interceptor: Box<dyn RequestInterceptor>) -> Self {
114 self.request_interceptors.push(interceptor);
115 self
116 }
117
118 pub fn with_response_interceptor(mut self, interceptor: Box<dyn ResponseInterceptor>) -> Self {
120 self.response_interceptors.push(interceptor);
121 self
122 }
123
124 async fn apply_request_interceptors(&self, mut request: LlmRequest) -> Result<LlmRequest> {
126 for interceptor in &self.request_interceptors {
127 request = interceptor.intercept_request(request).await?;
128 }
129 Ok(request)
130 }
131
132 async fn apply_response_interceptors(&self, mut response: LlmResponse) -> Result<LlmResponse> {
134 for interceptor in &self.response_interceptors {
135 response = interceptor.intercept_response(response).await?;
136 }
137 Ok(response)
138 }
139}
140
141#[async_trait]
142impl<P: LlmProvider> LlmProvider for InterceptorProvider<P> {
143 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
144 let request = self.apply_request_interceptors(request).await?;
145 let response = self.provider.complete(request).await?;
146 self.apply_response_interceptors(response).await
147 }
148}
149
150#[async_trait]
151impl<P: StreamingLlmProvider> StreamingLlmProvider for InterceptorProvider<P> {
152 async fn complete_stream(&self, request: LlmRequest) -> Result<LlmStream> {
153 let request = self.apply_request_interceptors(request).await?;
154 self.provider.complete_stream(request).await
155 }
156}
157
158pub struct EmbeddingInterceptorProvider<P> {
160 provider: Arc<P>,
161 request_interceptors: Vec<Box<dyn EmbeddingRequestInterceptor>>,
162 response_interceptors: Vec<Box<dyn EmbeddingResponseInterceptor>>,
163}
164
165impl<P> EmbeddingInterceptorProvider<P> {
166 pub fn new(provider: P) -> Self {
168 Self {
169 provider: Arc::new(provider),
170 request_interceptors: Vec::new(),
171 response_interceptors: Vec::new(),
172 }
173 }
174
175 pub fn with_request_interceptor(
177 mut self,
178 interceptor: Box<dyn EmbeddingRequestInterceptor>,
179 ) -> Self {
180 self.request_interceptors.push(interceptor);
181 self
182 }
183
184 pub fn with_response_interceptor(
186 mut self,
187 interceptor: Box<dyn EmbeddingResponseInterceptor>,
188 ) -> Self {
189 self.response_interceptors.push(interceptor);
190 self
191 }
192
193 async fn apply_request_interceptors(
195 &self,
196 mut request: EmbeddingRequest,
197 ) -> Result<EmbeddingRequest> {
198 for interceptor in &self.request_interceptors {
199 request = interceptor.intercept_embedding_request(request).await?;
200 }
201 Ok(request)
202 }
203
204 async fn apply_response_interceptors(
206 &self,
207 mut response: EmbeddingResponse,
208 ) -> Result<EmbeddingResponse> {
209 for interceptor in &self.response_interceptors {
210 response = interceptor.intercept_embedding_response(response).await?;
211 }
212 Ok(response)
213 }
214}
215
216#[async_trait]
217impl<P: EmbeddingProvider> EmbeddingProvider for EmbeddingInterceptorProvider<P> {
218 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
219 let request = self.apply_request_interceptors(request).await?;
220 let response = self.provider.embed(request).await?;
221 self.apply_response_interceptors(response).await
222 }
223}
224
225pub struct LoggingInterceptor {
229 prefix: String,
230}
231
232impl LoggingInterceptor {
233 pub fn new(prefix: String) -> Self {
235 Self { prefix }
236 }
237}
238
239#[async_trait]
240impl RequestInterceptor for LoggingInterceptor {
241 async fn intercept_request(&self, request: LlmRequest) -> Result<LlmRequest> {
242 tracing::info!(
243 "{} Request: prompt_len={}, temp={:?}, max_tokens={:?}",
244 self.prefix,
245 request.prompt.len(),
246 request.temperature,
247 request.max_tokens
248 );
249 Ok(request)
250 }
251}
252
253#[async_trait]
254impl ResponseInterceptor for LoggingInterceptor {
255 async fn intercept_response(&self, response: LlmResponse) -> Result<LlmResponse> {
256 tracing::info!(
257 "{} Response: content_len={}, model={}, tokens={:?}",
258 self.prefix,
259 response.content.len(),
260 response.model,
261 response.usage.as_ref().map(|u| u.total_tokens)
262 );
263 Ok(response)
264 }
265}
266
267pub struct SanitizationInterceptor {
269 patterns: Vec<String>,
270}
271
272impl SanitizationInterceptor {
273 pub fn new(patterns: Vec<String>) -> Self {
275 Self { patterns }
276 }
277}
278
279#[async_trait]
280impl RequestInterceptor for SanitizationInterceptor {
281 async fn intercept_request(&self, mut request: LlmRequest) -> Result<LlmRequest> {
282 for pattern in &self.patterns {
283 request.prompt = request.prompt.replace(pattern, "[REDACTED]");
284 if let Some(ref mut system) = request.system_prompt {
285 *system = system.replace(pattern, "[REDACTED]");
286 }
287 }
288 Ok(request)
289 }
290}
291
292pub struct ContentLengthInterceptor {
294 max_prompt_length: usize,
295 max_response_length: usize,
296}
297
298impl ContentLengthInterceptor {
299 pub fn new(max_prompt_length: usize, max_response_length: usize) -> Self {
301 Self {
302 max_prompt_length,
303 max_response_length,
304 }
305 }
306}
307
308#[async_trait]
309impl RequestInterceptor for ContentLengthInterceptor {
310 async fn intercept_request(&self, mut request: LlmRequest) -> Result<LlmRequest> {
311 if request.prompt.len() > self.max_prompt_length {
312 request.prompt.truncate(self.max_prompt_length);
313 request.prompt.push_str("...[truncated]");
314 }
315 Ok(request)
316 }
317}
318
319#[async_trait]
320impl ResponseInterceptor for ContentLengthInterceptor {
321 async fn intercept_response(&self, mut response: LlmResponse) -> Result<LlmResponse> {
322 if response.content.len() > self.max_response_length {
323 response.content.truncate(self.max_response_length);
324 response.content.push_str("...[truncated]");
325 }
326 Ok(response)
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::{LlmError, Usage};
334
335 struct MockProvider;
337
338 #[async_trait]
339 impl LlmProvider for MockProvider {
340 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
341 Ok(LlmResponse {
342 content: format!("Response to: {}", request.prompt),
343 model: "mock-model".to_string(),
344 usage: Some(Usage {
345 prompt_tokens: 10,
346 completion_tokens: 20,
347 total_tokens: 30,
348 }),
349 tool_calls: Vec::new(),
350 })
351 }
352 }
353
354 struct PrefixInterceptor {
356 prefix: String,
357 }
358
359 #[async_trait]
360 impl RequestInterceptor for PrefixInterceptor {
361 async fn intercept_request(&self, mut request: LlmRequest) -> Result<LlmRequest> {
362 request.prompt = format!("{}{}", self.prefix, request.prompt);
363 Ok(request)
364 }
365 }
366
367 struct SuffixInterceptor {
369 suffix: String,
370 }
371
372 #[async_trait]
373 impl ResponseInterceptor for SuffixInterceptor {
374 async fn intercept_response(&self, mut response: LlmResponse) -> Result<LlmResponse> {
375 response.content.push_str(&self.suffix);
376 Ok(response)
377 }
378 }
379
380 struct FailingInterceptor;
382
383 #[async_trait]
384 impl RequestInterceptor for FailingInterceptor {
385 async fn intercept_request(&self, _request: LlmRequest) -> Result<LlmRequest> {
386 Err(LlmError::InvalidRequest("Test error".to_string()))
387 }
388 }
389
390 #[tokio::test]
391 async fn test_request_interceptor() {
392 let provider = MockProvider;
393 let provider = InterceptorProvider::new(provider).with_request_interceptor(Box::new(
394 PrefixInterceptor {
395 prefix: "PREFIX: ".to_string(),
396 },
397 ));
398
399 let request = LlmRequest {
400 prompt: "test".to_string(),
401 system_prompt: None,
402 temperature: None,
403 max_tokens: None,
404 tools: Vec::new(),
405 images: Vec::new(),
406 };
407
408 let response = provider.complete(request).await.unwrap();
409 assert_eq!(response.content, "Response to: PREFIX: test");
410 }
411
412 #[tokio::test]
413 async fn test_response_interceptor() {
414 let provider = MockProvider;
415 let provider = InterceptorProvider::new(provider).with_response_interceptor(Box::new(
416 SuffixInterceptor {
417 suffix: " SUFFIX".to_string(),
418 },
419 ));
420
421 let request = LlmRequest {
422 prompt: "test".to_string(),
423 system_prompt: None,
424 temperature: None,
425 max_tokens: None,
426 tools: Vec::new(),
427 images: Vec::new(),
428 };
429
430 let response = provider.complete(request).await.unwrap();
431 assert!(response.content.ends_with(" SUFFIX"));
432 }
433
434 #[tokio::test]
435 async fn test_multiple_interceptors() {
436 let provider = MockProvider;
437 let provider = InterceptorProvider::new(provider)
438 .with_request_interceptor(Box::new(PrefixInterceptor {
439 prefix: "A: ".to_string(),
440 }))
441 .with_request_interceptor(Box::new(PrefixInterceptor {
442 prefix: "B: ".to_string(),
443 }))
444 .with_response_interceptor(Box::new(SuffixInterceptor {
445 suffix: " X".to_string(),
446 }))
447 .with_response_interceptor(Box::new(SuffixInterceptor {
448 suffix: " Y".to_string(),
449 }));
450
451 let request = LlmRequest {
452 prompt: "test".to_string(),
453 system_prompt: None,
454 temperature: None,
455 max_tokens: None,
456 tools: Vec::new(),
457 images: Vec::new(),
458 };
459
460 let response = provider.complete(request).await.unwrap();
461 assert_eq!(response.content, "Response to: B: A: test X Y");
462 }
463
464 #[tokio::test]
465 async fn test_failing_interceptor() {
466 let provider = MockProvider;
467 let provider = InterceptorProvider::new(provider)
468 .with_request_interceptor(Box::new(FailingInterceptor));
469
470 let request = LlmRequest {
471 prompt: "test".to_string(),
472 system_prompt: None,
473 temperature: None,
474 max_tokens: None,
475 tools: Vec::new(),
476 images: Vec::new(),
477 };
478
479 let result = provider.complete(request).await;
480 assert!(result.is_err());
481 }
482
483 #[tokio::test]
484 async fn test_sanitization_interceptor() {
485 let provider = MockProvider;
486 let provider = InterceptorProvider::new(provider).with_request_interceptor(Box::new(
487 SanitizationInterceptor::new(vec!["secret".to_string(), "password".to_string()]),
488 ));
489
490 let request = LlmRequest {
491 prompt: "My secret is password123".to_string(),
492 system_prompt: None,
493 temperature: None,
494 max_tokens: None,
495 tools: Vec::new(),
496 images: Vec::new(),
497 };
498
499 let response = provider.complete(request).await.unwrap();
500 assert_eq!(
501 response.content,
502 "Response to: My [REDACTED] is [REDACTED]123"
503 );
504 }
505
506 #[tokio::test]
507 async fn test_content_length_interceptor() {
508 let provider = MockProvider;
509 let provider = InterceptorProvider::new(provider)
510 .with_request_interceptor(Box::new(ContentLengthInterceptor::new(10, 100)))
511 .with_response_interceptor(Box::new(ContentLengthInterceptor::new(100, 20)));
512
513 let request = LlmRequest {
514 prompt: "This is a very long prompt that should be truncated".to_string(),
515 system_prompt: None,
516 temperature: None,
517 max_tokens: None,
518 tools: Vec::new(),
519 images: Vec::new(),
520 };
521
522 let response = provider.complete(request).await.unwrap();
523 assert!(response.content.ends_with("...[truncated]"));
525 assert!(response.content.len() <= 20 + "...[truncated]".len());
526 }
527
528 #[tokio::test]
529 async fn test_logging_interceptor() {
530 let provider = MockProvider;
531 let provider = InterceptorProvider::new(provider)
532 .with_request_interceptor(Box::new(LoggingInterceptor::new("TEST".to_string())))
533 .with_response_interceptor(Box::new(LoggingInterceptor::new("TEST".to_string())));
534
535 let request = LlmRequest {
536 prompt: "test".to_string(),
537 system_prompt: None,
538 temperature: Some(0.7),
539 max_tokens: Some(100),
540 tools: Vec::new(),
541 images: Vec::new(),
542 };
543
544 let response = provider.complete(request).await.unwrap();
546 assert!(response.content.contains("Response to: test"));
547 }
548}