use std::borrow::Cow;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use http::HeaderMap;
use crate::api::list_rewrite::ListRewrite;
use crate::error::ProxyError;
use crate::maybe_send::{MaybeSend, MaybeSync};
use crate::route_handler::HandlerAction;
use crate::types::{BucketConfig, ResolvedIdentity, S3Operation};
pub struct CompletedRequest<'a> {
pub request_id: &'a str,
pub identity: Option<&'a ResolvedIdentity>,
pub operation: Option<&'a S3Operation>,
pub bucket: Option<&'a str>,
pub status: u16,
pub response_bytes: Option<u64>,
pub request_bytes: Option<u64>,
pub was_forwarded: bool,
pub source_ip: Option<IpAddr>,
}
pub struct DispatchContext<'a> {
pub identity: &'a ResolvedIdentity,
pub operation: &'a S3Operation,
pub bucket_config: Option<Cow<'a, BucketConfig>>,
pub headers: &'a HeaderMap,
pub source_ip: Option<IpAddr>,
pub request_id: &'a str,
pub list_rewrite: Option<&'a ListRewrite>,
pub display_name: Option<&'a str>,
pub extensions: http::Extensions,
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) type DispatchFuture<'a> =
Pin<Box<dyn Future<Output = Result<HandlerAction, ProxyError>> + Send + 'a>>;
#[cfg(target_arch = "wasm32")]
pub(crate) type DispatchFuture<'a> =
Pin<Box<dyn Future<Output = Result<HandlerAction, ProxyError>> + 'a>>;
#[cfg(not(target_arch = "wasm32"))]
pub(crate) type AfterDispatchFuture<'a> = Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
#[cfg(target_arch = "wasm32")]
pub(crate) type AfterDispatchFuture<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
pub(crate) trait Dispatch: MaybeSend + MaybeSync {
fn dispatch<'a>(&'a self, ctx: DispatchContext<'a>) -> DispatchFuture<'a>;
}
pub(crate) trait ErasedMiddleware: MaybeSend + MaybeSync {
fn handle<'a>(&'a self, ctx: DispatchContext<'a>, next: Next<'a>) -> DispatchFuture<'a>;
fn after_dispatch<'a>(&'a self, completed: &'a CompletedRequest<'a>)
-> AfterDispatchFuture<'a>;
}
impl<T: Middleware> ErasedMiddleware for T {
fn handle<'a>(&'a self, ctx: DispatchContext<'a>, next: Next<'a>) -> DispatchFuture<'a> {
Box::pin(<Self as Middleware>::handle(self, ctx, next))
}
fn after_dispatch<'a>(
&'a self,
completed: &'a CompletedRequest<'a>,
) -> AfterDispatchFuture<'a> {
Box::pin(<Self as Middleware>::after_dispatch(self, completed))
}
}
pub struct Next<'a> {
middleware: &'a [Box<dyn ErasedMiddleware>],
dispatch: &'a dyn Dispatch,
}
impl<'a> Next<'a> {
pub(crate) fn new(
middleware: &'a [Box<dyn ErasedMiddleware>],
dispatch: &'a dyn Dispatch,
) -> Self {
Self {
middleware,
dispatch,
}
}
pub async fn run(self, ctx: DispatchContext<'a>) -> Result<HandlerAction, ProxyError> {
if let Some((first, rest)) = self.middleware.split_first() {
let next = Next {
middleware: rest,
dispatch: self.dispatch,
};
first.handle(ctx, next).await
} else {
self.dispatch.dispatch(ctx).await
}
}
}
pub trait Middleware: MaybeSend + MaybeSync + 'static {
fn handle<'a>(
&'a self,
ctx: DispatchContext<'a>,
next: Next<'a>,
) -> impl Future<Output = Result<HandlerAction, ProxyError>> + MaybeSend + 'a;
fn after_dispatch(
&self,
_completed: &CompletedRequest<'_>,
) -> impl Future<Output = ()> + MaybeSend + '_ {
async {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::route_handler::{ProxyResponseBody, ProxyResult};
use crate::types::{BucketConfig, ResolvedIdentity, S3Operation};
pub(crate) struct BlockingMiddleware;
impl Middleware for BlockingMiddleware {
async fn handle<'a>(
&'a self,
_ctx: DispatchContext<'a>,
_next: Next<'a>,
) -> Result<HandlerAction, ProxyError> {
Ok(HandlerAction::Response(ProxyResult {
status: 429,
headers: HeaderMap::new(),
body: ProxyResponseBody::Empty,
}))
}
}
pub(crate) struct PassthroughMiddleware;
impl Middleware for PassthroughMiddleware {
async fn handle<'a>(
&'a self,
ctx: DispatchContext<'a>,
next: Next<'a>,
) -> Result<HandlerAction, ProxyError> {
next.run(ctx).await
}
}
struct TestDispatch;
impl Dispatch for TestDispatch {
fn dispatch<'a>(&'a self, _ctx: DispatchContext<'a>) -> DispatchFuture<'a> {
Box::pin(async {
Ok(HandlerAction::Response(ProxyResult {
status: 200,
headers: HeaderMap::new(),
body: ProxyResponseBody::Empty,
}))
})
}
}
fn test_context() -> DispatchContext<'static> {
static IDENTITY: ResolvedIdentity = ResolvedIdentity::Anonymous;
static OPERATION: S3Operation = S3Operation::ListBuckets;
static HEADERS: std::sync::LazyLock<HeaderMap> = std::sync::LazyLock::new(HeaderMap::new);
static BUCKET_CONFIG: std::sync::LazyLock<BucketConfig> =
std::sync::LazyLock::new(|| BucketConfig {
name: "test".to_string(),
backend_type: "s3".to_string(),
backend_prefix: None,
anonymous_access: false,
allowed_roles: Vec::new(),
backend_options: Default::default(),
});
DispatchContext {
identity: &IDENTITY,
operation: &OPERATION,
bucket_config: Some(Cow::Borrowed(&*BUCKET_CONFIG)),
headers: &*HEADERS,
source_ip: None,
request_id: "test-request-id",
list_rewrite: None,
display_name: None,
extensions: http::Extensions::new(),
}
}
fn response_status(action: &HandlerAction) -> u16 {
match action {
HandlerAction::Response(r) => r.status,
_ => panic!("expected Response variant"),
}
}
#[test]
fn empty_chain_calls_dispatch() {
let dispatch = TestDispatch;
let middleware: Vec<Box<dyn ErasedMiddleware>> = vec![];
let result = futures::executor::block_on(async {
let next = Next::new(&middleware, &dispatch);
next.run(test_context()).await
});
assert_eq!(response_status(&result.unwrap()), 200);
}
#[test]
fn blocking_middleware_short_circuits() {
let dispatch = TestDispatch;
let middleware: Vec<Box<dyn ErasedMiddleware>> = vec![Box::new(BlockingMiddleware)];
let result = futures::executor::block_on(async {
let next = Next::new(&middleware, &dispatch);
next.run(test_context()).await
});
assert_eq!(response_status(&result.unwrap()), 429);
}
#[test]
fn passthrough_then_blocking_runs_in_order() {
let dispatch = TestDispatch;
let middleware: Vec<Box<dyn ErasedMiddleware>> = vec![
Box::new(PassthroughMiddleware),
Box::new(BlockingMiddleware),
];
let result = futures::executor::block_on(async {
let next = Next::new(&middleware, &dispatch);
next.run(test_context()).await
});
assert_eq!(response_status(&result.unwrap()), 429);
}
#[test]
fn passthrough_reaches_dispatch() {
let dispatch = TestDispatch;
let middleware: Vec<Box<dyn ErasedMiddleware>> = vec![Box::new(PassthroughMiddleware)];
let result = futures::executor::block_on(async {
let next = Next::new(&middleware, &dispatch);
next.run(test_context()).await
});
assert_eq!(response_status(&result.unwrap()), 200);
}
#[test]
fn after_dispatch_default_is_noop() {
let middleware: Box<dyn ErasedMiddleware> = Box::new(PassthroughMiddleware);
futures::executor::block_on(async {
let completed = CompletedRequest {
request_id: "test",
identity: None,
operation: None,
bucket: None,
status: 200,
response_bytes: None,
request_bytes: None,
was_forwarded: false,
source_ip: None,
};
middleware.after_dispatch(&completed).await;
});
}
}