use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{
FnArg, GenericArgument, Ident, ImplItemFn, ItemFn, PatType, PathArguments, ReturnType,
Signature, Type,
};
use crate::attr_args::{self, CommandAttrArgs};
pub fn expand(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream> {
let args = attr_args::parse(attr)?;
let func: ItemFn = syn::parse2(item)?;
if func.sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
&func.sig,
"#[command] requires an async fn",
));
}
let request_ty = extract_request_type(&func.sig)?;
let response_ty = extract_response_type(&func.sig)?;
let fn_name = &func.sig.ident;
let struct_ident = free_fn_struct_ident(fn_name);
let register_ident = format_ident!("register_{}", fn_name);
let vis = &func.vis;
let description = description_tokens(args.description.as_ref());
let id_lit = &args.id;
let call_expr = quote! { async move { #fn_name(request).await } };
let command_impl = emit_command_impl(
&struct_ident,
id_lit,
description,
&request_ty,
&response_ty,
call_expr,
);
let cmd_ipc = cmd_ipc_path();
Ok(quote! {
#func
#vis struct #struct_ident;
#command_impl
#vis async fn #register_ident(
registry: &#cmd_ipc::CommandRegistry,
) -> ::core::result::Result<(), #cmd_ipc::CommandError> {
registry.register_command(#struct_ident).await
}
})
}
pub struct MethodExpansion {
pub items: TokenStream,
pub struct_ident: Ident,
}
pub fn expand_method(
args: CommandAttrArgs,
method: &ImplItemFn,
host_ty: &TokenStream,
host_ident_for_naming: &Ident,
) -> syn::Result<MethodExpansion> {
if method.sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
&method.sig,
"#[command] requires an async fn",
));
}
expect_method_receiver(&method.sig)?;
let request_ty = extract_request_type(&method.sig)?;
let response_ty = extract_response_type(&method.sig)?;
let method_ident = &method.sig.ident;
let struct_ident = method_struct_ident(host_ident_for_naming, method_ident);
let description = description_tokens(args.description.as_ref());
let id_lit = &args.id;
let call_expr = quote! {
{
let host = ::std::sync::Arc::clone(&self.host);
async move { host.#method_ident(request).await }
}
};
let command_impl = emit_command_impl_owned(
&struct_ident,
host_ty,
id_lit,
description,
&request_ty,
&response_ty,
call_expr,
);
let items = quote! {
pub(super) struct #struct_ident {
pub(super) host: ::std::sync::Arc<#host_ty>,
}
#command_impl
};
Ok(MethodExpansion {
items,
struct_ident,
})
}
fn emit_command_impl(
struct_ident: &Ident,
id_lit: &syn::LitStr,
description: TokenStream,
request_ty: &Type,
response_ty: &Type,
call_expr: TokenStream,
) -> TokenStream {
let cmd_ipc = cmd_ipc_path();
let request_schema = if is_unit_type(request_ty) {
quote! { ::core::option::Option::None }
} else {
quote! {
::core::option::Option::Some(
#cmd_ipc::normalize_schema(
#cmd_ipc::serde_json::to_value(
#cmd_ipc::schemars::schema_for!(#request_ty)
).expect("request schema should serialize"),
),
)
}
};
let response_schema = if is_unit_type(response_ty) {
quote! { ::core::option::Option::None }
} else {
quote! {
::core::option::Option::Some(
#cmd_ipc::normalize_schema(
#cmd_ipc::serde_json::to_value(
#cmd_ipc::schemars::schema_for!(#response_ty)
).expect("response schema should serialize"),
),
)
}
};
quote! {
impl #cmd_ipc::Command for #struct_ident {
const ID: &'static str = #id_lit;
const DESCRIPTION: ::core::option::Option<&'static str> = #description;
type Request = #request_ty;
type Response = #response_ty;
fn handle(
&self,
request: Self::Request,
) -> impl ::core::future::Future<
Output = ::core::result::Result<Self::Response, #cmd_ipc::CommandError>
> + ::core::marker::Send {
#call_expr
}
fn schema(&self) -> ::core::option::Option<#cmd_ipc::CommandSchema> {
::core::option::Option::Some(#cmd_ipc::CommandSchema {
request: #request_schema,
response: #response_schema,
})
}
}
}
}
fn is_unit_type(ty: &Type) -> bool {
match ty {
Type::Tuple(t) => t.elems.is_empty(),
_ => false,
}
}
fn emit_command_impl_owned(
struct_ident: &Ident,
_host_ty: &TokenStream,
id_lit: &syn::LitStr,
description: TokenStream,
request_ty: &Type,
response_ty: &Type,
call_expr: TokenStream,
) -> TokenStream {
emit_command_impl(
struct_ident,
id_lit,
description,
request_ty,
response_ty,
call_expr,
)
}
fn description_tokens(d: Option<&syn::LitStr>) -> TokenStream {
match d {
Some(lit) => quote! { ::core::option::Option::Some(#lit) },
None => quote! { ::core::option::Option::None },
}
}
fn cmd_ipc_path() -> TokenStream {
quote! { ::coralstack_cmd_ipc }
}
fn free_fn_struct_ident(fn_name: &Ident) -> Ident {
let s = fn_name.to_string();
let mut c = s.chars();
let capitalized = match c.next() {
Some(first) => first.to_uppercase().chain(c).collect::<String>(),
None => s,
};
Ident::new(&format!("{capitalized}Command"), fn_name.span())
}
fn method_struct_ident(_host: &Ident, method: &Ident) -> Ident {
let s = method.to_string();
let mut c = s.chars();
let capitalized = match c.next() {
Some(first) => first.to_uppercase().chain(c).collect::<String>(),
None => s,
};
Ident::new(&capitalized, method.span())
}
fn extract_request_type(sig: &Signature) -> syn::Result<Type> {
let non_recv: Vec<&FnArg> = sig
.inputs
.iter()
.filter(|a| !matches!(a, FnArg::Receiver(_)))
.collect();
match non_recv.as_slice() {
[] => Ok(unit_type()),
[one] => {
let FnArg::Typed(PatType { ty, pat, .. }) = one else {
return Err(syn::Error::new_spanned(one, "unexpected receiver here"));
};
let _ = pat;
Ok((**ty).clone())
}
_ => Err(syn::Error::new_spanned(
&sig.inputs,
"#[command] handlers must take at most one argument (the typed request)",
)),
}
}
fn extract_response_type(sig: &Signature) -> syn::Result<Type> {
let ReturnType::Type(_, ty) = &sig.output else {
return Err(syn::Error::new_spanned(
&sig.output,
"#[command] handlers must return `Result<R, CommandError>`",
));
};
let Type::Path(tp) = &**ty else {
return Err(syn::Error::new_spanned(
ty,
"#[command] handlers must return `Result<R, CommandError>`",
));
};
let last = tp
.path
.segments
.last()
.ok_or_else(|| syn::Error::new_spanned(ty, "empty return type path"))?;
if last.ident != "Result" {
return Err(syn::Error::new_spanned(
&last.ident,
"#[command] handlers must return `Result<R, CommandError>`",
));
}
let PathArguments::AngleBracketed(args) = &last.arguments else {
return Err(syn::Error::new_spanned(
&last.arguments,
"expected `Result<R, CommandError>` with explicit generics",
));
};
let first = args.args.iter().find_map(|a| match a {
GenericArgument::Type(t) => Some(t.clone()),
_ => None,
});
first.ok_or_else(|| syn::Error::new_spanned(args, "missing response type in Result<_, _>"))
}
fn expect_method_receiver(sig: &Signature) -> syn::Result<()> {
match sig.inputs.first() {
Some(FnArg::Receiver(_)) => Ok(()),
_ => Err(syn::Error::new_spanned(
sig,
"#[command] inside an #[command_service] impl must be a method (first arg `&self`)",
)),
}
}
fn unit_type() -> Type {
syn::parse_quote! { () }
}