use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::{
Attribute, DeriveInput, Expr, ExprCall, ExprLit, ExprMethodCall, ExprPath, ExprStruct, FnArg,
Ident, ItemFn, Lit, LitStr, Meta, PatType, Path, ReturnType, Token, Type, TypePath,
parenthesized, parse_macro_input,
};
struct SubscriberArgs {
source: Expr,
publish: Option<LitStr>,
}
impl Parse for SubscriberArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let source: Expr = input.parse()?;
let mut publish = None;
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
let keyword: Ident = input.parse()?;
if keyword != "publish" {
return Err(syn::Error::new(
keyword.span(),
"expected `publish(\"reply-topic\")`",
));
}
let content;
parenthesized!(content in input);
publish = Some(content.parse()?);
}
Ok(Self { source, publish })
}
}
fn source_tokens(expr: &Expr) -> syn::Result<(TokenStream2, TokenStream2)> {
if let Expr::Lit(ExprLit {
lit: Lit::Str(name),
..
}) = expr
{
return Ok((
quote!(::ruststream::Name),
quote!(::ruststream::Name::new(#name)),
));
}
let ty = source_type(expr)?;
Ok((quote!(#ty), quote!(#expr)))
}
fn source_type(expr: &Expr) -> syn::Result<Type> {
match expr {
Expr::Call(ExprCall { func, .. }) => match &**func {
Expr::Path(ExprPath {
path, qself: None, ..
}) => type_from_constructor_path(path),
_ => Err(unsupported_source(expr)),
},
Expr::Struct(ExprStruct { path, .. }) => Ok(Type::Path(TypePath {
qself: None,
path: path.clone(),
})),
Expr::MethodCall(ExprMethodCall { receiver, .. }) => source_type(receiver),
_ => Err(unsupported_source(expr)),
}
}
fn type_from_constructor_path(path: &Path) -> syn::Result<Type> {
let n = path.segments.len();
if n < 2 {
return Err(syn::Error::new_spanned(
path,
"expected `Type::new(..)`: the path must name a type and an associated constructor",
));
}
let segments = path.segments.iter().take(n - 1).cloned().collect();
Ok(Type::Path(TypePath {
qself: None,
path: Path {
leading_colon: path.leading_colon,
segments,
},
}))
}
fn publish_result_reply(ty: &Type) -> Option<&Type> {
let Type::Path(TypePath { qself: None, path }) = ty else {
return None;
};
let last = path.segments.last()?;
if last.ident != "Result" {
return None;
}
let syn::PathArguments::AngleBracketed(args) = &last.arguments else {
return None;
};
let mut args = args.args.iter();
let (Some(syn::GenericArgument::Type(ok)), Some(syn::GenericArgument::Type(err)), None) =
(args.next(), args.next(), args.next())
else {
return None;
};
let Type::Path(TypePath {
qself: None,
path: err_path,
}) = err
else {
return None;
};
(err_path.segments.last()?.ident == "HandlerResult").then_some(ok)
}
fn unsupported_source(expr: &Expr) -> syn::Error {
syn::Error::new_spanned(
expr,
"expected a string literal name, `Type::new(..)`, `Type { .. }`, or a builder chain on \
one of those - a free function does not expose its type to the macro",
)
}
#[proc_macro_attribute]
pub fn subscriber(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as SubscriberArgs);
let func = parse_macro_input!(item as ItemFn);
expand(&args, &func).unwrap_or_else(|err| err.to_compile_error().into())
}
#[proc_macro_attribute]
pub fn app(attr: TokenStream, item: TokenStream) -> TokenStream {
let func = parse_macro_input!(item as ItemFn);
expand_app(&attr.into(), &func).unwrap_or_else(|err| err.to_compile_error().into())
}
fn expand_app(attr: &TokenStream2, func: &ItemFn) -> syn::Result<TokenStream> {
if !attr.is_empty() {
return Err(syn::Error::new_spanned(
attr,
"#[ruststream::app] takes no arguments",
));
}
if let Some(asyncness) = func.sig.asyncness {
return Err(syn::Error::new_spanned(
asyncness,
"#[ruststream::app] requires a synchronous builder returning `RustStream`",
));
}
if !func.sig.inputs.is_empty() {
return Err(syn::Error::new_spanned(
&func.sig.inputs,
"#[ruststream::app] builder must take no arguments",
));
}
let name = &func.sig.ident;
Ok(quote! {
#func
fn main() -> ::std::process::ExitCode {
::ruststream::runtime::cli::run_main(#name)
}
}
.into())
}
struct HandlerParts<'a> {
vis: &'a syn::Visibility,
name: &'a Ident,
block: &'a syn::Block,
pat: &'a syn::Pat,
input_ty: &'a Type,
description: TokenStream2,
source_ty: TokenStream2,
source_expr: TokenStream2,
input_schema: TokenStream2,
message_meta: TokenStream2,
ctx_param: TokenStream2,
}
fn handler_parts<'a>(args: &SubscriberArgs, func: &'a ItemFn) -> syn::Result<HandlerParts<'a>> {
let first = func.sig.inputs.first().ok_or_else(|| {
syn::Error::new_spanned(
&func.sig,
"a #[subscriber] handler must take exactly one message parameter",
)
})?;
let FnArg::Typed(PatType { pat, ty, .. }) = first else {
return Err(syn::Error::new_spanned(
first,
"a #[subscriber] handler cannot take `self`",
));
};
let Type::Reference(reference) = &**ty else {
return Err(syn::Error::new_spanned(
ty,
"the message parameter must be a reference `&T`",
));
};
let input_ty = &*reference.elem;
let description = doc_description(&func.attrs);
let (source_ty, source_expr) = source_tokens(&args.source)?;
let input_schema = quote! {
fn input_schema(&self) -> ::core::option::Option<::std::string::String> {
#[allow(unused_imports)]
use ::ruststream::__private::NoSchemaProbe as _;
::ruststream::__private::Probe::<#input_ty>::new().schema_json()
}
};
let message_meta = quote! {
fn message_name(&self) -> ::core::option::Option<&'static str> {
#[allow(unused_imports)]
use ::ruststream::__private::NoMessageProbe as _;
::ruststream::__private::Probe::<#input_ty>::new().message_name()
}
fn message_description(&self) -> ::core::option::Option<&'static str> {
#[allow(unused_imports)]
use ::ruststream::__private::NoMessageProbe as _;
::ruststream::__private::Probe::<#input_ty>::new().message_description()
}
};
let ctx_param = if let Some(FnArg::Typed(PatType { pat, .. })) = func.sig.inputs.get(1) {
quote!(#pat)
} else {
quote!(_ctx)
};
Ok(HandlerParts {
vis: &func.vis,
name: &func.sig.ident,
block: &func.block,
pat,
input_ty,
description,
source_ty,
source_expr,
input_schema,
message_meta,
ctx_param,
})
}
fn expand(args: &SubscriberArgs, func: &ItemFn) -> syn::Result<TokenStream> {
let parts = handler_parts(args, func)?;
let body = if let Some(reply_topic) = &args.publish {
expand_publishing(&parts, func, reply_topic)?
} else {
expand_subscribing(&parts)
};
Ok(body.into())
}
fn expand_publishing(
parts: &HandlerParts<'_>,
func: &ItemFn,
reply_topic: &LitStr,
) -> syn::Result<TokenStream2> {
let HandlerParts {
vis,
name,
block,
pat,
input_ty,
description,
source_ty,
source_expr,
input_schema,
message_meta,
ctx_param,
} = parts;
let declared_ty = match &func.sig.output {
ReturnType::Type(_, ty) => &**ty,
ReturnType::Default => {
return Err(syn::Error::new_spanned(
&func.sig,
"a publishing handler must return the reply value",
));
}
};
let (reply_ty, call_body) = match publish_result_reply(declared_ty) {
Some(reply_ty) => (reply_ty, quote!((async move #block).await)),
None => (
declared_ty,
quote!(::core::result::Result::Ok((async move #block).await)),
),
};
Ok(quote! {
#[allow(non_camel_case_types)]
#vis struct #name;
impl ::ruststream::runtime::PublishingDef for #name {
type Input = #input_ty;
type Reply = #reply_ty;
type Source = #source_ty;
fn source(&self) -> Self::Source { #source_expr }
fn reply_name(&self) -> &str { #reply_topic }
fn description(&self) -> ::core::option::Option<&str> {
#description
}
#input_schema
#message_meta
async fn call(
&self,
#pat: &#input_ty,
#ctx_param: &mut ::ruststream::runtime::Context<'_>,
) -> ::core::result::Result<#reply_ty, ::ruststream::runtime::HandlerResult> {
#call_body
}
}
})
}
fn expand_subscribing(parts: &HandlerParts<'_>) -> TokenStream2 {
let HandlerParts {
vis,
name,
block,
pat,
input_ty,
description,
source_ty,
source_expr,
input_schema,
message_meta,
ctx_param,
} = parts;
quote! {
#[derive(Clone, Copy)]
#[allow(non_camel_case_types)]
#vis struct #name;
impl ::ruststream::runtime::Handler<#input_ty> for #name {
async fn handle(
&self,
#pat: &#input_ty,
#ctx_param: &mut ::ruststream::runtime::Context<'_>,
) -> ::ruststream::runtime::HandlerResult {
::ruststream::runtime::IntoHandlerResult::into_handler_result(
(async move #block).await,
)
}
}
impl ::ruststream::runtime::SubscriberDef for #name {
type Input = #input_ty;
type Handler = Self;
type Source = #source_ty;
fn source(&self) -> Self::Source { #source_expr }
fn description(&self) -> ::core::option::Option<&str> {
#description
}
#input_schema
#message_meta
fn into_handler(self) -> Self { self }
}
}
}
#[proc_macro_derive(Message)]
pub fn derive_message(item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let name = &input.ident;
let name_str = name.to_string();
let description = doc_description(&input.attrs);
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
quote! {
impl #impl_generics ::ruststream::Message for #name #ty_generics #where_clause {
const NAME: &'static str = #name_str;
const DESCRIPTION: ::core::option::Option<&'static str> = #description;
}
}
.into()
}
fn doc_description(attrs: &[Attribute]) -> TokenStream2 {
let lines: Vec<String> = attrs
.iter()
.filter(|attr| attr.path().is_ident("doc"))
.filter_map(|attr| match &attr.meta {
Meta::NameValue(nv) => match &nv.value {
Expr::Lit(ExprLit {
lit: Lit::Str(text),
..
}) => Some(text.value().trim().to_owned()),
_ => None,
},
_ => None,
})
.collect();
if lines.is_empty() {
quote!(::core::option::Option::None)
} else {
let joined = lines.join("\n");
quote!(::core::option::Option::Some(#joined))
}
}