1use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use std::collections::HashMap;
10use std::sync::{Mutex, OnceLock};
11use syn::{
12 parse_macro_input, FnArg, GenericArgument, ItemFn, LitInt, Pat, PathArguments, ReturnType,
13 Token, Type, TypeReference,
14};
15use telepath_wire::cmd_id::derive_cmd_id as compute_cmd_id;
16
17struct CommandArgs {
21 explicit_cmd_id: Option<u16>,
24}
25
26impl syn::parse::Parse for CommandArgs {
27 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
28 if input.is_empty() {
29 return Ok(CommandArgs {
30 explicit_cmd_id: None,
31 });
32 }
33 let key: syn::Ident = input.parse()?;
34 if key != "cmd_id" {
35 return Err(syn::Error::new_spanned(
36 key,
37 "#[command]: unknown attribute key (expected `cmd_id`)",
38 ));
39 }
40 let _eq: Token![=] = input.parse()?;
41 let lit: LitInt = input.parse()?;
42 let value: u16 = lit.base10_parse().map_err(|_| {
43 syn::Error::new_spanned(&lit, "#[command(cmd_id = ...)]: value must fit in u16")
44 })?;
45 Ok(CommandArgs {
46 explicit_cmd_id: Some(value),
47 })
48 }
49}
50
51fn seen_cmd_ids() -> &'static Mutex<HashMap<u16, String>> {
52 static SEEN: OnceLock<Mutex<HashMap<u16, String>>> = OnceLock::new();
53 SEEN.get_or_init(|| Mutex::new(HashMap::new()))
54}
55
56#[proc_macro_attribute]
113pub fn command(attr: TokenStream, item: TokenStream) -> TokenStream {
114 let args = match syn::parse2::<CommandArgs>(TokenStream2::from(attr)) {
115 Ok(a) => a,
116 Err(e) => return e.to_compile_error().into(),
117 };
118 let input = parse_macro_input!(item as ItemFn);
119 match expand_command(input, args.explicit_cmd_id) {
120 Ok(ts) => ts.into(),
121 Err(e) => e.to_compile_error().into(),
122 }
123}
124
125fn expand_command(
126 func: ItemFn,
127 explicit_cmd_id: Option<u16>,
128) -> syn::Result<proc_macro2::TokenStream> {
129 let fn_ident = &func.sig.ident;
130 let fn_name_str = fn_ident.to_string();
131
132 if let Some(tok) = &func.sig.asyncness {
135 return Err(syn::Error::new_spanned(
136 tok,
137 "#[command] does not support async fn",
138 ));
139 }
140 if let Some(tok) = &func.sig.unsafety {
141 return Err(syn::Error::new_spanned(
142 tok,
143 "#[command] does not support unsafe fn",
144 ));
145 }
146 if !func.sig.generics.params.is_empty() {
147 return Err(syn::Error::new_spanned(
148 &func.sig.generics,
149 "#[command] does not support generic functions",
150 ));
151 }
152 if let Some(wc) = &func.sig.generics.where_clause {
153 return Err(syn::Error::new_spanned(
154 wc,
155 "#[command] does not support where clauses",
156 ));
157 }
158
159 let mut wire_idents = Vec::new();
163 let mut wire_types: Vec<Box<Type>> = Vec::new();
164 let mut wire_type_strs = Vec::new();
165
166 struct ResourceArg {
168 ident: syn::Ident,
169 inner_ty: Box<Type>,
170 is_mut: bool,
171 }
172 let mut resource_args: Vec<ResourceArg> = Vec::new();
173
174 let mut all_arg_idents: Vec<syn::Ident> = Vec::new();
176
177 for fn_arg in &func.sig.inputs {
178 match fn_arg {
179 FnArg::Receiver(recv) => {
180 return Err(syn::Error::new_spanned(
181 recv,
182 "#[command] cannot be applied to methods",
183 ));
184 }
185 FnArg::Typed(pat_type) => {
186 let ident = match pat_type.pat.as_ref() {
187 Pat::Ident(pi) => pi.ident.clone(),
188 other => {
189 return Err(syn::Error::new_spanned(
190 other,
191 "#[command] requires simple named arguments (patterns not supported)",
192 ));
193 }
194 };
195
196 let is_resource = pat_type.attrs.iter().any(|a| a.path().is_ident("resource"));
197
198 if is_resource {
199 let Type::Reference(TypeReference {
200 elem, mutability, ..
201 }) = pat_type.ty.as_ref()
202 else {
203 return Err(syn::Error::new_spanned(
204 &pat_type.ty,
205 "#[resource] arguments must be &T or &mut T",
206 ));
207 };
208
209 let inner_str = quote! { #elem }.to_string();
214 for existing in &resource_args {
215 let existing_ty = &existing.inner_ty;
216 let existing_str = quote! { #existing_ty }.to_string();
217 if existing_str == inner_str {
218 return Err(syn::Error::new_spanned(
219 &pat_type.ty,
220 "duplicate #[resource] type; each resource type may appear at most once",
221 ));
222 }
223 }
224
225 resource_args.push(ResourceArg {
226 ident: ident.clone(),
227 inner_ty: elem.clone(),
228 is_mut: mutability.is_some(),
229 });
230 all_arg_idents.push(ident);
231 } else {
232 if let Type::Reference(r) = pat_type.ty.as_ref() {
233 return Err(syn::Error::new_spanned(
234 r,
235 "#[command] does not support reference arguments \
236 (use #[resource] for injected references)",
237 ));
238 }
239 let ty = &*pat_type.ty;
240 wire_type_strs.push(quote! { #ty }.to_string());
241 wire_idents.push(ident.clone());
242 wire_types.push(pat_type.ty.clone());
243 all_arg_idents.push(ident);
244 }
245 }
246 }
247 }
248
249 let ret_type_str = match &func.sig.output {
252 ReturnType::Default => "()".to_string(),
253 ReturnType::Type(_, ty) => {
254 if let Type::Reference(r) = ty.as_ref() {
255 return Err(syn::Error::new_spanned(
256 r,
257 "#[command] does not support reference return types",
258 ));
259 }
260 quote! { #ty }.to_string()
261 }
262 };
263
264 let returns_app_error = match &func.sig.output {
269 ReturnType::Default => false,
270 ReturnType::Type(_, ty) => {
271 if is_result_outer(ty) && !is_result_app_error(ty) {
272 return Err(syn::Error::new_spanned(
273 ty,
274 "#[command] supports `Result<T, AppErrorPayload>` for fallible commands. \
275 A `Result` with any other error type is not supported — use \
276 `telepath_wire::AppErrorPayload` as the Err variant, or return a plain \
277 value `T` for an infallible command. \
278 Note: type aliases for AppErrorPayload are not detected; spell it out \
279 literally.",
280 ));
281 }
282 is_result_app_error(ty)
283 }
284 };
285
286 let arg_names_str: String = wire_idents
290 .iter()
291 .map(|id| id.to_string())
292 .collect::<Vec<_>>()
293 .join(",");
294
295 let args_type_str = if wire_type_strs.is_empty() {
300 "()".to_string()
301 } else if wire_type_strs.len() == 1 {
302 format!("({},)", wire_type_strs[0])
303 } else {
304 format!("({})", wire_type_strs.join(", "))
305 };
306
307 let cmd_id_value = explicit_cmd_id
316 .unwrap_or_else(|| compute_cmd_id(&fn_name_str, &args_type_str, &ret_type_str));
317
318 {
319 let mut seen = seen_cmd_ids().lock().unwrap();
320 if let Some(existing) = seen.get(&cmd_id_value) {
321 return Err(syn::Error::new_spanned(
322 fn_ident,
323 format!(
324 "#[command] cmd_id collision: `{}` and `{}` both map to 0x{:04X}. \
325 Rename one of the commands to avoid the collision.",
326 fn_name_str, existing, cmd_id_value
327 ),
328 ));
329 }
330 seen.insert(cmd_id_value, fn_name_str.clone());
331 }
332
333 let cmd_id_expr: proc_macro2::TokenStream = if explicit_cmd_id.is_some() {
337 let v = cmd_id_value;
338 quote! { #v }
339 } else {
340 quote! {
341 ::telepath_server::__derive_cmd_id(
342 #fn_name_str,
343 #args_type_str,
344 #ret_type_str,
345 )
346 }
347 };
348
349 let collision_export = format!("__telepath_cmd_id_{:04X}", cmd_id_value);
350 let guard_ident = format_ident!("__TELEPATH_CMDID_GUARD_{}", fn_name_str.to_uppercase());
351
352 let shim_ident = format_ident!("__telepath_shim_{}", fn_name_str);
355 let args_schema_ident = format_ident!("__telepath_args_schema_{}", fn_name_str);
356 let ret_schema_ident = format_ident!("__telepath_ret_schema_{}", fn_name_str);
357 let static_ident = format_ident!("__TELEPATH_CMD_{}", fn_name_str.to_uppercase());
358 let reg_ident = format_ident!("__TELEPATH_REG_{}", fn_name_str.to_uppercase());
359
360 let args_schema_type = if wire_types.is_empty() {
364 quote! { () }
365 } else if wire_types.len() == 1 {
366 let t = &*wire_types[0];
367 quote! { (#t,) }
368 } else {
369 quote! { (#(#wire_types),*) }
370 };
371
372 let ret_schema_type = match &func.sig.output {
376 ReturnType::Default => quote! { () },
377 ReturnType::Type(_, ty) => {
378 if returns_app_error {
379 let ok_ty = extract_ok_type(ty);
380 quote! { #ok_ty }
381 } else {
382 quote! { #ty }
383 }
384 }
385 };
386
387 let wire_deser = if wire_idents.is_empty() {
391 quote! {
392 if !input.is_empty() {
393 return ::core::result::Result::Err(
394 ::telepath_server::DispatchError::DeserializeError
395 );
396 }
397 }
398 } else {
399 let wire_tuple_type = if wire_types.len() == 1 {
400 let t = &*wire_types[0];
401 quote! { (#t,) }
402 } else {
403 quote! { (#(#wire_types),*) }
404 };
405 let wire_pat = if wire_idents.len() == 1 {
406 let id = &wire_idents[0];
407 quote! { (#id,) }
408 } else {
409 quote! { (#(#wire_idents),*) }
410 };
411 quote! {
412 let #wire_pat: #wire_tuple_type = match ::postcard::from_bytes(input) {
413 Ok(v) => v,
414 Err(_) => return ::core::result::Result::Err(
415 ::telepath_server::DispatchError::DeserializeError
416 ),
417 };
418 }
419 };
420
421 let resource_lookups: Vec<_> = resource_args
423 .iter()
424 .map(|ra| {
425 let ident = &ra.ident;
426 let inner_ty = &ra.inner_ty;
427 if ra.is_mut {
428 quote! {
429 let #ident: &mut #inner_ty = unsafe {
430 &mut *__resources.get_ptr::<#inner_ty>()
431 .ok_or(::telepath_server::DispatchError::ResourceUnavailable)?
432 };
433 }
434 } else {
435 quote! {
436 let #ident: &#inner_ty = unsafe {
437 &*__resources.get_ptr::<#inner_ty>()
438 .ok_or(::telepath_server::DispatchError::ResourceUnavailable)?
439 };
440 }
441 }
442 })
443 .collect();
444
445 let call_args: Vec<_> = all_arg_idents
447 .iter()
448 .map(|ident| quote! { #ident })
449 .collect();
450
451 let shim_body = if returns_app_error {
452 quote! {
456 #wire_deser
457 #(#resource_lookups)*
458 let __ret = #fn_ident(#(#call_args),*);
459 match __ret {
460 ::core::result::Result::Ok(__ok) => {
461 match ::postcard::to_slice(&__ok, output) {
462 Ok(s) => ::core::result::Result::Ok(
463 ::telepath_server::DispatchOutcome::Ok(s.len())
464 ),
465 Err(_) => ::core::result::Result::Err(
466 ::telepath_server::DispatchError::SerializeError
467 ),
468 }
469 }
470 ::core::result::Result::Err(__err) => {
471 match ::telepath_server::__encode_app_error(&__err, output) {
472 Ok(n) => ::core::result::Result::Ok(
473 ::telepath_server::DispatchOutcome::AppError(n)
474 ),
475 Err(_) => ::core::result::Result::Err(
476 ::telepath_server::DispatchError::SerializeError
477 ),
478 }
479 }
480 }
481 }
482 } else {
483 quote! {
486 #wire_deser
487 #(#resource_lookups)*
488 let __ret = #fn_ident(#(#call_args),*);
489 match ::postcard::to_slice(&__ret, output) {
490 Ok(s) => ::core::result::Result::Ok(
491 ::telepath_server::DispatchOutcome::Ok(s.len())
492 ),
493 Err(_) => ::core::result::Result::Err(
494 ::telepath_server::DispatchError::SerializeError
495 ),
496 }
497 }
498 };
499
500 let mut clean_func = func.clone();
503 for fn_arg in &mut clean_func.sig.inputs {
504 if let FnArg::Typed(pat_type) = fn_arg {
505 pat_type.attrs.retain(|a| !a.path().is_ident("resource"));
506 }
507 }
508
509 Ok(quote! {
510 #clean_func
511
512 #[allow(non_snake_case)]
513 fn #shim_ident(
514 input: &[u8],
515 output: &mut [u8],
516 __resources: &::telepath_server::ResourceRegistry,
517 ) -> ::core::result::Result<
518 ::telepath_server::DispatchOutcome,
519 ::telepath_server::DispatchError,
520 > {
521 #shim_body
522 }
523
524 #[allow(non_snake_case)]
525 fn #args_schema_ident(out: &mut [u8]) -> ::core::result::Result<usize, ()> {
526 ::postcard::to_slice(
527 <#args_schema_type as ::telepath_server::__postcard_schema::Schema>::SCHEMA,
528 out,
529 )
530 .map(|s| s.len())
531 .map_err(|_| ())
532 }
533
534 #[allow(non_snake_case)]
535 fn #ret_schema_ident(out: &mut [u8]) -> ::core::result::Result<usize, ()> {
536 ::postcard::to_slice(
537 <#ret_schema_type as ::telepath_server::__postcard_schema::Schema>::SCHEMA,
538 out,
539 )
540 .map(|s| s.len())
541 .map_err(|_| ())
542 }
543
544 pub const #static_ident: ::telepath_server::CommandMetadata =
545 ::telepath_server::CommandMetadata {
546 name: #fn_name_str,
547 id: #cmd_id_expr,
548 invoke: #shim_ident,
549 args_schema: #args_schema_ident,
550 ret_schema: #ret_schema_ident,
551 arg_names: #arg_names_str,
552 };
553
554 #[allow(non_upper_case_globals, non_snake_case)]
555 #[::telepath_server::__linkme::distributed_slice(::telepath_server::TELEPATH_COMMANDS)]
556 #[linkme(crate = ::telepath_server::__linkme)]
557 static #reg_ident: ::telepath_server::CommandMetadata = #static_ident;
558
559 #[doc(hidden)]
570 #[allow(non_upper_case_globals, dead_code)]
571 #[used]
572 #[export_name = #collision_export]
573 pub static #guard_ident: u8 = 0;
574
575 })
576}
577
578fn is_result_outer(ty: &Type) -> bool {
587 let Type::Path(tp) = ty else { return false };
588 let Some(seg) = tp.path.segments.last() else {
589 return false;
590 };
591 if seg.ident != "Result" {
592 return false;
593 }
594 let PathArguments::AngleBracketed(args) = &seg.arguments else {
595 return false;
596 };
597 let type_args: Vec<&Type> = args
598 .args
599 .iter()
600 .filter_map(|a| match a {
601 GenericArgument::Type(t) => Some(t),
602 _ => None,
603 })
604 .collect();
605 type_args.len() == 2
606}
607
608fn is_result_app_error(ty: &Type) -> bool {
615 if !is_result_outer(ty) {
616 return false;
617 }
618 let Type::Path(tp) = ty else { return false };
619 let Some(seg) = tp.path.segments.last() else {
620 return false;
621 };
622 let PathArguments::AngleBracketed(args) = &seg.arguments else {
623 return false;
624 };
625 let type_args: Vec<&Type> = args
626 .args
627 .iter()
628 .filter_map(|a| match a {
629 GenericArgument::Type(t) => Some(t),
630 _ => None,
631 })
632 .collect();
633 let err_ty = type_args[1];
634 let Type::Path(err_tp) = err_ty else {
635 return false;
636 };
637 err_tp
638 .path
639 .segments
640 .last()
641 .map(|s| s.ident == "AppErrorPayload")
642 .unwrap_or(false)
643}
644
645fn extract_ok_type(ty: &Type) -> &Type {
652 let Type::Path(tp) = ty else {
653 panic!("extract_ok_type: expected Type::Path");
654 };
655 let seg = tp.path.segments.last().expect("empty path");
656 let PathArguments::AngleBracketed(args) = &seg.arguments else {
657 panic!("extract_ok_type: expected angle-bracketed args");
658 };
659 args.args
660 .iter()
661 .filter_map(|a| match a {
662 GenericArgument::Type(t) => Some(t),
663 _ => None,
664 })
665 .next()
666 .expect("extract_ok_type: no type arg")
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672 use std::sync::Mutex;
673
674 static TEST_GUARD: Mutex<()> = Mutex::new(());
676
677 fn parse_fn(src: &str) -> ItemFn {
678 syn::parse_str(src).unwrap()
679 }
680
681 #[test]
682 fn same_crate_collision_is_rejected() {
683 let _g = TEST_GUARD.lock().unwrap();
684 seen_cmd_ids().lock().unwrap().clear();
685 assert!(expand_command(parse_fn("fn cmd_446() -> u32 { 0 }"), None).is_ok());
687 let err = expand_command(parse_fn("fn cmd_470() -> u32 { 0 }"), None)
688 .unwrap_err()
689 .to_string();
690 assert!(
691 err.contains("cmd_id collision"),
692 "expected collision error, got: {err}"
693 );
694 assert!(
695 err.contains("0x43AE"),
696 "expected hex id 0x43AE in error, got: {err}"
697 );
698 assert!(
699 err.contains("cmd_446") && err.contains("cmd_470"),
700 "expected both command names in error, got: {err}"
701 );
702 seen_cmd_ids().lock().unwrap().clear();
703 }
704
705 #[test]
706 fn guard_symbol_has_correct_export_name() {
707 let _g = TEST_GUARD.lock().unwrap();
708 seen_cmd_ids().lock().unwrap().clear();
709 let ts = expand_command(parse_fn("fn cmd_446() -> u32 { 0 }"), None)
710 .unwrap()
711 .to_string();
712 assert!(
714 ts.contains("__telepath_cmd_id_43AE"),
715 "guard symbol export_name not found in generated code: {ts}"
716 );
717 seen_cmd_ids().lock().unwrap().clear();
718 }
719
720 #[test]
721 fn distinct_commands_do_not_collide() {
722 let _g = TEST_GUARD.lock().unwrap();
723 seen_cmd_ids().lock().unwrap().clear();
724 assert!(expand_command(parse_fn("fn ping() -> u32 { 0 }"), None).is_ok());
725 assert!(expand_command(parse_fn("fn echo(x: u32) -> u32 { x }"), None).is_ok());
726 seen_cmd_ids().lock().unwrap().clear();
727 }
728
729 #[test]
730 fn explicit_cmd_id_overrides_derive() {
731 let _g = TEST_GUARD.lock().unwrap();
732 seen_cmd_ids().lock().unwrap().clear();
733 let ts = expand_command(parse_fn("fn get_metrics() -> u32 { 0 }"), Some(0xFFFE))
734 .unwrap()
735 .to_string();
736 assert!(
738 ts.contains("65534"), "explicit cmd_id 0xFFFE not found as literal in generated code: {ts}"
740 );
741 assert!(
743 ts.contains("__telepath_cmd_id_FFFE"),
744 "guard symbol for explicit cmd_id not found in generated code: {ts}"
745 );
746 seen_cmd_ids().lock().unwrap().clear();
747 }
748
749 #[test]
750 fn explicit_cmd_id_collision_rejected() {
751 let _g = TEST_GUARD.lock().unwrap();
752 seen_cmd_ids().lock().unwrap().clear();
753 assert!(expand_command(parse_fn("fn foo() -> u32 { 0 }"), Some(0xFFFE)).is_ok());
754 let err = expand_command(parse_fn("fn bar() -> u32 { 0 }"), Some(0xFFFE))
755 .unwrap_err()
756 .to_string();
757 assert!(
758 err.contains("cmd_id collision"),
759 "expected collision error for duplicate explicit cmd_id, got: {err}"
760 );
761 seen_cmd_ids().lock().unwrap().clear();
762 }
763}