use std::{future::Future, pin::Pin, sync::Arc};
use super::context::Context;
use super::handler::{Handler, Settle};
use super::middleware::Layer;
type BoxFut<'a> = Pin<Box<dyn Future<Output = Settle> + Send + 'a>>;
trait ErasedHandler<I>: Send + Sync {
fn handle_boxed<'a>(&'a self, input: &'a I, ctx: &'a mut Context<'_>) -> BoxFut<'a>;
}
impl<I, H> ErasedHandler<I> for H
where
I: Sync,
H: Handler<I>,
{
fn handle_boxed<'a>(&'a self, input: &'a I, ctx: &'a mut Context<'_>) -> BoxFut<'a> {
Box::pin(self.handle(input, ctx))
}
}
pub trait DynMiddleware<I>: Send + Sync {
fn handle<'a>(
&'a self,
input: &'a I,
ctx: &'a mut Context<'_>,
next: Next<'a, I>,
) -> BoxFut<'a>;
}
pub struct Next<'a, I> {
rest: &'a [Arc<dyn DynMiddleware<I>>],
tail: &'a dyn ErasedHandler<I>,
}
impl<'a, I> Next<'a, I> {
#[must_use]
pub fn run(self, input: &'a I, ctx: &'a mut Context<'_>) -> BoxFut<'a> {
match self.rest.split_first() {
Some((middleware, rest)) => middleware.handle(
input,
ctx,
Next {
rest,
tail: self.tail,
},
),
None => self.tail.handle_boxed(input, ctx),
}
}
}
impl<I> std::fmt::Debug for Next<'_, I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Next")
.field("remaining", &self.rest.len())
.finish_non_exhaustive()
}
}
pub struct DynStack<I>(Arc<[Arc<dyn DynMiddleware<I>>]>);
impl<I> Clone for DynStack<I> {
fn clone(&self) -> Self {
Self(Arc::clone(&self.0))
}
}
impl<I> DynStack<I> {
#[must_use]
pub fn new(middleware: impl IntoIterator<Item = Arc<dyn DynMiddleware<I>>>) -> Self {
Self(middleware.into_iter().collect())
}
}
impl<I> std::fmt::Debug for DynStack<I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DynStack")
.field("middleware", &self.0.len())
.finish_non_exhaustive()
}
}
impl<I, H> Layer<H> for DynStack<I>
where
I: Sync,
H: Handler<I>,
{
type Handler = DynStackHandler<I, H>;
fn layer(&self, inner: H) -> Self::Handler {
DynStackHandler {
chain: self.0.clone(),
inner,
}
}
}
pub struct DynStackHandler<I, H> {
chain: Arc<[Arc<dyn DynMiddleware<I>>]>,
inner: H,
}
impl<I, H> std::fmt::Debug for DynStackHandler<I, H> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DynStackHandler")
.field("middleware", &self.chain.len())
.finish_non_exhaustive()
}
}
impl<I, H> Handler<I> for DynStackHandler<I, H>
where
I: Sync,
H: Handler<I>,
{
async fn handle(&self, input: &I, ctx: &mut Context<'_>) -> Settle {
let tail: &dyn ErasedHandler<I> = &self.inner;
Next {
rest: &self.chain,
tail,
}
.run(input, ctx)
.await
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use super::super::HandlerExt;
use super::super::context::{Context, State};
use super::super::handler::{Handler, HandlerResult};
use super::{BoxFut, DynMiddleware, DynStack, Next};
use crate::Headers;
struct Input;
struct Recorder(Arc<Mutex<Vec<&'static str>>>, &'static str);
impl DynMiddleware<Input> for Recorder {
fn handle<'a>(
&'a self,
input: &'a Input,
ctx: &'a mut Context<'_>,
next: Next<'a, Input>,
) -> BoxFut<'a> {
Box::pin(async move {
self.0.lock().expect("poisoned").push(self.1);
next.run(input, ctx).await
})
}
}
#[tokio::test]
async fn runs_middleware_in_order_then_inner() {
let log = Arc::new(Mutex::new(Vec::new()));
let stack = DynStack::new([
Arc::new(Recorder(Arc::clone(&log), "a")) as Arc<dyn DynMiddleware<Input>>,
Arc::new(Recorder(Arc::clone(&log), "b")) as Arc<dyn DynMiddleware<Input>>,
]);
let inner_log = Arc::clone(&log);
let inner = move |_: &Input, _ctx: &mut Context| {
let inner_log = Arc::clone(&inner_log);
async move {
inner_log.lock().expect("poisoned").push("inner");
HandlerResult::Ack
}
};
let handler = inner.with(stack);
let state = State::default();
let delivery = crate::runtime::dispatch::Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("test", &headers, &state, &delivery);
assert_eq!(
handler.handle(&Input, &mut ctx).await.outcome(),
HandlerResult::Ack
);
assert_eq!(*log.lock().expect("poisoned"), vec!["a", "b", "inner"]);
}
}