use crate::Result;
use std::time::Duration;
#[derive(Debug)]
pub struct BeforeRequestContext<'a, T = ()> {
pub operation: &'a str,
pub model: &'a str,
pub request_json: &'a str,
pub state: &'a mut T,
}
#[derive(Debug)]
pub struct AfterResponseContext<'a, T = ()> {
pub operation: &'a str,
pub model: &'a str,
pub request_json: &'a str,
pub response_json: &'a str,
pub duration: Duration,
pub input_tokens: Option<i64>,
pub output_tokens: Option<i64>,
pub state: &'a T,
}
#[derive(Debug)]
pub struct StreamChunkContext<'a, T = ()> {
pub operation: &'a str,
pub model: &'a str,
pub request_json: &'a str,
pub chunk_json: &'a str,
pub chunk_index: usize,
pub state: &'a T,
}
#[derive(Debug)]
pub struct StreamEndContext<'a, T = ()> {
pub operation: &'a str,
pub model: &'a str,
pub request_json: &'a str,
pub total_chunks: usize,
pub duration: Duration,
pub input_tokens: Option<i64>,
pub output_tokens: Option<i64>,
pub state: &'a T,
}
#[derive(Debug)]
pub struct ErrorContext<'a, T = ()> {
pub operation: &'a str,
pub model: Option<&'a str>,
pub request_json: Option<&'a str>,
pub error: &'a crate::Error,
pub state: Option<&'a T>,
}
#[async_trait::async_trait]
pub trait Interceptor<T = ()>: Send + Sync {
async fn before_request(&self, _ctx: &mut BeforeRequestContext<'_, T>) -> Result<()> {
Ok(())
}
async fn after_response(&self, _ctx: &AfterResponseContext<'_, T>) -> Result<()> {
Ok(())
}
async fn on_stream_chunk(&self, _ctx: &StreamChunkContext<'_, T>) -> Result<()> {
Ok(())
}
async fn on_stream_end(&self, _ctx: &StreamEndContext<'_, T>) -> Result<()> {
Ok(())
}
async fn on_error(&self, _ctx: &ErrorContext<'_, T>) {
}
}
pub struct InterceptorChain<T = ()> {
interceptors: Vec<Box<dyn Interceptor<T>>>,
}
impl<T> Default for InterceptorChain<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> InterceptorChain<T> {
pub fn new() -> Self {
Self {
interceptors: Vec::new(),
}
}
pub fn add(&mut self, interceptor: Box<dyn Interceptor<T>>) {
self.interceptors.push(interceptor);
}
pub async fn before_request(&self, ctx: &mut BeforeRequestContext<'_, T>) -> Result<()> {
for interceptor in &self.interceptors {
interceptor.before_request(ctx).await?;
}
Ok(())
}
pub async fn after_response(&self, ctx: &AfterResponseContext<'_, T>) -> Result<()>
where
T: Sync,
{
for interceptor in &self.interceptors {
interceptor.after_response(ctx).await?;
}
Ok(())
}
pub async fn on_stream_chunk(&self, ctx: &StreamChunkContext<'_, T>) -> Result<()>
where
T: Sync,
{
for interceptor in &self.interceptors {
interceptor.on_stream_chunk(ctx).await?;
}
Ok(())
}
pub async fn on_stream_end(&self, ctx: &StreamEndContext<'_, T>) -> Result<()>
where
T: Sync,
{
for interceptor in &self.interceptors {
interceptor.on_stream_end(ctx).await?;
}
Ok(())
}
pub async fn on_error(&self, ctx: &ErrorContext<'_, T>)
where
T: Sync,
{
for interceptor in &self.interceptors {
interceptor.on_error(ctx).await;
}
}
pub fn is_empty(&self) -> bool {
self.interceptors.is_empty()
}
pub fn len(&self) -> usize {
self.interceptors.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[allow(clippy::struct_field_names)]
struct TestInterceptor {
before_request_count: Arc<AtomicUsize>,
after_response_count: Arc<AtomicUsize>,
on_stream_chunk_count: Arc<AtomicUsize>,
on_stream_end_count: Arc<AtomicUsize>,
on_error_count: Arc<AtomicUsize>,
}
impl TestInterceptor {
fn new() -> Self {
Self {
before_request_count: Arc::new(AtomicUsize::new(0)),
after_response_count: Arc::new(AtomicUsize::new(0)),
on_stream_chunk_count: Arc::new(AtomicUsize::new(0)),
on_stream_end_count: Arc::new(AtomicUsize::new(0)),
on_error_count: Arc::new(AtomicUsize::new(0)),
}
}
}
#[async_trait::async_trait]
impl Interceptor for TestInterceptor {
async fn before_request(&self, _ctx: &mut BeforeRequestContext<'_>) -> Result<()> {
self.before_request_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn after_response(&self, _ctx: &AfterResponseContext<'_>) -> Result<()> {
self.after_response_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn on_stream_chunk(&self, _ctx: &StreamChunkContext<'_>) -> Result<()> {
self.on_stream_chunk_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn on_stream_end(&self, _ctx: &StreamEndContext<'_>) -> Result<()> {
self.on_stream_end_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn on_error(&self, _ctx: &ErrorContext<'_>) {
self.on_error_count.fetch_add(1, Ordering::SeqCst);
}
}
#[tokio::test]
async fn test_interceptor_chain_executes_in_order() {
let mut chain = InterceptorChain::new();
let interceptor1 = TestInterceptor::new();
let interceptor2 = TestInterceptor::new();
let count1 = interceptor1.before_request_count.clone();
let count2 = interceptor2.before_request_count.clone();
chain.add(Box::new(interceptor1));
chain.add(Box::new(interceptor2));
let mut state = ();
let mut ctx = BeforeRequestContext {
operation: "test",
model: "gpt-4",
request_json: "{}",
state: &mut state,
};
chain.before_request(&mut ctx).await.unwrap();
assert_eq!(count1.load(Ordering::SeqCst), 1);
assert_eq!(count2.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_interceptor_chain_handles_errors() {
struct FailingInterceptor;
#[async_trait::async_trait]
impl Interceptor for FailingInterceptor {
async fn before_request(&self, _ctx: &mut BeforeRequestContext<'_>) -> Result<()> {
Err(crate::Error::Internal("Test error".to_string()))
}
}
let mut chain = InterceptorChain::new();
chain.add(Box::new(FailingInterceptor));
let mut state = ();
let mut ctx = BeforeRequestContext {
operation: "test",
model: "gpt-4",
request_json: "{}",
state: &mut state,
};
let result = chain.before_request(&mut ctx).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_interceptor_chain_empty() {
let chain = InterceptorChain::new();
assert!(chain.is_empty());
assert_eq!(chain.len(), 0);
let mut state = ();
let mut ctx = BeforeRequestContext {
operation: "test",
model: "gpt-4",
request_json: "{}",
state: &mut state,
};
chain.before_request(&mut ctx).await.unwrap();
}
#[tokio::test]
async fn test_state_passing() {
struct StateInterceptor;
#[async_trait::async_trait]
impl Interceptor<HashMap<String, String>> for StateInterceptor {
async fn before_request(
&self,
ctx: &mut BeforeRequestContext<'_, HashMap<String, String>>,
) -> Result<()> {
ctx.state
.insert("test_key".to_string(), "test_value".to_string());
Ok(())
}
}
let mut chain = InterceptorChain::new();
chain.add(Box::new(StateInterceptor));
let mut state = HashMap::new();
let mut ctx = BeforeRequestContext {
operation: "test",
model: "gpt-4",
request_json: "{}",
state: &mut state,
};
chain.before_request(&mut ctx).await.unwrap();
assert_eq!(state.get("test_key"), Some(&"test_value".to_string()));
}
#[tokio::test]
async fn test_error_handler_doesnt_propagate_errors() {
#[allow(dead_code)]
struct ErrorInterceptor {
called: Arc<AtomicUsize>,
}
#[async_trait::async_trait]
impl Interceptor for ErrorInterceptor {
async fn on_error(&self, _ctx: &ErrorContext<'_>) {
self.called.fetch_add(1, Ordering::SeqCst);
panic!("This panic should be caught");
}
}
let chain: InterceptorChain<()> = InterceptorChain::new();
let error = crate::Error::Internal("Test".to_string());
let ctx = ErrorContext {
operation: "test",
model: None,
request_json: None,
error: &error,
state: None,
};
chain.on_error(&ctx).await;
}
}