use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use syn::{
parse_macro_input, FnArg, GenericArgument, ItemFn, LitInt, Pat, PathArguments, ReturnType,
Token, Type, TypeReference,
};
use telepath_wire::cmd_id::derive_cmd_id as compute_cmd_id;
struct CommandArgs {
explicit_cmd_id: Option<u16>,
}
impl syn::parse::Parse for CommandArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
if input.is_empty() {
return Ok(CommandArgs {
explicit_cmd_id: None,
});
}
let key: syn::Ident = input.parse()?;
if key != "cmd_id" {
return Err(syn::Error::new_spanned(
key,
"#[command]: unknown attribute key (expected `cmd_id`)",
));
}
let _eq: Token![=] = input.parse()?;
let lit: LitInt = input.parse()?;
let value: u16 = lit.base10_parse().map_err(|_| {
syn::Error::new_spanned(&lit, "#[command(cmd_id = ...)]: value must fit in u16")
})?;
Ok(CommandArgs {
explicit_cmd_id: Some(value),
})
}
}
fn seen_cmd_ids() -> &'static Mutex<HashMap<u16, String>> {
static SEEN: OnceLock<Mutex<HashMap<u16, String>>> = OnceLock::new();
SEEN.get_or_init(|| Mutex::new(HashMap::new()))
}
#[proc_macro_attribute]
pub fn command(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = match syn::parse2::<CommandArgs>(TokenStream2::from(attr)) {
Ok(a) => a,
Err(e) => return e.to_compile_error().into(),
};
let input = parse_macro_input!(item as ItemFn);
match expand_command(input, args.explicit_cmd_id) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn expand_command(
func: ItemFn,
explicit_cmd_id: Option<u16>,
) -> syn::Result<proc_macro2::TokenStream> {
let fn_ident = &func.sig.ident;
let fn_name_str = fn_ident.to_string();
if let Some(tok) = &func.sig.asyncness {
return Err(syn::Error::new_spanned(
tok,
"#[command] does not support async fn",
));
}
if let Some(tok) = &func.sig.unsafety {
return Err(syn::Error::new_spanned(
tok,
"#[command] does not support unsafe fn",
));
}
if !func.sig.generics.params.is_empty() {
return Err(syn::Error::new_spanned(
&func.sig.generics,
"#[command] does not support generic functions",
));
}
if let Some(wc) = &func.sig.generics.where_clause {
return Err(syn::Error::new_spanned(
wc,
"#[command] does not support where clauses",
));
}
let mut wire_idents = Vec::new();
let mut wire_types: Vec<Box<Type>> = Vec::new();
let mut wire_type_strs = Vec::new();
struct ResourceArg {
ident: syn::Ident,
inner_ty: Box<Type>,
is_mut: bool,
}
let mut resource_args: Vec<ResourceArg> = Vec::new();
let mut all_arg_idents: Vec<syn::Ident> = Vec::new();
for fn_arg in &func.sig.inputs {
match fn_arg {
FnArg::Receiver(recv) => {
return Err(syn::Error::new_spanned(
recv,
"#[command] cannot be applied to methods",
));
}
FnArg::Typed(pat_type) => {
let ident = match pat_type.pat.as_ref() {
Pat::Ident(pi) => pi.ident.clone(),
other => {
return Err(syn::Error::new_spanned(
other,
"#[command] requires simple named arguments (patterns not supported)",
));
}
};
let is_resource = pat_type.attrs.iter().any(|a| a.path().is_ident("resource"));
if is_resource {
let Type::Reference(TypeReference {
elem, mutability, ..
}) = pat_type.ty.as_ref()
else {
return Err(syn::Error::new_spanned(
&pat_type.ty,
"#[resource] arguments must be &T or &mut T",
));
};
let inner_str = quote! { #elem }.to_string();
for existing in &resource_args {
let existing_ty = &existing.inner_ty;
let existing_str = quote! { #existing_ty }.to_string();
if existing_str == inner_str {
return Err(syn::Error::new_spanned(
&pat_type.ty,
"duplicate #[resource] type; each resource type may appear at most once",
));
}
}
resource_args.push(ResourceArg {
ident: ident.clone(),
inner_ty: elem.clone(),
is_mut: mutability.is_some(),
});
all_arg_idents.push(ident);
} else {
if let Type::Reference(r) = pat_type.ty.as_ref() {
return Err(syn::Error::new_spanned(
r,
"#[command] does not support reference arguments \
(use #[resource] for injected references)",
));
}
let ty = &*pat_type.ty;
wire_type_strs.push(quote! { #ty }.to_string());
wire_idents.push(ident.clone());
wire_types.push(pat_type.ty.clone());
all_arg_idents.push(ident);
}
}
}
}
let ret_type_str = match &func.sig.output {
ReturnType::Default => "()".to_string(),
ReturnType::Type(_, ty) => {
if let Type::Reference(r) = ty.as_ref() {
return Err(syn::Error::new_spanned(
r,
"#[command] does not support reference return types",
));
}
quote! { #ty }.to_string()
}
};
let returns_app_error = match &func.sig.output {
ReturnType::Default => false,
ReturnType::Type(_, ty) => {
if is_result_outer(ty) && !is_result_app_error(ty) {
return Err(syn::Error::new_spanned(
ty,
"#[command] supports `Result<T, AppErrorPayload>` for fallible commands. \
A `Result` with any other error type is not supported — use \
`telepath_wire::AppErrorPayload` as the Err variant, or return a plain \
value `T` for an infallible command. \
Note: type aliases for AppErrorPayload are not detected; spell it out \
literally.",
));
}
is_result_app_error(ty)
}
};
let arg_names_str: String = wire_idents
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(",");
let args_type_str = if wire_type_strs.is_empty() {
"()".to_string()
} else if wire_type_strs.len() == 1 {
format!("({},)", wire_type_strs[0])
} else {
format!("({})", wire_type_strs.join(", "))
};
let cmd_id_value = explicit_cmd_id
.unwrap_or_else(|| compute_cmd_id(&fn_name_str, &args_type_str, &ret_type_str));
{
let mut seen = seen_cmd_ids().lock().unwrap();
if let Some(existing) = seen.get(&cmd_id_value) {
return Err(syn::Error::new_spanned(
fn_ident,
format!(
"#[command] cmd_id collision: `{}` and `{}` both map to 0x{:04X}. \
Rename one of the commands to avoid the collision.",
fn_name_str, existing, cmd_id_value
),
));
}
seen.insert(cmd_id_value, fn_name_str.clone());
}
let cmd_id_expr: proc_macro2::TokenStream = if explicit_cmd_id.is_some() {
let v = cmd_id_value;
quote! { #v }
} else {
quote! {
::telepath_server::__derive_cmd_id(
#fn_name_str,
#args_type_str,
#ret_type_str,
)
}
};
let collision_export = format!("__telepath_cmd_id_{:04X}", cmd_id_value);
let guard_ident = format_ident!("__TELEPATH_CMDID_GUARD_{}", fn_name_str.to_uppercase());
let shim_ident = format_ident!("__telepath_shim_{}", fn_name_str);
let args_schema_ident = format_ident!("__telepath_args_schema_{}", fn_name_str);
let ret_schema_ident = format_ident!("__telepath_ret_schema_{}", fn_name_str);
let static_ident = format_ident!("__TELEPATH_CMD_{}", fn_name_str.to_uppercase());
let reg_ident = format_ident!("__TELEPATH_REG_{}", fn_name_str.to_uppercase());
let args_schema_type = if wire_types.is_empty() {
quote! { () }
} else if wire_types.len() == 1 {
let t = &*wire_types[0];
quote! { (#t,) }
} else {
quote! { (#(#wire_types),*) }
};
let ret_schema_type = match &func.sig.output {
ReturnType::Default => quote! { () },
ReturnType::Type(_, ty) => {
if returns_app_error {
let ok_ty = extract_ok_type(ty);
quote! { #ok_ty }
} else {
quote! { #ty }
}
}
};
let wire_deser = if wire_idents.is_empty() {
quote! {
if !input.is_empty() {
return ::core::result::Result::Err(
::telepath_server::DispatchError::DeserializeError
);
}
}
} else {
let wire_tuple_type = if wire_types.len() == 1 {
let t = &*wire_types[0];
quote! { (#t,) }
} else {
quote! { (#(#wire_types),*) }
};
let wire_pat = if wire_idents.len() == 1 {
let id = &wire_idents[0];
quote! { (#id,) }
} else {
quote! { (#(#wire_idents),*) }
};
quote! {
let #wire_pat: #wire_tuple_type = match ::postcard::from_bytes(input) {
Ok(v) => v,
Err(_) => return ::core::result::Result::Err(
::telepath_server::DispatchError::DeserializeError
),
};
}
};
let resource_lookups: Vec<_> = resource_args
.iter()
.map(|ra| {
let ident = &ra.ident;
let inner_ty = &ra.inner_ty;
if ra.is_mut {
quote! {
let #ident: &mut #inner_ty = unsafe {
&mut *__resources.get_ptr::<#inner_ty>()
.ok_or(::telepath_server::DispatchError::ResourceUnavailable)?
};
}
} else {
quote! {
let #ident: &#inner_ty = unsafe {
&*__resources.get_ptr::<#inner_ty>()
.ok_or(::telepath_server::DispatchError::ResourceUnavailable)?
};
}
}
})
.collect();
let call_args: Vec<_> = all_arg_idents
.iter()
.map(|ident| quote! { #ident })
.collect();
let shim_body = if returns_app_error {
quote! {
#wire_deser
#(#resource_lookups)*
let __ret = #fn_ident(#(#call_args),*);
match __ret {
::core::result::Result::Ok(__ok) => {
match ::postcard::to_slice(&__ok, output) {
Ok(s) => ::core::result::Result::Ok(
::telepath_server::DispatchOutcome::Ok(s.len())
),
Err(_) => ::core::result::Result::Err(
::telepath_server::DispatchError::SerializeError
),
}
}
::core::result::Result::Err(__err) => {
match ::telepath_server::__encode_app_error(&__err, output) {
Ok(n) => ::core::result::Result::Ok(
::telepath_server::DispatchOutcome::AppError(n)
),
Err(_) => ::core::result::Result::Err(
::telepath_server::DispatchError::SerializeError
),
}
}
}
}
} else {
quote! {
#wire_deser
#(#resource_lookups)*
let __ret = #fn_ident(#(#call_args),*);
match ::postcard::to_slice(&__ret, output) {
Ok(s) => ::core::result::Result::Ok(
::telepath_server::DispatchOutcome::Ok(s.len())
),
Err(_) => ::core::result::Result::Err(
::telepath_server::DispatchError::SerializeError
),
}
}
};
let mut clean_func = func.clone();
for fn_arg in &mut clean_func.sig.inputs {
if let FnArg::Typed(pat_type) = fn_arg {
pat_type.attrs.retain(|a| !a.path().is_ident("resource"));
}
}
Ok(quote! {
#clean_func
#[allow(non_snake_case)]
fn #shim_ident(
input: &[u8],
output: &mut [u8],
__resources: &::telepath_server::ResourceRegistry,
) -> ::core::result::Result<
::telepath_server::DispatchOutcome,
::telepath_server::DispatchError,
> {
#shim_body
}
#[allow(non_snake_case)]
fn #args_schema_ident(out: &mut [u8]) -> ::core::result::Result<usize, ()> {
::postcard::to_slice(
<#args_schema_type as ::telepath_server::__postcard_schema::Schema>::SCHEMA,
out,
)
.map(|s| s.len())
.map_err(|_| ())
}
#[allow(non_snake_case)]
fn #ret_schema_ident(out: &mut [u8]) -> ::core::result::Result<usize, ()> {
::postcard::to_slice(
<#ret_schema_type as ::telepath_server::__postcard_schema::Schema>::SCHEMA,
out,
)
.map(|s| s.len())
.map_err(|_| ())
}
pub const #static_ident: ::telepath_server::CommandMetadata =
::telepath_server::CommandMetadata {
name: #fn_name_str,
id: #cmd_id_expr,
invoke: #shim_ident,
args_schema: #args_schema_ident,
ret_schema: #ret_schema_ident,
arg_names: #arg_names_str,
};
#[allow(non_upper_case_globals, non_snake_case)]
#[::telepath_server::__linkme::distributed_slice(::telepath_server::TELEPATH_COMMANDS)]
#[linkme(crate = ::telepath_server::__linkme)]
static #reg_ident: ::telepath_server::CommandMetadata = #static_ident;
#[doc(hidden)]
#[allow(non_upper_case_globals, dead_code)]
#[used]
#[export_name = #collision_export]
pub static #guard_ident: u8 = 0;
})
}
fn is_result_outer(ty: &Type) -> bool {
let Type::Path(tp) = ty else { return false };
let Some(seg) = tp.path.segments.last() else {
return false;
};
if seg.ident != "Result" {
return false;
}
let PathArguments::AngleBracketed(args) = &seg.arguments else {
return false;
};
let type_args: Vec<&Type> = args
.args
.iter()
.filter_map(|a| match a {
GenericArgument::Type(t) => Some(t),
_ => None,
})
.collect();
type_args.len() == 2
}
fn is_result_app_error(ty: &Type) -> bool {
if !is_result_outer(ty) {
return false;
}
let Type::Path(tp) = ty else { return false };
let Some(seg) = tp.path.segments.last() else {
return false;
};
let PathArguments::AngleBracketed(args) = &seg.arguments else {
return false;
};
let type_args: Vec<&Type> = args
.args
.iter()
.filter_map(|a| match a {
GenericArgument::Type(t) => Some(t),
_ => None,
})
.collect();
let err_ty = type_args[1];
let Type::Path(err_tp) = err_ty else {
return false;
};
err_tp
.path
.segments
.last()
.map(|s| s.ident == "AppErrorPayload")
.unwrap_or(false)
}
fn extract_ok_type(ty: &Type) -> &Type {
let Type::Path(tp) = ty else {
panic!("extract_ok_type: expected Type::Path");
};
let seg = tp.path.segments.last().expect("empty path");
let PathArguments::AngleBracketed(args) = &seg.arguments else {
panic!("extract_ok_type: expected angle-bracketed args");
};
args.args
.iter()
.filter_map(|a| match a {
GenericArgument::Type(t) => Some(t),
_ => None,
})
.next()
.expect("extract_ok_type: no type arg")
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static TEST_GUARD: Mutex<()> = Mutex::new(());
fn parse_fn(src: &str) -> ItemFn {
syn::parse_str(src).unwrap()
}
#[test]
fn same_crate_collision_is_rejected() {
let _g = TEST_GUARD.lock().unwrap();
seen_cmd_ids().lock().unwrap().clear();
assert!(expand_command(parse_fn("fn cmd_446() -> u32 { 0 }"), None).is_ok());
let err = expand_command(parse_fn("fn cmd_470() -> u32 { 0 }"), None)
.unwrap_err()
.to_string();
assert!(
err.contains("cmd_id collision"),
"expected collision error, got: {err}"
);
assert!(
err.contains("0x43AE"),
"expected hex id 0x43AE in error, got: {err}"
);
assert!(
err.contains("cmd_446") && err.contains("cmd_470"),
"expected both command names in error, got: {err}"
);
seen_cmd_ids().lock().unwrap().clear();
}
#[test]
fn guard_symbol_has_correct_export_name() {
let _g = TEST_GUARD.lock().unwrap();
seen_cmd_ids().lock().unwrap().clear();
let ts = expand_command(parse_fn("fn cmd_446() -> u32 { 0 }"), None)
.unwrap()
.to_string();
assert!(
ts.contains("__telepath_cmd_id_43AE"),
"guard symbol export_name not found in generated code: {ts}"
);
seen_cmd_ids().lock().unwrap().clear();
}
#[test]
fn distinct_commands_do_not_collide() {
let _g = TEST_GUARD.lock().unwrap();
seen_cmd_ids().lock().unwrap().clear();
assert!(expand_command(parse_fn("fn ping() -> u32 { 0 }"), None).is_ok());
assert!(expand_command(parse_fn("fn echo(x: u32) -> u32 { x }"), None).is_ok());
seen_cmd_ids().lock().unwrap().clear();
}
#[test]
fn explicit_cmd_id_overrides_derive() {
let _g = TEST_GUARD.lock().unwrap();
seen_cmd_ids().lock().unwrap().clear();
let ts = expand_command(parse_fn("fn get_metrics() -> u32 { 0 }"), Some(0xFFFE))
.unwrap()
.to_string();
assert!(
ts.contains("65534"), "explicit cmd_id 0xFFFE not found as literal in generated code: {ts}"
);
assert!(
ts.contains("__telepath_cmd_id_FFFE"),
"guard symbol for explicit cmd_id not found in generated code: {ts}"
);
seen_cmd_ids().lock().unwrap().clear();
}
#[test]
fn explicit_cmd_id_collision_rejected() {
let _g = TEST_GUARD.lock().unwrap();
seen_cmd_ids().lock().unwrap().clear();
assert!(expand_command(parse_fn("fn foo() -> u32 { 0 }"), Some(0xFFFE)).is_ok());
let err = expand_command(parse_fn("fn bar() -> u32 { 0 }"), Some(0xFFFE))
.unwrap_err()
.to_string();
assert!(
err.contains("cmd_id collision"),
"expected collision error for duplicate explicit cmd_id, got: {err}"
);
seen_cmd_ids().lock().unwrap().clear();
}
}