use async_trait::async_trait;
use crate::{
error::ProxyError,
types::{ConnectionContext, ProxyRequest, ProxyResponse},
};
#[async_trait]
pub trait ProxyMiddleware: Send + Sync {
async fn on_request(
&self,
req: &mut ProxyRequest,
ctx: &mut ConnectionContext,
) -> Result<(), ProxyError>;
async fn on_response(
&self,
res: &mut ProxyResponse,
ctx: &ConnectionContext,
) -> Result<(), ProxyError>;
async fn on_connect(&self, _ctx: &ConnectionContext) {}
async fn on_disconnect(&self, _ctx: &ConnectionContext) {}
async fn on_init(&self) -> Result<(), ProxyError> {
Ok(())
}
async fn on_shutdown(&self) -> Result<(), ProxyError> {
Ok(())
}
fn name(&self) -> &'static str;
}
pub async fn run_on_request_chain(
middlewares: &[Box<dyn ProxyMiddleware>],
req: &mut ProxyRequest,
ctx: &mut ConnectionContext,
) -> Result<(), ProxyError> {
for mw in middlewares {
mw.on_request(req, ctx).await?;
}
Ok(())
}
pub async fn run_on_response_chain(
middlewares: &[Box<dyn ProxyMiddleware>],
res: &mut ProxyResponse,
ctx: &ConnectionContext,
) -> Result<(), ProxyError> {
for mw in middlewares.iter().rev() {
mw.on_response(res, ctx).await?;
}
Ok(())
}
#[async_trait::async_trait]
pub trait CostRecorder: Send + Sync + std::fmt::Debug {
async fn record(
&self,
ctx: &crate::types::ConnectionContext,
response_body: &serde_json::Value,
) -> Result<(), crate::error::ProxyError>;
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use bytes::Bytes;
use http::{HeaderMap, Method, StatusCode};
use super::*;
struct RecordingMiddleware {
name: &'static str,
request_order: Arc<AtomicUsize>,
response_order: Arc<AtomicUsize>,
request_counter: AtomicUsize,
response_counter: AtomicUsize,
request_err: Option<ProxyError>,
}
#[async_trait]
impl ProxyMiddleware for RecordingMiddleware {
async fn on_request(
&self,
_req: &mut ProxyRequest,
_ctx: &mut ConnectionContext,
) -> Result<(), ProxyError> {
if let Some(ref err) = self.request_err {
return Err(ProxyError::BadRequest(err.to_string()));
}
let seq = self.request_order.fetch_add(1, Ordering::SeqCst);
self.request_counter.store(seq, Ordering::SeqCst);
Ok(())
}
async fn on_response(
&self,
_res: &mut ProxyResponse,
_ctx: &ConnectionContext,
) -> Result<(), ProxyError> {
let seq = self.response_order.fetch_add(1, Ordering::SeqCst);
self.response_counter.store(seq, Ordering::SeqCst);
Ok(())
}
fn name(&self) -> &'static str {
self.name
}
}
fn make_request() -> ProxyRequest {
ProxyRequest::new(
Method::POST,
"/v1/messages".into(),
HeaderMap::new(),
Bytes::from(r#"{"model":"test"}"#),
)
}
fn make_context() -> ConnectionContext {
ConnectionContext::new(1, crate::types::AgentType::Unknown, None, None)
}
fn make_response() -> ProxyResponse {
ProxyResponse::new(StatusCode::OK, HeaderMap::new(), Bytes::new(), false)
}
#[tokio::test]
async fn test_on_request_runs_in_registration_order() {
let order = Arc::new(AtomicUsize::new(0));
let mw_a = RecordingMiddleware {
name: "A",
request_order: order.clone(),
response_order: Arc::new(AtomicUsize::new(0)),
request_counter: AtomicUsize::new(0),
response_counter: AtomicUsize::new(0),
request_err: None,
};
let mw_b = RecordingMiddleware {
name: "B",
request_order: order.clone(),
response_order: Arc::new(AtomicUsize::new(0)),
request_counter: AtomicUsize::new(0),
response_counter: AtomicUsize::new(0),
request_err: None,
};
let mw_c = RecordingMiddleware {
name: "C",
request_order: order.clone(),
response_order: Arc::new(AtomicUsize::new(0)),
request_counter: AtomicUsize::new(0),
response_counter: AtomicUsize::new(0),
request_err: None,
};
let middlewares: Vec<Box<dyn ProxyMiddleware>> =
vec![Box::new(mw_a), Box::new(mw_b), Box::new(mw_c)];
let mut req = make_request();
let mut ctx = make_context();
run_on_request_chain(&middlewares, &mut req, &mut ctx)
.await
.unwrap();
assert_eq!(order.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_on_response_runs_in_reverse_registration_order() {
let order = Arc::new(AtomicUsize::new(0));
let mw_a = RecordingMiddleware {
name: "A",
request_order: Arc::new(AtomicUsize::new(0)),
response_order: order.clone(),
request_counter: AtomicUsize::new(0),
response_counter: AtomicUsize::new(0),
request_err: None,
};
let mw_b = RecordingMiddleware {
name: "B",
request_order: Arc::new(AtomicUsize::new(0)),
response_order: order.clone(),
request_counter: AtomicUsize::new(0),
response_counter: AtomicUsize::new(0),
request_err: None,
};
let mw_c = RecordingMiddleware {
name: "C",
request_order: Arc::new(AtomicUsize::new(0)),
response_order: order.clone(),
request_counter: AtomicUsize::new(0),
response_counter: AtomicUsize::new(0),
request_err: None,
};
let middlewares: Vec<Box<dyn ProxyMiddleware>> =
vec![Box::new(mw_a), Box::new(mw_b), Box::new(mw_c)];
let mut res = make_response();
let ctx = make_context();
run_on_response_chain(&middlewares, &mut res, &ctx)
.await
.unwrap();
assert_eq!(order.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_on_request_aborts_on_error() {
let mw_ok = RecordingMiddleware {
name: "ok",
request_order: Arc::new(AtomicUsize::new(0)),
response_order: Arc::new(AtomicUsize::new(0)),
request_counter: AtomicUsize::new(0),
response_counter: AtomicUsize::new(0),
request_err: None,
};
let mw_err = RecordingMiddleware {
name: "err",
request_order: Arc::new(AtomicUsize::new(0)),
response_order: Arc::new(AtomicUsize::new(0)),
request_counter: AtomicUsize::new(0),
response_counter: AtomicUsize::new(0),
request_err: Some(ProxyError::BadRequest("test error".into())),
};
let mw_never = RecordingMiddleware {
name: "never",
request_order: Arc::new(AtomicUsize::new(0)),
response_order: Arc::new(AtomicUsize::new(0)),
request_counter: AtomicUsize::new(0),
response_counter: AtomicUsize::new(0),
request_err: None,
};
let middlewares: Vec<Box<dyn ProxyMiddleware>> =
vec![Box::new(mw_ok), Box::new(mw_err), Box::new(mw_never)];
let mut req = make_request();
let mut ctx = make_context();
let result = run_on_request_chain(&middlewares, &mut req, &mut ctx).await;
assert!(result.is_err());
}
}