use crate::error::OpenAIError;
use crate::service::request::Request;
use reqwest::Response;
use std::fmt::Debug;
use std::sync::Arc;
#[async_trait::async_trait]
pub trait Interceptor: Send + Sync {
fn priority(&self) -> InterceptorPriority {
InterceptorPriority::Medium
}
async fn on_request(&self, request: Request) -> Result<Request, OpenAIError> {
Ok(request)
}
async fn on_response(&self, response: Response) -> Result<Response, OpenAIError> {
Ok(response)
}
async fn on_error(&self, error: OpenAIError) -> Result<OpenAIError, OpenAIError> {
Ok(error)
}
}
#[derive(Debug, Clone)]
pub enum InterceptorPriority {
Highest,
High,
Medium,
Low,
Lowest,
Custom(i32),
}
impl InterceptorPriority {
fn to_int(&self) -> i32 {
match self {
InterceptorPriority::Highest => 100,
InterceptorPriority::High => 75,
InterceptorPriority::Medium => 50,
InterceptorPriority::Low => 25,
InterceptorPriority::Lowest => 0,
InterceptorPriority::Custom(value) => *value,
}
}
}
impl PartialEq for InterceptorPriority {
fn eq(&self, other: &Self) -> bool {
self.to_int() == other.to_int()
}
}
impl Eq for InterceptorPriority {}
impl PartialOrd for InterceptorPriority {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for InterceptorPriority {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.to_int().cmp(&other.to_int())
}
}
#[derive(Clone)]
pub struct PrioritizedInterceptor {
interceptor: Arc<dyn Interceptor>,
priority: InterceptorPriority,
}
impl PrioritizedInterceptor {
pub fn new(interceptor: Arc<dyn Interceptor>) -> Self {
let priority = interceptor.priority();
Self {
interceptor,
priority,
}
}
#[inline]
pub fn interceptor(&self) -> &Arc<dyn Interceptor> {
&self.interceptor
}
#[inline]
pub fn priority(&self) -> &InterceptorPriority {
&self.priority
}
#[inline]
pub fn interceptor_mut(&mut self) -> &mut Arc<dyn Interceptor> {
&mut self.interceptor
}
#[inline]
pub fn priority_mut(&mut self) -> &mut InterceptorPriority {
&mut self.priority
}
}
#[derive(Default, Clone)]
pub struct InterceptorChain(Vec<PrioritizedInterceptor>);
impl InterceptorChain {
pub fn new() -> Self {
Self(Vec::new())
}
pub fn add_interceptor(&mut self, interceptor: impl Interceptor + 'static) {
self.0
.push(PrioritizedInterceptor::new(Arc::new(interceptor)));
self.0.sort_by(|a, b| b.priority().cmp(a.priority()));
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub async fn execute_request_interceptors(
&self,
request: Request,
) -> Result<Request, OpenAIError> {
let mut current_request = request;
for prioritized_interceptor in &self.0 {
current_request = prioritized_interceptor
.interceptor()
.on_request(current_request)
.await?;
}
Ok(current_request)
}
pub async fn execute_response_interceptors(
&self,
response: Response,
) -> Result<Response, OpenAIError> {
let mut current_response = response;
for prioritized_interceptor in self.0.iter().rev() {
current_response = prioritized_interceptor
.interceptor()
.on_response(current_response)
.await?;
}
Ok(current_response)
}
pub async fn execute_error_interceptors(
&self,
error: OpenAIError,
) -> Result<OpenAIError, OpenAIError> {
let mut current_error = error;
for prioritized_interceptor in self.0.iter().rev() {
current_error = prioritized_interceptor
.interceptor()
.on_error(current_error)
.await?;
}
Ok(current_error)
}
}
impl std::ops::Index<usize> for InterceptorChain {
type Output = PrioritizedInterceptor;
fn index(&self, index: usize) -> &Self::Output {
&self.0[index]
}
}
impl std::ops::IndexMut<usize> for InterceptorChain {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.0[index]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct TestInterceptor {
id: String,
priority: InterceptorPriority,
}
#[async_trait::async_trait]
impl Interceptor for TestInterceptor {
fn priority(&self) -> InterceptorPriority {
self.priority.clone()
}
async fn on_request(&self, request: Request) -> Result<Request, OpenAIError> {
println!("TestInterceptor {} processing request", self.id);
Ok(request)
}
async fn on_response(&self, response: Response) -> Result<Response, OpenAIError> {
println!("TestInterceptor {} processing response", self.id);
Ok(response)
}
async fn on_error(&self, error: OpenAIError) -> Result<OpenAIError, OpenAIError> {
println!("TestInterceptor {} processing error", self.id);
Ok(error)
}
}
#[tokio::test]
async fn test_interceptor_chain_order() {
let mut chain = InterceptorChain::new();
let low_interceptor = TestInterceptor {
id: "low".to_string(),
priority: InterceptorPriority::Low,
};
let high_interceptor = TestInterceptor {
id: "high".to_string(),
priority: InterceptorPriority::High,
};
let medium_interceptor = TestInterceptor {
id: "medium".to_string(),
priority: InterceptorPriority::Medium,
};
chain.add_interceptor(low_interceptor);
chain.add_interceptor(high_interceptor);
chain.add_interceptor(medium_interceptor);
assert_eq!(chain.len(), 3);
assert_eq!(chain[0].priority, InterceptorPriority::High);
assert_eq!(chain[1].priority, InterceptorPriority::Medium);
assert_eq!(chain[2].priority, InterceptorPriority::Low);
}
}