use std::marker::PhantomData;
use std::sync::Arc;
use orpc_procedure::{ErasedSchema, ProcedureError};
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::builder::{Builder, build_procedure};
use crate::context::Context;
use crate::contract::ContractProcedure;
use crate::handler::{BoxFuture, Handler};
use crate::middleware::{IdentityChain, MiddlewareCtx, MiddlewareOutput};
use crate::procedure::Procedure;
pub fn implement<TBaseCtx: Context, TInput, TOutput, TError>(
contract: ContractProcedure<TInput, TOutput, TError>,
) -> ContractImplementer<TBaseCtx, TInput, TOutput, TError>
where
TError: Into<ProcedureError> + Send + 'static,
{
ContractImplementer {
builder: Builder {
middleware_chain: Arc::new(IdentityChain),
error_map: contract.error_map,
route: contract.route,
meta: contract.meta,
_phantom: PhantomData,
},
input_schema: contract.input_schema,
output_schema: contract.output_schema,
_phantom: PhantomData,
}
}
pub struct ContractImplementer<TBaseCtx, TInput, TOutput, TError> {
builder: Builder<TBaseCtx, TBaseCtx, TError>,
input_schema: Option<Box<dyn ErasedSchema>>,
output_schema: Option<Box<dyn ErasedSchema>>,
_phantom: PhantomData<fn(TInput, TOutput)>,
}
impl<TBaseCtx: Context, TInput, TOutput, TError>
ContractImplementer<TBaseCtx, TInput, TOutput, TError>
where
TInput: DeserializeOwned + Send + 'static,
TOutput: Serialize + Send + 'static,
TError: Into<ProcedureError> + Send + 'static,
{
pub fn use_middleware<TNextCtx, M>(
self,
m: M,
) -> ContractImplementerWithMw<TBaseCtx, TNextCtx, TInput, TOutput, TError>
where
M: Fn(
TBaseCtx,
MiddlewareCtx<TNextCtx>,
) -> BoxFuture<'static, Result<MiddlewareOutput, ProcedureError>>
+ Send
+ Sync
+ 'static,
TNextCtx: Context,
{
ContractImplementerWithMw {
builder: self.builder.use_middleware(m),
input_schema: self.input_schema,
output_schema: self.output_schema,
_phantom: PhantomData,
}
}
pub fn handler<F>(self, f: F) -> Procedure<TBaseCtx, TInput, TOutput, TError>
where
F: Handler<TBaseCtx, TInput, TOutput, TError>,
{
build_procedure(
self.builder.middleware_chain,
f,
self.input_schema,
None, self.output_schema,
self.builder.error_map,
self.builder.route,
self.builder.meta,
)
}
}
pub struct ContractImplementerWithMw<TBaseCtx, TCtx, TInput, TOutput, TError> {
builder: Builder<TBaseCtx, TCtx, TError>,
input_schema: Option<Box<dyn ErasedSchema>>,
output_schema: Option<Box<dyn ErasedSchema>>,
_phantom: PhantomData<fn(TInput, TOutput)>,
}
impl<TBaseCtx: Context, TCtx: Context, TInput, TOutput, TError>
ContractImplementerWithMw<TBaseCtx, TCtx, TInput, TOutput, TError>
where
TInput: DeserializeOwned + Send + 'static,
TOutput: Serialize + Send + 'static,
TError: Into<ProcedureError> + Send + 'static,
{
pub fn use_middleware<TNextCtx, M>(
self,
m: M,
) -> ContractImplementerWithMw<TBaseCtx, TNextCtx, TInput, TOutput, TError>
where
M: Fn(
TCtx,
MiddlewareCtx<TNextCtx>,
) -> BoxFuture<'static, Result<MiddlewareOutput, ProcedureError>>
+ Send
+ Sync
+ 'static,
TNextCtx: Context,
{
ContractImplementerWithMw {
builder: self.builder.use_middleware(m),
input_schema: self.input_schema,
output_schema: self.output_schema,
_phantom: PhantomData,
}
}
pub fn handler<F>(self, f: F) -> Procedure<TBaseCtx, TInput, TOutput, TError>
where
F: Handler<TCtx, TInput, TOutput, TError>,
{
build_procedure(
self.builder.middleware_chain,
f,
self.input_schema,
None, self.output_schema,
self.builder.error_map,
self.builder.route,
self.builder.meta,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::contract::oc;
use crate::schema::Identity;
use futures_util::StreamExt;
use orpc_procedure::{DynInput, Route};
use serde::Deserialize;
#[derive(Debug, Deserialize, Serialize, PartialEq)]
struct GetUserInput {
id: u32,
}
#[derive(Debug, Deserialize, Serialize, PartialEq)]
struct User {
id: u32,
name: String,
}
async fn get_user_handler(_ctx: (), input: GetUserInput) -> Result<User, crate::ORPCError> {
Ok(User {
id: input.id,
name: format!("User {}", input.id),
})
}
#[tokio::test]
async fn implement_without_middleware() {
let contract = oc()
.route(Route::get("/users/{id}"))
.input(Identity::<GetUserInput>::new())
.output(Identity::<User>::new())
.build();
let proc = implement::<(), _, _, _>(contract).handler(get_user_handler);
let erased = proc.into_erased();
let input = DynInput::from_value(serde_json::json!({"id": 42}));
let mut stream = erased.exec((), input);
let result = stream.next().await.unwrap().unwrap();
let user: User = serde_json::from_value(result.to_value().unwrap()).unwrap();
assert_eq!(
user,
User {
id: 42,
name: "User 42".into()
}
);
}
#[tokio::test]
async fn implement_with_middleware() {
struct AppCtx {
token: String,
}
struct AuthCtx {
user_name: String,
}
let auth_mw = |ctx: AppCtx, mw: MiddlewareCtx<AuthCtx>| {
Box::pin(async move {
if ctx.token == "valid" {
mw.next(AuthCtx {
user_name: "Alice".into(),
})
.await
} else {
Err(ProcedureError::Resolver(Box::new(
crate::ORPCError::unauthorized("Invalid token"),
)))
}
}) as BoxFuture<'static, Result<MiddlewareOutput, ProcedureError>>
};
async fn handler(ctx: AuthCtx, input: GetUserInput) -> Result<User, crate::ORPCError> {
Ok(User {
id: input.id,
name: ctx.user_name,
})
}
let contract = oc()
.input(Identity::<GetUserInput>::new())
.output(Identity::<User>::new())
.build();
let proc = implement::<AppCtx, _, _, _>(contract)
.use_middleware(auth_mw)
.handler(handler);
let erased = proc.into_erased();
let input = DynInput::from_value(serde_json::json!({"id": 1}));
let mut stream = erased.exec(
AppCtx {
token: "valid".into(),
},
input,
);
let result = stream.next().await.unwrap().unwrap();
let user: User = serde_json::from_value(result.to_value().unwrap()).unwrap();
assert_eq!(user.name, "Alice");
let input = DynInput::from_value(serde_json::json!({"id": 1}));
let mut stream = erased.exec(
AppCtx {
token: "bad".into(),
},
input,
);
let result = stream.next().await.unwrap();
assert!(matches!(result, Err(ProcedureError::Resolver(_))));
}
#[test]
fn contract_preserves_route_in_procedure() {
let contract = oc()
.route(Route::get("/users/{id}").tag("users"))
.input(Identity::<GetUserInput>::new())
.output(Identity::<User>::new())
.build();
let proc = implement::<(), _, _, _>(contract).handler(get_user_handler);
let erased = proc.into_erased();
assert_eq!(erased.route.method, Some(orpc_procedure::HttpMethod::Get));
assert_eq!(erased.route.path.as_deref(), Some("/users/{id}"));
assert_eq!(erased.route.tags, vec!["users"]);
}
}