use super::context::QueryContext;
use super::types::{
BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse, SharedMiddleware,
};
use std::sync::Arc;
pub struct MiddlewareChain {
middlewares: Vec<SharedMiddleware>,
}
impl MiddlewareChain {
pub fn new() -> Self {
Self {
middlewares: Vec::new(),
}
}
pub fn with(middlewares: Vec<SharedMiddleware>) -> Self {
Self { middlewares }
}
pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
self.middlewares.push(Arc::new(middleware));
}
pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) {
self.middlewares.insert(0, Arc::new(middleware));
}
pub fn len(&self) -> usize {
self.middlewares.len()
}
pub fn is_empty(&self) -> bool {
self.middlewares.is_empty()
}
pub fn execute<'a, F>(
&'a self,
ctx: QueryContext,
final_handler: F,
) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
where
F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
{
self.execute_at(0, ctx, final_handler)
}
fn execute_at<'a, F>(
&'a self,
index: usize,
ctx: QueryContext,
final_handler: F,
) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
where
F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
{
if index >= self.middlewares.len() {
return final_handler(ctx);
}
let middleware = &self.middlewares[index];
if !middleware.enabled() {
return self.execute_at(index + 1, ctx, final_handler);
}
({
Box::pin(async move {
middleware
.handle(
ctx,
Next {
inner: Box::new(move |ctx| {
final_handler(ctx)
}),
},
)
.await
})
}) as _
}
}
impl Default for MiddlewareChain {
fn default() -> Self {
Self::new()
}
}
pub struct MiddlewareStack {
chain: MiddlewareChain,
}
impl MiddlewareStack {
pub fn new() -> Self {
Self {
chain: MiddlewareChain::new(),
}
}
pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
self.chain.push(middleware);
self
}
pub fn push<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
self.chain.push(middleware);
self
}
pub fn prepend<M: Middleware + 'static>(&mut self, middleware: M) -> &mut Self {
self.chain.prepend(middleware);
self
}
pub fn len(&self) -> usize {
self.chain.len()
}
pub fn is_empty(&self) -> bool {
self.chain.is_empty()
}
pub fn into_chain(self) -> MiddlewareChain {
self.chain
}
pub fn execute<'a, F>(
&'a self,
ctx: QueryContext,
final_handler: F,
) -> BoxFuture<'a, MiddlewareResult<QueryResponse>>
where
F: FnOnce(QueryContext) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> + Send + 'a,
{
self.chain.execute(ctx, final_handler)
}
}
impl Default for MiddlewareStack {
fn default() -> Self {
Self::new()
}
}
impl From<MiddlewareStack> for MiddlewareChain {
fn from(stack: MiddlewareStack) -> Self {
stack.chain
}
}
pub struct MiddlewareBuilder {
middlewares: Vec<SharedMiddleware>,
}
impl MiddlewareBuilder {
pub fn new() -> Self {
Self {
middlewares: Vec::new(),
}
}
pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
self.middlewares.push(Arc::new(middleware));
self
}
pub fn with_if<M: Middleware + 'static>(self, condition: bool, middleware: M) -> Self {
if condition {
self.with(middleware)
} else {
self
}
}
pub fn build(self) -> MiddlewareChain {
MiddlewareChain::with(self.middlewares)
}
pub fn build_stack(self) -> MiddlewareStack {
MiddlewareStack {
chain: self.build(),
}
}
}
impl Default for MiddlewareBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_middleware_chain_empty() {
let chain = MiddlewareChain::new();
assert!(chain.is_empty());
assert_eq!(chain.len(), 0);
}
#[test]
fn test_middleware_stack_builder() {
struct DummyMiddleware;
impl Middleware for DummyMiddleware {
fn handle<'a>(
&'a self,
ctx: QueryContext,
next: Next<'a>,
) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
Box::pin(async move { next.run(ctx).await })
}
}
let stack = MiddlewareStack::new()
.with(DummyMiddleware)
.with(DummyMiddleware);
assert_eq!(stack.len(), 2);
}
#[test]
fn test_middleware_builder() {
struct TestMiddleware;
impl Middleware for TestMiddleware {
fn handle<'a>(
&'a self,
ctx: QueryContext,
next: Next<'a>,
) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
Box::pin(async move { next.run(ctx).await })
}
}
let chain = MiddlewareBuilder::new()
.with(TestMiddleware)
.with_if(true, TestMiddleware)
.with_if(false, TestMiddleware)
.build();
assert_eq!(chain.len(), 2);
}
}