use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
use orpc_procedure::{DynInput, DynOutput, ProcedureError, Route};
use serde::Serialize;
use crate::handler::BoxFuture;
pub fn middleware_fn<TCtx, TNextCtx, F, Fut>(
f: F,
) -> impl Fn(
TCtx,
MiddlewareCtx<TNextCtx>,
) -> BoxFuture<'static, Result<MiddlewareOutput, ProcedureError>>
+ Send
+ Sync
+ 'static
where
F: Fn(TCtx, MiddlewareCtx<TNextCtx>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<MiddlewareOutput, ProcedureError>> + Send + 'static,
{
move |ctx, mw| Box::pin(f(ctx, mw))
}
type InnerHandler<TCtx> =
Box<dyn FnOnce(TCtx, DynInput) -> BoxFuture<'static, Result<DynOutput, ProcedureError>> + Send>;
pub(crate) trait MiddlewareChain<TBaseCtx, TCurrentCtx>: Send + Sync + 'static {
fn run(
&self,
ctx: TBaseCtx,
input: DynInput,
meta: ProcedureMeta,
inner_handler: InnerHandler<TCurrentCtx>,
) -> BoxFuture<'static, Result<DynOutput, ProcedureError>>;
}
pub(crate) struct IdentityChain;
impl<TCtx: Send + 'static> MiddlewareChain<TCtx, TCtx> for IdentityChain {
fn run(
&self,
ctx: TCtx,
input: DynInput,
_meta: ProcedureMeta,
inner_handler: InnerHandler<TCtx>,
) -> BoxFuture<'static, Result<DynOutput, ProcedureError>> {
inner_handler(ctx, input)
}
}
pub(crate) struct ComposedChain<TBaseCtx, TMidCtx, TCurrentCtx, M> {
prev: Arc<dyn MiddlewareChain<TBaseCtx, TMidCtx>>,
middleware: Arc<M>,
_phantom: PhantomData<fn(TBaseCtx, TMidCtx, TCurrentCtx)>,
}
impl<TBaseCtx, TMidCtx, TCurrentCtx, M> ComposedChain<TBaseCtx, TMidCtx, TCurrentCtx, M> {
pub fn new(prev: Arc<dyn MiddlewareChain<TBaseCtx, TMidCtx>>, middleware: Arc<M>) -> Self {
ComposedChain {
prev,
middleware,
_phantom: PhantomData,
}
}
}
impl<TBaseCtx, TMidCtx, TCurrentCtx, M> MiddlewareChain<TBaseCtx, TCurrentCtx>
for ComposedChain<TBaseCtx, TMidCtx, TCurrentCtx, M>
where
TBaseCtx: Send + 'static,
TMidCtx: Send + 'static,
TCurrentCtx: Send + 'static,
M: Fn(
TMidCtx,
MiddlewareCtx<TCurrentCtx>,
) -> BoxFuture<'static, Result<MiddlewareOutput, ProcedureError>>
+ Send
+ Sync
+ 'static,
{
fn run(
&self,
ctx: TBaseCtx,
input: DynInput,
meta: ProcedureMeta,
inner_handler: InnerHandler<TCurrentCtx>,
) -> BoxFuture<'static, Result<DynOutput, ProcedureError>> {
let middleware = self.middleware.clone();
let meta_for_mw = meta.clone();
self.prev.run(
ctx,
input,
meta,
Box::new(move |mid_ctx: TMidCtx, input: DynInput| {
let mw_ctx = MiddlewareCtx {
next_fn: inner_handler,
dyn_input: input,
meta: meta_for_mw,
};
Box::pin(async move {
let result = middleware(mid_ctx, mw_ctx).await?;
Ok(result.output)
})
}),
)
}
}
pub struct MiddlewareCtx<TNextCtx> {
next_fn: InnerHandler<TNextCtx>,
dyn_input: DynInput,
meta: ProcedureMeta,
}
impl<TNextCtx> MiddlewareCtx<TNextCtx> {
pub async fn next(self, ctx: TNextCtx) -> Result<MiddlewareOutput, ProcedureError> {
let output = (self.next_fn)(ctx, self.dyn_input).await?;
Ok(MiddlewareOutput { output })
}
pub async fn next_with_input(
self,
ctx: TNextCtx,
input: DynInput,
) -> Result<MiddlewareOutput, ProcedureError> {
let output = (self.next_fn)(ctx, input).await?;
Ok(MiddlewareOutput { output })
}
pub fn output<T: Serialize + Send + 'static>(
self,
value: T,
) -> Result<MiddlewareOutput, ProcedureError> {
Ok(MiddlewareOutput {
output: DynOutput::new(value),
})
}
pub fn input(&self) -> Option<&serde_json::Value> {
self.dyn_input.as_value()
}
pub fn materialize_input(&mut self) -> Result<(), ProcedureError> {
let input = std::mem::replace(
&mut self.dyn_input,
DynInput::from_value(serde_json::Value::Null),
);
self.dyn_input = input.materialize()?;
Ok(())
}
pub fn meta(&self) -> &ProcedureMeta {
&self.meta
}
}
pub struct MiddlewareOutput {
pub output: DynOutput,
}
#[derive(Clone)]
pub struct ProcedureMeta {
pub route: Route,
}
#[cfg(test)]
mod tests {
use super::*;
use orpc_procedure::DynInput;
fn test_meta() -> ProcedureMeta {
ProcedureMeta {
route: Route::get("/test"),
}
}
#[tokio::test]
async fn identity_chain_passthrough() {
let chain = IdentityChain;
let input = DynInput::from_value(serde_json::json!(42));
let result = chain
.run(
"context",
input,
test_meta(),
Box::new(|ctx: &str, input: DynInput| {
Box::pin(async move {
let val: i32 = input.deserialize()?;
Ok(DynOutput::new(format!("{ctx}:{val}")))
})
}),
)
.await
.unwrap();
assert_eq!(result.to_value().unwrap(), serde_json::json!("context:42"));
}
#[tokio::test]
async fn composed_chain_context_switch() {
let prev: Arc<dyn MiddlewareChain<u32, u32>> = Arc::new(IdentityChain);
let middleware = Arc::new(
|ctx: u32,
mw: MiddlewareCtx<String>|
-> BoxFuture<'static, Result<MiddlewareOutput, ProcedureError>> {
Box::pin(async move { mw.next(format!("user-{ctx}")).await })
},
);
let chain = ComposedChain::new(prev, middleware);
let input = DynInput::from_value(serde_json::json!("hello"));
let result = chain
.run(
42u32,
input,
test_meta(),
Box::new(|ctx: String, input: DynInput| {
Box::pin(async move {
let val: String = input.deserialize()?;
Ok(DynOutput::new(format!("{ctx}:{val}")))
})
}),
)
.await
.unwrap();
assert_eq!(
result.to_value().unwrap(),
serde_json::json!("user-42:hello")
);
}
#[tokio::test]
async fn middleware_output_short_circuit() {
let prev: Arc<dyn MiddlewareChain<(), ()>> = Arc::new(IdentityChain);
let middleware = Arc::new(
|_ctx: (),
mw: MiddlewareCtx<()>|
-> BoxFuture<'static, Result<MiddlewareOutput, ProcedureError>> {
Box::pin(async move { mw.output("cached response") })
},
);
let chain = ComposedChain::new(prev, middleware);
let input = DynInput::from_value(serde_json::json!(null));
let result = chain
.run(
(),
input,
test_meta(),
Box::new(|_ctx: (), _input: DynInput| {
Box::pin(async move { panic!("should not be called") })
}),
)
.await
.unwrap();
assert_eq!(
result.to_value().unwrap(),
serde_json::json!("cached response")
);
}
#[tokio::test]
async fn double_middleware_chain() {
let identity: Arc<dyn MiddlewareChain<u32, u32>> = Arc::new(IdentityChain);
let mw1 = Arc::new(
|ctx: u32,
mw: MiddlewareCtx<String>|
-> BoxFuture<'static, Result<MiddlewareOutput, ProcedureError>> {
Box::pin(async move { mw.next(format!("user-{ctx}")).await })
},
);
let chain1: Arc<dyn MiddlewareChain<u32, String>> =
Arc::new(ComposedChain::new(identity, mw1));
let mw2 = Arc::new(
|ctx: String,
mw: MiddlewareCtx<(String, bool)>|
-> BoxFuture<'static, Result<MiddlewareOutput, ProcedureError>> {
Box::pin(async move { mw.next((ctx, true)).await })
},
);
let chain2 = ComposedChain::new(chain1, mw2);
let input = DynInput::from_value(serde_json::json!("test"));
let result = chain2
.run(
42u32,
input,
test_meta(),
Box::new(|ctx: (String, bool), input: DynInput| {
Box::pin(async move {
let val: String = input.deserialize()?;
Ok(DynOutput::new(format!("{}:{}:{}", ctx.0, ctx.1, val)))
})
}),
)
.await
.unwrap();
assert_eq!(
result.to_value().unwrap(),
serde_json::json!("user-42:true:test")
);
}
#[test]
fn middleware_ctx_is_send() {
fn assert_send<T: Send>() {}
assert_send::<MiddlewareCtx<()>>();
assert_send::<MiddlewareOutput>();
}
}