use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
parse::{Parse, ParseStream},
parse2,
spanned::Spanned,
FnArg, GenericArgument, Ident, ItemFn, LitStr, PathArguments, ReturnType, Token,
TraitBoundModifier, Type, TypeImplTrait, TypeParamBound, TypePath,
};
fn err(span: Span, msg: &str, spec_anchor: &str) -> syn::Error {
syn::Error::new(span, format!("taut_rpc: {msg}\n see SPEC §{spec_anchor}"))
}
#[derive(Debug, Clone, Copy)]
enum ProcKind {
Query,
Mutation,
Stream,
}
#[derive(Debug)]
struct RpcArgs {
kind: ProcKind,
_method_override: Option<Span>,
}
impl Parse for RpcArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
if input.is_empty() {
return Ok(Self {
kind: ProcKind::Query,
_method_override: None,
});
}
let mut kind = ProcKind::Query;
let mut method_override: Option<Span> = None;
loop {
if input.is_empty() {
break;
}
let ident: Ident = input.parse()?;
if ident == "stream" {
kind = ProcKind::Stream;
} else if ident == "mutation" {
kind = ProcKind::Mutation;
} else if ident == "query" {
kind = ProcKind::Query;
} else if ident == "method" {
input.parse::<Token![=]>()?;
let lit: LitStr = input.parse()?;
method_override = Some(lit.span());
} else {
return Err(err(
ident.span(),
&format!(
"unrecognised `#[rpc]` argument `{ident}`; \
expected one of `query`, `mutation`, `stream`, `method = \"...\"`"
),
"5",
));
}
if input.is_empty() {
break;
}
input.parse::<Token![,]>()?;
}
Ok(Self {
kind,
_method_override: method_override,
})
}
}
struct ReturnShape {
success: Type,
error: Option<Type>,
}
fn match_result_type(ty: &Type) -> Option<(Type, Type)> {
let TypePath { qself: None, path } = (match ty {
Type::Path(p) => p,
_ => return None,
}) else {
return None;
};
let last = path.segments.last()?;
if last.ident != "Result" {
return None;
}
let PathArguments::AngleBracketed(args) = &last.arguments else {
return None;
};
let mut types = args.args.iter().filter_map(|a| match a {
GenericArgument::Type(t) => Some(t.clone()),
_ => None,
});
let ok = types.next()?;
let err = types.next()?;
if types.next().is_some() {
return None;
}
Some((ok, err))
}
fn classify_return(output: &ReturnType) -> ReturnShape {
match output {
ReturnType::Default => ReturnShape {
success: syn::parse_quote!(()),
error: None,
},
ReturnType::Type(_, ty) => match match_result_type(ty) {
Some((ok, err)) => ReturnShape {
success: ok,
error: Some(err),
},
None => ReturnShape {
success: (**ty).clone(),
error: None,
},
},
}
}
fn extract_stream_item(output: &ReturnType) -> Option<Type> {
let ty = match output {
ReturnType::Type(_, ty) => &**ty,
ReturnType::Default => return None,
};
let Type::ImplTrait(TypeImplTrait { bounds, .. }) = ty else {
return None;
};
for bound in bounds {
let TypeParamBound::Trait(tb) = bound else {
continue;
};
if !matches!(tb.modifier, TraitBoundModifier::None) {
continue;
}
let Some(last) = tb.path.segments.last() else {
continue;
};
if last.ident != "Stream" {
continue;
}
let PathArguments::AngleBracketed(args) = &last.arguments else {
continue;
};
for arg in &args.args {
if let GenericArgument::AssocType(assoc) = arg {
if assoc.ident == "Item" {
return Some(assoc.ty.clone());
}
}
}
}
None
}
fn extract_input_type(func: &ItemFn) -> syn::Result<Option<Type>> {
let inputs = &func.sig.inputs;
if let Some(FnArg::Receiver(rcv)) = inputs.first() {
return Err(err(
rcv.span(),
"#[rpc] cannot be applied to methods; use a free-standing async fn",
"5",
));
}
match inputs.len() {
0 => Ok(None),
1 => match inputs.first().expect("len == 1") {
FnArg::Typed(pat_type) => Ok(Some((*pat_type.ty).clone())),
FnArg::Receiver(_) => unreachable!("handled above"),
},
_ => Err(err(
inputs.span(),
"multi-argument procedures are not yet supported in v0.1; \
wrap your arguments in a struct that derives Type and Deserialize",
"5",
)),
}
}
fn validate_common(func: &ItemFn) -> syn::Result<()> {
if func.sig.asyncness.is_none() {
return Err(err(
func.sig.fn_token.span(),
"#[rpc] requires an async fn",
"5",
));
}
if !func.sig.generics.params.is_empty() || func.sig.generics.where_clause.is_some() {
return Err(err(
func.sig.generics.span(),
"generic procedures are not supported in v0.1; \
monomorphise by writing a wrapper fn with concrete types",
"5",
));
}
if let Some(variadic) = &func.sig.variadic {
return Err(err(
variadic.span(),
"variadic procedures are not supported",
"5",
));
}
Ok(())
}
pub(crate) fn expand(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream> {
let args: RpcArgs = parse2(attr)?;
let func: ItemFn = parse2(item)?;
validate_common(&func)?;
match args.kind {
ProcKind::Query | ProcKind::Mutation => expand_unary(&func, args.kind),
ProcKind::Stream => expand_stream(&func),
}
}
#[allow(clippy::too_many_lines)] fn expand_unary(func: &ItemFn, kind: ProcKind) -> syn::Result<TokenStream> {
let fn_ident = func.sig.ident.clone();
let fn_name_str = fn_ident.to_string();
let descriptor_ident = Ident::new(&format!("__taut_proc_{fn_name_str}"), Span::call_site());
let input_ty_opt = extract_input_type(func)?;
let return_shape = classify_return(&func.sig.output);
let (ir_kind_tok, runtime_kind_tok) = match kind {
ProcKind::Query => (
quote!(::taut_rpc::ir::ProcKind::Query),
quote!(::taut_rpc::ProcKindRuntime::Query),
),
ProcKind::Mutation => (
quote!(::taut_rpc::ir::ProcKind::Mutation),
quote!(::taut_rpc::ProcKindRuntime::Mutation),
),
ProcKind::Stream => unreachable!("expand_unary called with ProcKind::Stream"),
};
let input_ir_expr = if let Some(ty) = &input_ty_opt {
quote!(<#ty as ::taut_rpc::TautType>::ir_type_ref())
} else {
quote!(<() as ::taut_rpc::TautType>::ir_type_ref())
};
let input_collect = if let Some(ty) = &input_ty_opt {
quote!(<#ty as ::taut_rpc::TautType>::collect_type_defs(&mut type_defs);)
} else {
quote!(<() as ::taut_rpc::TautType>::collect_type_defs(&mut type_defs);)
};
let success_ty = &return_shape.success;
let output_ir_expr = quote!(<#success_ty as ::taut_rpc::TautType>::ir_type_ref());
let output_collect =
quote!(<#success_ty as ::taut_rpc::TautType>::collect_type_defs(&mut type_defs););
let (errors_vec_expr, error_collect) = match &return_shape.error {
Some(err_ty) => (
quote!(vec![<#err_ty as ::taut_rpc::TautType>::ir_type_ref()]),
quote!(<#err_ty as ::taut_rpc::TautType>::collect_type_defs(&mut type_defs);),
),
None => (quote!(::std::vec::Vec::new()), quote!()),
};
let decode_block = if let Some(ty) = &input_ty_opt {
quote! {
let __input: #ty = match ::serde_json::from_value(__input_value) {
::std::result::Result::Ok(v) => v,
::std::result::Result::Err(e) => {
return ::taut_rpc::ProcedureResult::Err {
http_status: 400,
code: ::std::string::String::from("decode_error"),
payload: ::serde_json::json!({ "message": e.to_string() }),
};
}
};
if let ::std::result::Result::Err(__errors) =
<#ty as ::taut_rpc::Validate>::validate(&__input)
{
return ::taut_rpc::ProcedureResult::Err {
http_status: 400,
code: ::std::string::String::from("validation_error"),
payload: ::serde_json::json!({ "errors": __errors }),
};
}
}
} else {
quote! {
if !__input_value.is_null() {
return ::taut_rpc::ProcedureResult::Err {
http_status: 400,
code: ::std::string::String::from("decode_error"),
payload: ::serde_json::json!({ "message": "expected null input" }),
};
}
}
};
let call_expr = if input_ty_opt.is_some() {
quote!(#fn_ident(__input).await)
} else {
quote!(#fn_ident().await)
};
let result_block = if return_shape.error.is_some() {
quote! {
match #call_expr {
::std::result::Result::Ok(__out) => match ::serde_json::to_value(&__out) {
::std::result::Result::Ok(__v) => ::taut_rpc::ProcedureResult::Ok(__v),
::std::result::Result::Err(__e) => ::taut_rpc::ProcedureResult::Err {
http_status: 500,
code: ::std::string::String::from("serialization_error"),
payload: ::serde_json::json!({ "message": __e.to_string() }),
},
},
::std::result::Result::Err(__err) => {
let __code = ::taut_rpc::TautError::code(&__err).to_string();
let __http_status = ::taut_rpc::TautError::http_status(&__err);
let __payload = ::serde_json::to_value(&__err)
.unwrap_or(::serde_json::Value::Null);
::taut_rpc::ProcedureResult::Err {
http_status: __http_status,
code: __code,
payload: __payload,
}
},
}
}
} else {
quote! {
let __out = #call_expr;
match ::serde_json::to_value(&__out) {
::std::result::Result::Ok(__v) => ::taut_rpc::ProcedureResult::Ok(__v),
::std::result::Result::Err(__e) => ::taut_rpc::ProcedureResult::Err {
http_status: 500,
code: ::std::string::String::from("serialization_error"),
payload: ::serde_json::json!({ "message": __e.to_string() }),
},
}
}
};
let doc_expr = extract_doc_expr(func);
let descriptor = quote! {
#[allow(non_snake_case)]
pub fn #descriptor_ident() -> ::taut_rpc::ProcedureDescriptor {
let input_ty = #input_ir_expr;
let output_ty = #output_ir_expr;
let mut type_defs: ::std::vec::Vec<::taut_rpc::ir::TypeDef> =
::std::vec::Vec::new();
#input_collect
#output_collect
#error_collect
{
let mut seen = ::std::collections::HashSet::<::std::string::String>::new();
type_defs.retain(|d| seen.insert(d.name.clone()));
}
::taut_rpc::ProcedureDescriptor {
name: #fn_name_str,
kind: #runtime_kind_tok,
ir: ::taut_rpc::ir::Procedure {
name: ::std::string::String::from(#fn_name_str),
kind: #ir_kind_tok,
input: input_ty,
output: output_ty,
errors: #errors_vec_expr,
http_method: ::taut_rpc::ir::HttpMethod::Post,
doc: #doc_expr,
},
type_defs,
body: ::taut_rpc::ProcedureBody::Unary(
::std::sync::Arc::new(|__input_value: ::serde_json::Value| {
::std::boxed::Box::pin(async move {
#decode_block
#result_block
})
})
),
}
}
};
Ok(quote! {
#func
#descriptor
})
}
#[allow(clippy::too_many_lines)] fn expand_stream(func: &ItemFn) -> syn::Result<TokenStream> {
let fn_ident = func.sig.ident.clone();
let fn_name_str = fn_ident.to_string();
let descriptor_ident = Ident::new(&format!("__taut_proc_{fn_name_str}"), Span::call_site());
let input_ty_opt = extract_input_type(func)?;
let item_ty = extract_stream_item(&func.sig.output).ok_or_else(|| {
err(
func.sig.output.span(),
"#[rpc(stream)] requires `async fn ... -> impl Stream<Item = T>`",
"5.1",
)
})?;
let input_ir_expr = if let Some(ty) = &input_ty_opt {
quote!(<#ty as ::taut_rpc::TautType>::ir_type_ref())
} else {
quote!(<() as ::taut_rpc::TautType>::ir_type_ref())
};
let input_collect = if let Some(ty) = &input_ty_opt {
quote!(<#ty as ::taut_rpc::TautType>::collect_type_defs(&mut type_defs);)
} else {
quote!(<() as ::taut_rpc::TautType>::collect_type_defs(&mut type_defs);)
};
let output_ir_expr = quote!(<#item_ty as ::taut_rpc::TautType>::ir_type_ref());
let output_collect =
quote!(<#item_ty as ::taut_rpc::TautType>::collect_type_defs(&mut type_defs););
let decode_block = if let Some(ty) = &input_ty_opt {
quote! {
let __input: #ty = match ::serde_json::from_value(__input_value) {
::std::result::Result::Ok(v) => v,
::std::result::Result::Err(__e) => {
yield ::taut_rpc::StreamFrame::Error {
code: ::std::string::String::from("decode_error"),
payload: ::serde_json::json!({ "message": __e.to_string() }),
};
return;
}
};
if let ::std::result::Result::Err(__errors) =
<#ty as ::taut_rpc::Validate>::validate(&__input)
{
yield ::taut_rpc::StreamFrame::Error {
code: ::std::string::String::from("validation_error"),
payload: ::serde_json::json!({ "errors": __errors }),
};
return;
}
}
} else {
quote! {
if !__input_value.is_null() {
yield ::taut_rpc::StreamFrame::Error {
code: ::std::string::String::from("decode_error"),
payload: ::serde_json::json!({ "message": "expected null input" }),
};
return;
}
}
};
let call_expr = if input_ty_opt.is_some() {
quote!(#fn_ident(__input).await)
} else {
quote!(#fn_ident().await)
};
let doc_expr = extract_doc_expr(func);
let descriptor = quote! {
#[allow(non_snake_case)]
pub fn #descriptor_ident() -> ::taut_rpc::ProcedureDescriptor {
let input_ty = #input_ir_expr;
let output_ty = #output_ir_expr;
let mut type_defs: ::std::vec::Vec<::taut_rpc::ir::TypeDef> =
::std::vec::Vec::new();
#input_collect
#output_collect
{
let mut seen = ::std::collections::HashSet::<::std::string::String>::new();
type_defs.retain(|d| seen.insert(d.name.clone()));
}
::taut_rpc::ProcedureDescriptor {
name: #fn_name_str,
kind: ::taut_rpc::ProcKindRuntime::Subscription,
ir: ::taut_rpc::ir::Procedure {
name: ::std::string::String::from(#fn_name_str),
kind: ::taut_rpc::ir::ProcKind::Subscription,
input: input_ty,
output: output_ty,
errors: ::std::vec::Vec::new(),
http_method: ::taut_rpc::ir::HttpMethod::Get,
doc: #doc_expr,
},
type_defs,
body: ::taut_rpc::ProcedureBody::Stream(
::std::sync::Arc::new(|__input_value: ::serde_json::Value| {
::std::boxed::Box::pin(::async_stream::stream! {
#decode_block
let __inner = #call_expr;
::futures::pin_mut!(__inner);
while let ::std::option::Option::Some(__item) =
::futures::StreamExt::next(&mut __inner).await
{
match ::serde_json::to_value(&__item) {
::std::result::Result::Ok(__v) => {
yield ::taut_rpc::StreamFrame::Data(__v);
}
::std::result::Result::Err(__e) => {
yield ::taut_rpc::StreamFrame::Error {
code: ::std::string::String::from("serialization_error"),
payload: ::serde_json::json!({ "message": __e.to_string() }),
};
return;
}
}
}
})
})
),
}
}
};
Ok(quote! {
#func
#descriptor
})
}
fn extract_doc_expr(func: &ItemFn) -> TokenStream {
let mut lines: Vec<String> = Vec::new();
for attr in &func.attrs {
if !attr.path().is_ident("doc") {
continue;
}
if let syn::Meta::NameValue(nv) = &attr.meta {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(s),
..
}) = &nv.value
{
lines.push(s.value());
}
}
}
if lines.is_empty() {
quote!(::std::option::Option::None)
} else {
let joined = lines
.iter()
.map(|l| l.strip_prefix(' ').unwrap_or(l))
.collect::<Vec<_>>()
.join("\n");
quote!(::std::option::Option::Some(::std::string::String::from(#joined)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn match_result_recognises_bare_result() {
let ty: Type = parse_quote!(Result<i32, MyError>);
let (ok, err) = match_result_type(&ty).expect("should match");
let ok_str = quote!(#ok).to_string();
let err_str = quote!(#err).to_string();
assert_eq!(ok_str, "i32");
assert_eq!(err_str, "MyError");
}
#[test]
fn match_result_recognises_qualified_result() {
let ty: Type = parse_quote!(::std::result::Result<String, ()>);
assert!(match_result_type(&ty).is_some());
}
#[test]
fn match_result_rejects_non_result_path() {
let ty: Type = parse_quote!(Option<i32>);
assert!(match_result_type(&ty).is_none());
}
#[test]
fn match_result_rejects_wrong_arity() {
let ty: Type = parse_quote!(Result<i32>);
assert!(match_result_type(&ty).is_none());
}
#[test]
fn match_result_rejects_non_path_types() {
let ty: Type = parse_quote!(&str);
assert!(match_result_type(&ty).is_none());
let ty: Type = parse_quote!((i32, String));
assert!(match_result_type(&ty).is_none());
}
#[test]
fn classify_return_default_is_unit_no_error() {
let rt: ReturnType = ReturnType::Default;
let shape = classify_return(&rt);
assert!(shape.error.is_none());
let success = shape.success;
let rendered = quote!(#success).to_string();
assert_eq!(rendered.replace(' ', ""), "()");
}
#[test]
fn classify_return_plain_type() {
let rt: ReturnType = parse_quote!(-> String);
let shape = classify_return(&rt);
assert!(shape.error.is_none());
let success = shape.success;
assert_eq!(quote!(#success).to_string(), "String");
}
#[test]
fn classify_return_result_type() {
let rt: ReturnType = parse_quote!(-> Result<i32, AddError>);
let shape = classify_return(&rt);
let success = shape.success;
let err = shape.error.expect("error should be present");
assert_eq!(quote!(#success).to_string(), "i32");
assert_eq!(quote!(#err).to_string(), "AddError");
}
#[test]
fn extract_input_type_zero_args() {
let func: ItemFn = parse_quote!(
async fn ping() -> String {
String::new()
}
);
let got = extract_input_type(&func).unwrap();
assert!(got.is_none());
}
#[test]
fn extract_input_type_one_arg() {
let func: ItemFn = parse_quote!(
async fn add(input: AddInput) -> i32 {
0
}
);
let got = extract_input_type(&func).unwrap().unwrap();
assert_eq!(quote!(#got).to_string(), "AddInput");
}
#[test]
fn extract_input_type_rejects_multi_arg() {
let func: ItemFn = parse_quote!(
async fn add(a: i32, b: i32) -> i32 {
a + b
}
);
let err = extract_input_type(&func).unwrap_err();
assert!(err.to_string().contains("wrap your arguments in a struct"));
}
#[test]
fn extract_input_type_rejects_self_receiver() {
let func: ItemFn = parse_quote!(
async fn ping(&self) -> String {
String::new()
}
);
let err = extract_input_type(&func).unwrap_err();
assert!(err.to_string().contains("methods"));
}
#[test]
fn rpc_args_default_is_query() {
let args: RpcArgs = parse2(quote!()).unwrap();
assert!(matches!(args.kind, ProcKind::Query));
}
#[test]
fn rpc_args_mutation() {
let args: RpcArgs = parse2(quote!(mutation)).unwrap();
assert!(matches!(args.kind, ProcKind::Mutation));
}
#[test]
fn rpc_args_stream() {
let args: RpcArgs = parse2(quote!(stream)).unwrap();
assert!(matches!(args.kind, ProcKind::Stream));
}
#[test]
fn rpc_args_method_accepts_get() {
let args: RpcArgs = parse2(quote!(method = "GET")).unwrap();
assert!(matches!(args.kind, ProcKind::Query));
}
#[test]
fn rpc_args_unknown_token_errors() {
let err = parse2::<RpcArgs>(quote!(banana)).unwrap_err();
assert!(err.to_string().contains("unrecognised"));
}
#[test]
fn expand_rejects_non_async() {
let item = quote! {
fn ping() -> String { String::new() }
};
let err = expand(quote!(), item).unwrap_err();
assert!(err.to_string().contains("requires an async fn"));
}
#[test]
fn expand_rejects_generic_fn() {
let item = quote! {
async fn ping<T>() -> T { todo!() }
};
let err = expand(quote!(), item).unwrap_err();
assert!(err.to_string().contains("generic"));
}
#[test]
fn expand_rejects_multi_arg() {
let item = quote! {
async fn add(a: i32, b: i32) -> i32 { a + b }
};
let err = expand(quote!(), item).unwrap_err();
assert!(err.to_string().contains("wrap your arguments in a struct"));
}
#[test]
fn expand_emits_descriptor_for_simple_fn() {
let item = quote! {
async fn ping() -> String { String::from("pong") }
};
let out = expand(quote!(), item).unwrap().to_string();
assert!(out.contains("__taut_proc_ping"));
assert!(out.contains("ProcedureDescriptor"));
assert!(out.contains("ProcKindRuntime :: Query"));
assert!(
out.contains("ProcedureBody :: Unary"),
"expected unary handler to be wrapped in ProcedureBody::Unary, got: {out}"
);
}
#[test]
fn expand_result_fn_uses_taut_error_trait_methods() {
let item = quote! {
async fn fails() -> Result<i32, MyErr> { todo!() }
};
let out = expand(quote!(), item).unwrap().to_string();
assert!(
out.contains("TautError :: code (& __err)"),
"expected TautError::code(&__err) call in emitted tokens, got: {out}"
);
assert!(
out.contains("TautError :: http_status (& __err)"),
"expected TautError::http_status(&__err) call in emitted tokens, got: {out}"
);
assert!(
!out.contains(". get (\"code\")"),
"expected the JSON-poking `__payload.get(\"code\")` lookup to be removed, got: {out}"
);
assert!(
!out.contains("unwrap_or (\"error\")"),
"expected the `unwrap_or(\"error\")` fallback to be removed, got: {out}"
);
}
#[test]
fn extract_stream_item_from_bare_stream() {
let rt: ReturnType = parse_quote!(-> impl Stream<Item = u64>);
let item = extract_stream_item(&rt).expect("should extract Item");
assert_eq!(quote!(#item).to_string(), "u64");
}
#[test]
fn extract_stream_item_from_qualified_path() {
let rt: ReturnType =
parse_quote!(-> impl ::futures::Stream<Item = String> + Send + 'static);
let item = extract_stream_item(&rt).expect("should extract Item");
assert_eq!(quote!(#item).to_string(), "String");
}
#[test]
fn extract_stream_item_from_futures_path() {
let rt: ReturnType = parse_quote!(-> impl futures::Stream<Item = MyMsg> + Send);
let item = extract_stream_item(&rt).expect("should extract Item");
assert_eq!(quote!(#item).to_string(), "MyMsg");
}
#[test]
fn extract_stream_item_returns_none_for_plain_type() {
let rt: ReturnType = parse_quote!(-> u64);
assert!(extract_stream_item(&rt).is_none());
}
#[test]
fn extract_stream_item_returns_none_when_item_binding_missing() {
let rt: ReturnType = parse_quote!(-> impl Stream + Send);
assert!(extract_stream_item(&rt).is_none());
}
#[test]
fn expand_stream_emits_subscription_descriptor() {
let item = quote! {
async fn ticks(input: TicksInput) -> impl futures::Stream<Item = u64> + Send + 'static {
::futures::stream::empty()
}
};
let out = expand(quote!(stream), item).unwrap().to_string();
assert!(out.contains("__taut_proc_ticks"));
assert!(out.contains("ProcedureDescriptor"));
assert!(
out.contains("ProcKindRuntime :: Subscription"),
"expected runtime kind Subscription, got: {out}"
);
assert!(
out.contains("ProcKind :: Subscription"),
"expected IR kind Subscription, got: {out}"
);
assert!(
out.contains("HttpMethod :: Get"),
"expected HttpMethod::Get for subscription, got: {out}"
);
assert!(
out.contains("ProcedureBody :: Stream"),
"expected ProcedureBody::Stream wrapping, got: {out}"
);
assert!(
out.contains("async_stream :: stream"),
"expected async_stream::stream! invocation, got: {out}"
);
assert!(
out.contains("< u64 as :: taut_rpc :: TautType > :: ir_type_ref"),
"expected ir_type_ref<u64> for the Item type, got: {out}"
);
}
#[test]
fn expand_stream_emits_descriptor_for_zero_input_fn() {
let item = quote! {
async fn server_time() -> impl futures::Stream<Item = String> + Send + 'static {
::futures::stream::empty()
}
};
let out = expand(quote!(stream), item).unwrap().to_string();
assert!(out.contains("__taut_proc_server_time"));
assert!(
out.contains("is_null"),
"expected zero-arg path to assert input_value.is_null(), got: {out}"
);
assert!(
out.contains("StreamFrame :: Error"),
"expected StreamFrame::Error emission for invalid zero-arg input, got: {out}"
);
}
#[test]
fn expand_stream_rejects_non_async() {
let item = quote! {
fn ticks() -> impl Stream<Item = u64> { ::futures::stream::empty() }
};
let err = expand(quote!(stream), item).unwrap_err();
assert!(
err.to_string().contains("requires an async fn"),
"expected user-friendly async-fn error, got: {err}"
);
}
#[test]
fn expand_stream_rejects_non_stream_return() {
let item = quote! {
async fn ticks() -> u64 { 0 }
};
let err = expand(quote!(stream), item).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("impl Stream<Item = T>"),
"expected error pointing at `impl Stream<Item = T>`, got: {msg}"
);
}
#[test]
fn expand_unary_emits_validate_call_for_typed_input() {
let item = quote! {
async fn add(input: AddInput) -> i32 { 0 }
};
let out = expand(quote!(), item).unwrap().to_string();
assert!(
out.contains("< AddInput as :: taut_rpc :: Validate > :: validate (& __input)"),
"expected `<AddInput as ::taut_rpc::Validate>::validate(&__input)` call, got: {out}"
);
assert!(
out.contains("\"validation_error\""),
"expected `validation_error` code in emitted tokens, got: {out}"
);
assert!(
out.contains("http_status : 400"),
"expected HTTP 400 status for validation failure, got: {out}"
);
assert!(
out.contains("\"errors\""),
"expected `errors` field in validation_error payload, got: {out}"
);
}
#[test]
fn expand_unary_skips_validate_for_zero_input_fn() {
let item = quote! {
async fn ping() -> String { String::new() }
};
let out = expand(quote!(), item).unwrap().to_string();
assert!(
!out.contains(":: taut_rpc :: Validate"),
"expected no Validate call for zero-input fn, got: {out}"
);
}
#[test]
fn expand_stream_emits_validate_call_for_typed_input() {
let item = quote! {
async fn ticks(input: TicksInput) -> impl futures::Stream<Item = u64> + Send + 'static {
::futures::stream::empty()
}
};
let out = expand(quote!(stream), item).unwrap().to_string();
assert!(
out.contains("< TicksInput as :: taut_rpc :: Validate > :: validate (& __input)"),
"expected `<TicksInput as ::taut_rpc::Validate>::validate(&__input)` call, got: {out}"
);
assert!(
out.contains("StreamFrame :: Error"),
"expected StreamFrame::Error emission for validation failure, got: {out}"
);
assert!(
out.contains("\"validation_error\""),
"expected `validation_error` code in emitted tokens, got: {out}"
);
}
#[test]
fn expand_stream_skips_validate_for_zero_input_fn() {
let item = quote! {
async fn server_time() -> impl futures::Stream<Item = String> + Send + 'static {
::futures::stream::empty()
}
};
let out = expand(quote!(stream), item).unwrap().to_string();
assert!(
!out.contains(":: taut_rpc :: Validate"),
"expected no Validate call for zero-input subscription, got: {out}"
);
}
}