use std::sync::Arc;
use futures::future::BoxFuture;
use futures::stream::BoxStream;
pub type UnaryHandler =
Arc<dyn Fn(serde_json::Value) -> BoxFuture<'static, ProcedureResult> + Send + Sync>;
pub type StreamHandler =
Arc<dyn Fn(serde_json::Value) -> BoxStream<'static, StreamFrame> + Send + Sync>;
pub type ProcedureHandler = UnaryHandler;
pub enum ProcedureResult {
Ok(serde_json::Value),
Err {
http_status: u16,
code: String,
payload: serde_json::Value,
},
}
impl ProcedureResult {
pub fn ok(value: impl serde::Serialize) -> Self {
match serde_json::to_value(&value) {
Ok(v) => ProcedureResult::Ok(v),
Err(_) => ProcedureResult::Err {
http_status: 500,
code: "serialization_error".to_string(),
payload: serde_json::Value::Null,
},
}
}
pub fn err(http_status: u16, code: impl Into<String>, payload: impl serde::Serialize) -> Self {
match serde_json::to_value(&payload) {
Ok(payload) => ProcedureResult::Err {
http_status,
code: code.into(),
payload,
},
Err(_) => ProcedureResult::Err {
http_status: 500,
code: "serialization_error".to_string(),
payload: serde_json::Value::Null,
},
}
}
#[allow(clippy::needless_pass_by_value)] pub fn from_taut_error<E: crate::TautError>(e: E) -> Self {
let code = e.code().to_string();
let http_status = e.http_status();
let payload = serde_json::to_value(&e).unwrap_or(serde_json::Value::Null);
ProcedureResult::Err {
http_status,
code,
payload,
}
}
#[must_use]
#[allow(clippy::needless_pass_by_value)] pub fn from_serialization(_e: serde_json::Error) -> Self {
ProcedureResult::Err {
http_status: 500,
code: "serialization_error".to_string(),
payload: serde_json::Value::Null,
}
}
}
#[derive(Debug, Clone)]
pub enum StreamFrame {
Data(serde_json::Value),
Error {
code: String,
payload: serde_json::Value,
},
}
impl StreamFrame {
pub fn data(value: impl serde::Serialize) -> Self {
match serde_json::to_value(&value) {
Ok(v) => StreamFrame::Data(v),
Err(_) => StreamFrame::Error {
code: "serialization_error".to_string(),
payload: serde_json::Value::Null,
},
}
}
pub fn err(code: impl Into<String>, payload: impl serde::Serialize) -> Self {
match serde_json::to_value(&payload) {
Ok(payload) => StreamFrame::Error {
code: code.into(),
payload,
},
Err(_) => StreamFrame::Error {
code: "serialization_error".to_string(),
payload: serde_json::Value::Null,
},
}
}
#[allow(clippy::needless_pass_by_value)] pub fn from_taut_error<E: crate::TautError>(e: E) -> Self {
let code = e.code().to_string();
let payload = serde_json::to_value(&e).unwrap_or(serde_json::Value::Null);
StreamFrame::Error { code, payload }
}
}
#[derive(Clone)]
pub enum ProcedureBody {
Unary(UnaryHandler),
Stream(StreamHandler),
}
#[derive(Clone)]
pub struct ProcedureDescriptor {
pub name: &'static str,
pub kind: crate::router::ProcKindRuntime,
pub ir: crate::ir::Procedure,
pub type_defs: Vec<crate::ir::TypeDef>,
pub body: ProcedureBody,
}
impl std::fmt::Debug for ProcedureDescriptor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let body_kind = match &self.body {
ProcedureBody::Unary(_) => "Unary",
ProcedureBody::Stream(_) => "Stream",
};
f.debug_struct("ProcedureDescriptor")
.field("name", &self.name)
.field("kind", &self.kind)
.field("body", &body_kind)
.field("input", &self.ir.input)
.field("output", &self.ir.output)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ok_serializes_to_expected_json_value() {
let r = ProcedureResult::ok(42u32);
match r {
ProcedureResult::Ok(v) => assert_eq!(v, serde_json::json!(42)),
ProcedureResult::Err { .. } => panic!("expected Ok"),
}
}
#[test]
fn err_builds_envelope_with_supplied_fields() {
let r = ProcedureResult::err(404, "not_found", serde_json::Value::Null);
match r {
ProcedureResult::Err {
http_status,
code,
payload,
} => {
assert_eq!(http_status, 404);
assert_eq!(code, "not_found");
assert_eq!(payload, serde_json::Value::Null);
}
ProcedureResult::Ok(_) => panic!("expected Err"),
}
}
#[test]
fn from_taut_error_preserves_code_and_status() {
let r = ProcedureResult::from_taut_error(crate::error::StandardError::Unauthenticated);
match r {
ProcedureResult::Err {
http_status, code, ..
} => {
assert_eq!(code, "unauthenticated");
assert_eq!(http_status, 401);
}
ProcedureResult::Ok(_) => panic!("expected Err"),
}
}
#[test]
fn ok_payload_roundtrips_through_serde_json_string() {
let value = serde_json::json!({ "id": 7, "name": "ada" });
let r = ProcedureResult::Ok(value.clone());
let encoded = match r {
ProcedureResult::Ok(v) => serde_json::to_string(&v).expect("encode"),
ProcedureResult::Err { .. } => panic!("expected Ok"),
};
let decoded: serde_json::Value = serde_json::from_str(&encoded).expect("decode");
assert_eq!(decoded, value);
}
fn dummy_procedure_ir(name: &str) -> crate::ir::Procedure {
use crate::ir::{HttpMethod, Primitive, ProcKind, TypeRef};
crate::ir::Procedure {
name: name.to_string(),
kind: ProcKind::Query,
input: TypeRef::Primitive(Primitive::Unit),
output: TypeRef::Primitive(Primitive::Unit),
errors: vec![],
http_method: HttpMethod::Post,
doc: None,
}
}
#[tokio::test]
async fn unary_body_dispatches_through_handler() {
let handler: UnaryHandler = Arc::new(|input: serde_json::Value| {
Box::pin(async move { ProcedureResult::Ok(input) })
});
let desc = ProcedureDescriptor {
name: "echo",
kind: crate::router::ProcKindRuntime::Query,
ir: dummy_procedure_ir("echo"),
type_defs: vec![],
body: ProcedureBody::Unary(handler),
};
let h = match &desc.body {
ProcedureBody::Unary(h) => h.clone(),
ProcedureBody::Stream(_) => panic!("expected Unary body"),
};
let result = h(serde_json::json!({"hello": "world"})).await;
match result {
ProcedureResult::Ok(v) => assert_eq!(v, serde_json::json!({"hello": "world"})),
ProcedureResult::Err { .. } => panic!("expected Ok"),
}
}
#[tokio::test]
async fn stream_body_emits_collected_frames() {
use futures::stream::{self, StreamExt};
let handler: StreamHandler = Arc::new(|_input: serde_json::Value| {
let frames = vec![
StreamFrame::Data(serde_json::json!(1)),
StreamFrame::Data(serde_json::json!(2)),
StreamFrame::Data(serde_json::json!(3)),
];
stream::iter(frames).boxed()
});
let desc = ProcedureDescriptor {
name: "counter",
kind: crate::router::ProcKindRuntime::Subscription,
ir: dummy_procedure_ir("counter"),
type_defs: vec![],
body: ProcedureBody::Stream(handler),
};
let s = match &desc.body {
ProcedureBody::Stream(s) => s.clone(),
ProcedureBody::Unary(_) => panic!("expected Stream body"),
};
let frames: Vec<StreamFrame> = s(serde_json::Value::Null).collect().await;
assert_eq!(frames.len(), 3);
let values: Vec<serde_json::Value> = frames
.into_iter()
.map(|f| match f {
StreamFrame::Data(v) => v,
StreamFrame::Error { .. } => panic!("expected Data frame"),
})
.collect();
assert_eq!(
values,
vec![
serde_json::json!(1),
serde_json::json!(2),
serde_json::json!(3),
]
);
}
#[test]
fn stream_frame_data_serializes_payload_in_place() {
let f = StreamFrame::data(42u32);
match f {
StreamFrame::Data(v) => assert_eq!(v, serde_json::json!(42)),
StreamFrame::Error { .. } => panic!("expected Data variant"),
}
}
#[test]
fn stream_frame_err_builds_error_variant() {
let f = StreamFrame::err("rate_limited", serde_json::json!({"retry_after": 5}));
match f {
StreamFrame::Error { code, payload } => {
assert_eq!(code, "rate_limited");
assert_eq!(payload, serde_json::json!({"retry_after": 5}));
}
StreamFrame::Data(_) => panic!("expected Error variant"),
}
}
#[test]
fn stream_frame_from_taut_error_preserves_code() {
let f = StreamFrame::from_taut_error(crate::error::StandardError::Unauthenticated);
match f {
StreamFrame::Error { code, .. } => assert_eq!(code, "unauthenticated"),
StreamFrame::Data(_) => panic!("expected Error variant"),
}
}
}