use std::marker::PhantomData;
use std::sync::Arc;
use orpc_procedure::{
DynInput, DynOutput, ErasedSchema, ErrorMap, Meta, ProcedureError, ProcedureStream, Route,
};
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::context::Context;
use crate::error::ORPCError;
use crate::handler::{BoxFuture, Handler};
use crate::middleware::{
ComposedChain, IdentityChain, MiddlewareChain, MiddlewareCtx, MiddlewareOutput, ProcedureMeta,
};
use crate::procedure::Procedure;
use crate::schema::{InputValidator, Schema, make_input_validator};
pub fn os<TCtx: Context>() -> Builder<TCtx, TCtx> {
Builder {
middleware_chain: Arc::new(IdentityChain),
error_map: ErrorMap::default(),
route: Route::default(),
meta: Meta::default(),
_phantom: PhantomData,
}
}
pub struct Builder<TBaseCtx, TCtx, TError = ORPCError> {
pub(crate) middleware_chain: Arc<dyn MiddlewareChain<TBaseCtx, TCtx>>,
pub(crate) error_map: ErrorMap,
pub(crate) route: Route,
pub(crate) meta: Meta,
pub(crate) _phantom: PhantomData<fn(TError)>,
}
impl<TBaseCtx: Context, TCtx: Context, TError> Builder<TBaseCtx, TCtx, TError> {
pub fn use_middleware<TNextCtx, M>(self, m: M) -> Builder<TBaseCtx, TNextCtx, TError>
where
M: Fn(
TCtx,
MiddlewareCtx<TNextCtx>,
) -> BoxFuture<'static, Result<MiddlewareOutput, ProcedureError>>
+ Send
+ Sync
+ 'static,
TNextCtx: Context,
{
Builder {
middleware_chain: Arc::new(ComposedChain::new(self.middleware_chain, Arc::new(m))),
error_map: self.error_map,
route: self.route,
meta: self.meta,
_phantom: PhantomData,
}
}
pub fn route(mut self, route: Route) -> Self {
self.route = route;
self
}
pub fn output<S: Schema>(
self,
schema: S,
) -> BuilderWithOutput<TBaseCtx, TCtx, S::Output, TError>
where
S::Output: Serialize + 'static,
{
BuilderWithOutput {
middleware_chain: self.middleware_chain,
error_map: self.error_map,
route: self.route,
meta: self.meta,
output_schema: schema.into_erased(),
_phantom: PhantomData,
}
}
pub fn input<S: Schema>(self, schema: S) -> BuilderWithInput<TBaseCtx, TCtx, S::Output, TError>
where
S::Output: Serialize + 'static,
{
let is_passthrough = schema.is_passthrough();
let erased = schema.into_erased();
let validator = if is_passthrough {
None
} else {
make_input_validator()
};
BuilderWithInput {
middleware_chain: self.middleware_chain,
error_map: self.error_map,
route: self.route,
meta: self.meta,
input_schema: erased,
input_validator: validator,
_phantom: PhantomData,
}
}
pub fn handler<F, TOutput>(self, f: F) -> Procedure<TBaseCtx, (), TOutput, TError>
where
F: Handler<TCtx, (), TOutput, TError>,
TOutput: Serialize + Send + 'static,
TError: Into<ProcedureError> + Send + 'static,
{
build_procedure(
self.middleware_chain,
f,
None,
None,
None,
self.error_map,
self.route,
self.meta,
)
}
}
pub struct BuilderWithInput<TBaseCtx, TCtx, TInput, TError = ORPCError> {
middleware_chain: Arc<dyn MiddlewareChain<TBaseCtx, TCtx>>,
error_map: ErrorMap,
route: Route,
meta: Meta,
input_schema: Box<dyn ErasedSchema>,
input_validator: Option<InputValidator>,
_phantom: PhantomData<fn(TInput, TError)>,
}
impl<TBaseCtx: Context, TCtx: Context, TInput, TError>
BuilderWithInput<TBaseCtx, TCtx, TInput, TError>
{
pub fn output<S: Schema>(
self,
schema: S,
) -> BuilderWithIO<TBaseCtx, TCtx, TInput, S::Output, TError> {
BuilderWithIO {
middleware_chain: self.middleware_chain,
error_map: self.error_map,
route: self.route,
meta: self.meta,
input_schema: self.input_schema,
input_validator: self.input_validator,
output_schema: schema.into_erased(),
_phantom: PhantomData,
}
}
pub fn handler<F, TOutput>(self, f: F) -> Procedure<TBaseCtx, TInput, TOutput, TError>
where
F: Handler<TCtx, TInput, TOutput, TError>,
TInput: DeserializeOwned + Send + 'static,
TOutput: Serialize + Send + 'static,
TError: Into<ProcedureError> + Send + 'static,
{
build_procedure(
self.middleware_chain,
f,
Some(self.input_schema),
self.input_validator,
None,
self.error_map,
self.route,
self.meta,
)
}
}
pub struct BuilderWithIO<TBaseCtx, TCtx, TInput, TOutput, TError = ORPCError> {
middleware_chain: Arc<dyn MiddlewareChain<TBaseCtx, TCtx>>,
error_map: ErrorMap,
route: Route,
meta: Meta,
input_schema: Box<dyn ErasedSchema>,
input_validator: Option<InputValidator>,
output_schema: Box<dyn ErasedSchema>,
_phantom: PhantomData<fn(TInput, TOutput, TError)>,
}
impl<TBaseCtx: Context, TCtx: Context, TInput, TOutput, TError>
BuilderWithIO<TBaseCtx, TCtx, TInput, TOutput, TError>
{
pub fn handler<F>(self, f: F) -> Procedure<TBaseCtx, TInput, TOutput, TError>
where
F: Handler<TCtx, TInput, TOutput, TError>,
TInput: DeserializeOwned + Send + 'static,
TOutput: Serialize + Send + 'static,
TError: Into<ProcedureError> + Send + 'static,
{
build_procedure(
self.middleware_chain,
f,
Some(self.input_schema),
self.input_validator,
Some(self.output_schema),
self.error_map,
self.route,
self.meta,
)
}
}
pub struct BuilderWithOutput<TBaseCtx, TCtx, TOutput, TError = ORPCError> {
middleware_chain: Arc<dyn MiddlewareChain<TBaseCtx, TCtx>>,
error_map: ErrorMap,
route: Route,
meta: Meta,
output_schema: Box<dyn ErasedSchema>,
_phantom: PhantomData<fn(TOutput, TError)>,
}
impl<TBaseCtx: Context, TCtx: Context, TOutput, TError>
BuilderWithOutput<TBaseCtx, TCtx, TOutput, TError>
{
pub fn handler<F>(self, f: F) -> Procedure<TBaseCtx, (), TOutput, TError>
where
F: Handler<TCtx, (), TOutput, TError>,
TOutput: Serialize + Send + 'static,
TError: Into<ProcedureError> + Send + 'static,
{
build_procedure(
self.middleware_chain,
f,
None,
None,
Some(self.output_schema),
self.error_map,
self.route,
self.meta,
)
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn build_procedure<TBaseCtx, TCtx, TInput, TOutput, TError, F>(
middleware_chain: Arc<dyn MiddlewareChain<TBaseCtx, TCtx>>,
handler: F,
input_schema: Option<Box<dyn ErasedSchema>>,
input_validator: Option<InputValidator>,
output_schema: Option<Box<dyn ErasedSchema>>,
error_map: ErrorMap,
route: Route,
meta: Meta,
) -> Procedure<TBaseCtx, TInput, TOutput, TError>
where
TBaseCtx: Context,
TCtx: Context,
TInput: DeserializeOwned + Send + 'static,
TOutput: Serialize + Send + 'static,
TError: Into<ProcedureError> + Send + 'static,
F: Handler<TCtx, TInput, TOutput, TError>,
{
let handler = Arc::new(handler);
let route_for_meta = route.clone();
let exec = Arc::new(move |base_ctx: TBaseCtx, dyn_input: DynInput| {
let handler = handler.clone();
let chain = middleware_chain.clone();
let input_validator = input_validator.clone();
let procedure_meta = ProcedureMeta {
route: route_for_meta.clone(),
};
ProcedureStream::from_future(async move {
chain
.run(
base_ctx,
dyn_input,
procedure_meta,
Box::new(move |ctx: TCtx, input: DynInput| -> BoxFuture<'static, Result<DynOutput, ProcedureError>> {
Box::pin(async move {
let input = match input_validator {
Some(ref validator) => validator(input)?,
None => input,
};
let typed_input: TInput = input.deserialize()?;
let result = handler
.call(ctx, typed_input)
.await
.map_err(|e| -> ProcedureError { e.into() })?;
Ok(DynOutput::new(result))
})
}),
)
.await
})
});
Procedure {
exec,
input_schema,
output_schema,
error_map,
route,
meta,
_phantom: PhantomData,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::Identity;
use futures_util::StreamExt;
use serde::Deserialize;
#[derive(Debug, Deserialize, Serialize)]
struct GreetInput {
name: String,
}
async fn greet_handler(_ctx: (), input: GreetInput) -> Result<String, ORPCError> {
Ok(format!("Hello, {}!", input.name))
}
#[tokio::test]
async fn basic_builder_no_middleware() {
let proc = os::<()>()
.route(Route::post("/greet"))
.input(Identity::<GreetInput>::new())
.handler(greet_handler);
let erased = proc.into_erased();
let input = DynInput::from_value(serde_json::json!({"name": "World"}));
let mut stream = erased.exec((), input);
let result = stream.next().await.unwrap().unwrap();
assert_eq!(
result.to_value().unwrap(),
serde_json::json!("Hello, World!")
);
}
#[tokio::test]
async fn builder_with_middleware_context_switch() {
struct AppCtx {
user_id: u32,
}
struct AuthCtx {
user: String,
}
let auth_mw = |ctx: AppCtx, mw: MiddlewareCtx<AuthCtx>| {
Box::pin(async move {
mw.next(AuthCtx {
user: format!("user-{}", ctx.user_id),
})
.await
}) as BoxFuture<'static, Result<MiddlewareOutput, ProcedureError>>
};
async fn handler(ctx: AuthCtx, input: GreetInput) -> Result<String, ORPCError> {
Ok(format!("Hello {}, from {}!", input.name, ctx.user))
}
let proc = os::<AppCtx>()
.use_middleware(auth_mw)
.input(Identity::<GreetInput>::new())
.handler(handler);
let erased = proc.into_erased();
let input = DynInput::from_value(serde_json::json!({"name": "World"}));
let mut stream = erased.exec(AppCtx { user_id: 42 }, input);
let result = stream.next().await.unwrap().unwrap();
assert_eq!(
result.to_value().unwrap(),
serde_json::json!("Hello World, from user-42!")
);
}
#[tokio::test]
async fn builder_no_input_handler() {
async fn ping(_ctx: (), _input: ()) -> Result<String, ORPCError> {
Ok("pong".into())
}
let proc = os::<()>().handler(ping);
let erased = proc.into_erased();
let input = DynInput::from_value(serde_json::json!(null));
let mut stream = erased.exec((), input);
let result = stream.next().await.unwrap().unwrap();
assert_eq!(result.to_value().unwrap(), serde_json::json!("pong"));
}
#[tokio::test]
async fn builder_with_output_schema() {
let proc = os::<()>()
.input(Identity::<GreetInput>::new())
.output(Identity::<String>::new())
.handler(greet_handler);
assert!(proc.input_schema.is_some());
assert!(proc.output_schema.is_some());
let erased = proc.into_erased();
let input = DynInput::from_value(serde_json::json!({"name": "Test"}));
let mut stream = erased.exec((), input);
let result = stream.next().await.unwrap().unwrap();
assert_eq!(
result.to_value().unwrap(),
serde_json::json!("Hello, Test!")
);
}
#[tokio::test]
async fn multiple_calls_to_same_procedure() {
let proc = os::<u32>().input(Identity::<String>::new()).handler(
|ctx: u32, input: String| async move { Ok::<_, ORPCError>(format!("{ctx}:{input}")) },
);
let erased = proc.into_erased();
for i in 0..3 {
let input = DynInput::from_value(serde_json::json!(format!("call-{i}")));
let mut stream = erased.exec(i, input);
let result = stream.next().await.unwrap().unwrap();
assert_eq!(
result.to_value().unwrap(),
serde_json::json!(format!("{i}:call-{i}"))
);
}
}
}