Skip to main content

telepath_macros/
lib.rs

1//! Proc-macro crate for Telepath.
2//!
3//! Provides the `#[command]` attribute macro that generates a type-erased shim
4//! function and a `CommandMetadata` const from a plain Rust function definition.
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use std::collections::HashMap;
10use std::sync::{Mutex, OnceLock};
11use syn::{parse_macro_input, FnArg, ItemFn, LitInt, Pat, ReturnType, Token, Type, TypeReference};
12use telepath_wire::cmd_id::derive_cmd_id as compute_cmd_id;
13
14/// Optional attributes for `#[command]`.
15///
16/// Syntax: `#[command]` (no attrs) or `#[command(cmd_id = 0xFFFE)]`.
17struct CommandArgs {
18    /// If present, use this literal value as the command ID instead of
19    /// deriving it from the function signature.
20    explicit_cmd_id: Option<u16>,
21}
22
23impl syn::parse::Parse for CommandArgs {
24    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
25        if input.is_empty() {
26            return Ok(CommandArgs {
27                explicit_cmd_id: None,
28            });
29        }
30        let key: syn::Ident = input.parse()?;
31        if key != "cmd_id" {
32            return Err(syn::Error::new_spanned(
33                key,
34                "#[command]: unknown attribute key (expected `cmd_id`)",
35            ));
36        }
37        let _eq: Token![=] = input.parse()?;
38        let lit: LitInt = input.parse()?;
39        let value: u16 = lit.base10_parse().map_err(|_| {
40            syn::Error::new_spanned(&lit, "#[command(cmd_id = ...)]: value must fit in u16")
41        })?;
42        Ok(CommandArgs {
43            explicit_cmd_id: Some(value),
44        })
45    }
46}
47
48fn seen_cmd_ids() -> &'static Mutex<HashMap<u16, String>> {
49    static SEEN: OnceLock<Mutex<HashMap<u16, String>>> = OnceLock::new();
50    SEEN.get_or_init(|| Mutex::new(HashMap::new()))
51}
52
53/// Marks a function as a Telepath RPC command.
54///
55/// # What it generates
56///
57/// For every annotated function the macro emits five additional items:
58///
59/// 1. **`fn __telepath_shim_<name>(input: &[u8], output: &mut [u8], resources: &ResourceRegistry) -> Result<usize, DispatchError>`** —
60///    deserializes `input` via postcard, resolves `#[resource]`-annotated arguments from
61///    `resources`, calls the original function, and serializes the result into `output`.
62/// 2. **`fn __telepath_args_schema_<name>(out: &mut [u8]) -> Result<usize, ()>`** —
63///    writes a postcard-encoded `postcard_schema::schema::NamedType` for the argument tuple
64///    into `out` and returns the byte count.
65/// 3. **`fn __telepath_ret_schema_<name>(out: &mut [u8]) -> Result<usize, ()>`** —
66///    same for the return type.
67/// 4. **`pub const __TELEPATH_CMD_<NAME>: CommandMetadata`** — a `CommandMetadata` const whose
68///    `id` is derived deterministically from the function's signature via
69///    `derive_cmd_id` at build time.
70/// 5. **`#[linkme] static __TELEPATH_REG_<NAME>`** — registers the metadata in
71///    [`telepath_server::TELEPATH_COMMANDS`] at link time.
72///
73/// The original function body is preserved unchanged so it remains directly callable.
74///
75/// # Requirements on the calling crate
76///
77/// The calling crate must declare the following direct dependencies:
78/// - `telepath-server` — provides `CommandMetadata`, `DispatchError`, and re-exports
79///   `postcard_schema` and `linkme` for use in generated code.
80/// - `postcard` — used in the generated shim for (de)serialization
81///
82/// All argument types and the return type must implement
83/// `postcard_schema::Schema`. Built-in primitives (`u8`, `u32`, `()`,
84/// standard tuples, etc.) already implement it. For user-defined types,
85/// add `#[derive(postcard_schema::Schema)]`.
86///
87/// # Restrictions
88///
89/// The macro rejects functions that are:
90/// - `async fn` (RPC dispatch is synchronous)
91/// - `unsafe fn`
92/// - Generic (`<T>` / `where` clauses)
93/// - Methods (`self` receiver)
94/// - Functions with reference arguments or reference return types
95/// - Functions with pattern-destructured arguments
96///
97/// # Example
98///
99/// ```rust,ignore
100/// use telepath_server::{command, CommandMetadata};
101///
102/// #[command]
103/// fn ping() -> u32 {
104///     0xDEAD_BEEF
105/// }
106///
107/// static COMMANDS: [CommandMetadata; 1] = [__TELEPATH_CMD_PING];
108/// ```
109#[proc_macro_attribute]
110pub fn command(attr: TokenStream, item: TokenStream) -> TokenStream {
111    let args = match syn::parse2::<CommandArgs>(TokenStream2::from(attr)) {
112        Ok(a) => a,
113        Err(e) => return e.to_compile_error().into(),
114    };
115    let input = parse_macro_input!(item as ItemFn);
116    match expand_command(input, args.explicit_cmd_id) {
117        Ok(ts) => ts.into(),
118        Err(e) => e.to_compile_error().into(),
119    }
120}
121
122fn expand_command(
123    func: ItemFn,
124    explicit_cmd_id: Option<u16>,
125) -> syn::Result<proc_macro2::TokenStream> {
126    let fn_ident = &func.sig.ident;
127    let fn_name_str = fn_ident.to_string();
128
129    // --- Validation ---
130
131    if let Some(tok) = &func.sig.asyncness {
132        return Err(syn::Error::new_spanned(
133            tok,
134            "#[command] does not support async fn",
135        ));
136    }
137    if let Some(tok) = &func.sig.unsafety {
138        return Err(syn::Error::new_spanned(
139            tok,
140            "#[command] does not support unsafe fn",
141        ));
142    }
143    if !func.sig.generics.params.is_empty() {
144        return Err(syn::Error::new_spanned(
145            &func.sig.generics,
146            "#[command] does not support generic functions",
147        ));
148    }
149    if let Some(wc) = &func.sig.generics.where_clause {
150        return Err(syn::Error::new_spanned(
151            wc,
152            "#[command] does not support where clauses",
153        ));
154    }
155
156    // --- Parse arguments ---
157
158    // Wire arguments: deserialized from the postcard request payload.
159    let mut wire_idents = Vec::new();
160    let mut wire_types: Vec<Box<Type>> = Vec::new();
161    let mut wire_type_strs = Vec::new();
162
163    // Resource arguments: injected from the ResourceRegistry.
164    struct ResourceArg {
165        ident: syn::Ident,
166        inner_ty: Box<Type>,
167        is_mut: bool,
168    }
169    let mut resource_args: Vec<ResourceArg> = Vec::new();
170
171    // All argument idents in declaration order, for calling the original function.
172    let mut all_arg_idents: Vec<syn::Ident> = Vec::new();
173
174    for fn_arg in &func.sig.inputs {
175        match fn_arg {
176            FnArg::Receiver(recv) => {
177                return Err(syn::Error::new_spanned(
178                    recv,
179                    "#[command] cannot be applied to methods",
180                ));
181            }
182            FnArg::Typed(pat_type) => {
183                let ident = match pat_type.pat.as_ref() {
184                    Pat::Ident(pi) => pi.ident.clone(),
185                    other => {
186                        return Err(syn::Error::new_spanned(
187                            other,
188                            "#[command] requires simple named arguments (patterns not supported)",
189                        ));
190                    }
191                };
192
193                let is_resource = pat_type.attrs.iter().any(|a| a.path().is_ident("resource"));
194
195                if is_resource {
196                    let Type::Reference(TypeReference {
197                        elem, mutability, ..
198                    }) = pat_type.ty.as_ref()
199                    else {
200                        return Err(syn::Error::new_spanned(
201                            &pat_type.ty,
202                            "#[resource] arguments must be &T or &mut T",
203                        ));
204                    };
205
206                    // Best-effort compile-time uniqueness check via token-string comparison.
207                    // Type aliases or differently-spelled paths for the same concrete type
208                    // may slip through; `ResourceRegistry::insert` panics at runtime as a
209                    // fallback in those cases.
210                    let inner_str = quote! { #elem }.to_string();
211                    for existing in &resource_args {
212                        let existing_ty = &existing.inner_ty;
213                        let existing_str = quote! { #existing_ty }.to_string();
214                        if existing_str == inner_str {
215                            return Err(syn::Error::new_spanned(
216                                &pat_type.ty,
217                                "duplicate #[resource] type; each resource type may appear at most once",
218                            ));
219                        }
220                    }
221
222                    resource_args.push(ResourceArg {
223                        ident: ident.clone(),
224                        inner_ty: elem.clone(),
225                        is_mut: mutability.is_some(),
226                    });
227                    all_arg_idents.push(ident);
228                } else {
229                    if let Type::Reference(r) = pat_type.ty.as_ref() {
230                        return Err(syn::Error::new_spanned(
231                            r,
232                            "#[command] does not support reference arguments \
233                             (use #[resource] for injected references)",
234                        ));
235                    }
236                    let ty = &*pat_type.ty;
237                    wire_type_strs.push(quote! { #ty }.to_string());
238                    wire_idents.push(ident.clone());
239                    wire_types.push(pat_type.ty.clone());
240                    all_arg_idents.push(ident);
241                }
242            }
243        }
244    }
245
246    // --- Parse return type ---
247
248    let ret_type_str = match &func.sig.output {
249        ReturnType::Default => "()".to_string(),
250        ReturnType::Type(_, ty) => {
251            if let Type::Reference(r) = ty.as_ref() {
252                return Err(syn::Error::new_spanned(
253                    r,
254                    "#[command] does not support reference return types",
255                ));
256            }
257            quote! { #ty }.to_string()
258        }
259    };
260
261    // --- Build arg_names_str ---
262    // Comma-joined wire argument names for runtime introspection (e.g. "a,b").
263    // Resource arguments are excluded — they are server-side only.
264    let arg_names_str: String = wire_idents
265        .iter()
266        .map(|id| id.to_string())
267        .collect::<Vec<_>>()
268        .join(",");
269
270    // --- Build args_type_str ---
271    // Canonical tuple format of wire arguments matching Rust syntax: "()" for 0-arg,
272    // "(T,)" for 1-arg, "(T1, T2)" for 2-arg. Resource arguments are excluded.
273
274    let args_type_str = if wire_type_strs.is_empty() {
275        "()".to_string()
276    } else if wire_type_strs.len() == 1 {
277        format!("({},)", wire_type_strs[0])
278    } else {
279        format!("({})", wire_type_strs.join(", "))
280    };
281
282    // --- Duplicate cmd_id detection ---
283    //
284    // Compute the cmd_id at macro-expansion time so we can:
285    // 1. Check for same-crate collisions via an in-process registry → compile_error!
286    // 2. Emit a link-time guard symbol (export_name keyed on the hex id) that causes
287    //    a "multiple definition" linker error when two commands from different crates
288    //    happen to share the same id in the final binary.
289
290    let cmd_id_value = explicit_cmd_id
291        .unwrap_or_else(|| compute_cmd_id(&fn_name_str, &args_type_str, &ret_type_str));
292
293    {
294        let mut seen = seen_cmd_ids().lock().unwrap();
295        if let Some(existing) = seen.get(&cmd_id_value) {
296            return Err(syn::Error::new_spanned(
297                fn_ident,
298                format!(
299                    "#[command] cmd_id collision: `{}` and `{}` both map to 0x{:04X}. \
300                     Rename one of the commands to avoid the collision.",
301                    fn_name_str, existing, cmd_id_value
302                ),
303            ));
304        }
305        seen.insert(cmd_id_value, fn_name_str.clone());
306    }
307
308    // If the caller supplied an explicit `cmd_id = N`, embed that literal
309    // directly. Otherwise emit a `__derive_cmd_id(...)` const-fn call so
310    // the derivation is independently verifiable in expanded output.
311    let cmd_id_expr: proc_macro2::TokenStream = if explicit_cmd_id.is_some() {
312        let v = cmd_id_value;
313        quote! { #v }
314    } else {
315        quote! {
316            ::telepath_server::__derive_cmd_id(
317                #fn_name_str,
318                #args_type_str,
319                #ret_type_str,
320            )
321        }
322    };
323
324    let collision_export = format!("__telepath_cmd_id_{:04X}", cmd_id_value);
325    let guard_ident = format_ident!("__TELEPATH_CMDID_GUARD_{}", fn_name_str.to_uppercase());
326
327    // --- Generated identifiers ---
328
329    let shim_ident = format_ident!("__telepath_shim_{}", fn_name_str);
330    let args_schema_ident = format_ident!("__telepath_args_schema_{}", fn_name_str);
331    let ret_schema_ident = format_ident!("__telepath_ret_schema_{}", fn_name_str);
332    let static_ident = format_ident!("__TELEPATH_CMD_{}", fn_name_str.to_uppercase());
333    let reg_ident = format_ident!("__TELEPATH_REG_{}", fn_name_str.to_uppercase());
334
335    // --- Compute args tuple type and ret type tokens for schema writers ---
336    // Only wire arguments participate in schemas and CmdID derivation.
337
338    let args_schema_type = if wire_types.is_empty() {
339        quote! { () }
340    } else if wire_types.len() == 1 {
341        let t = &*wire_types[0];
342        quote! { (#t,) }
343    } else {
344        quote! { (#(#wire_types),*) }
345    };
346
347    let ret_schema_type = match &func.sig.output {
348        ReturnType::Default => quote! { () },
349        ReturnType::Type(_, ty) => quote! { #ty },
350    };
351
352    // --- Build shim body ---
353
354    // Wire-arg deserialization
355    let wire_deser = if wire_idents.is_empty() {
356        quote! {
357            if !input.is_empty() {
358                return ::core::result::Result::Err(
359                    ::telepath_server::DispatchError::DeserializeError
360                );
361            }
362        }
363    } else {
364        let wire_tuple_type = if wire_types.len() == 1 {
365            let t = &*wire_types[0];
366            quote! { (#t,) }
367        } else {
368            quote! { (#(#wire_types),*) }
369        };
370        let wire_pat = if wire_idents.len() == 1 {
371            let id = &wire_idents[0];
372            quote! { (#id,) }
373        } else {
374            quote! { (#(#wire_idents),*) }
375        };
376        quote! {
377            let #wire_pat: #wire_tuple_type = match ::postcard::from_bytes(input) {
378                Ok(v) => v,
379                Err(_) => return ::core::result::Result::Err(
380                    ::telepath_server::DispatchError::DeserializeError
381                ),
382            };
383        }
384    };
385
386    // Resource lookups
387    let resource_lookups: Vec<_> = resource_args
388        .iter()
389        .map(|ra| {
390            let ident = &ra.ident;
391            let inner_ty = &ra.inner_ty;
392            if ra.is_mut {
393                quote! {
394                    let #ident: &mut #inner_ty = unsafe {
395                        &mut *__resources.get_ptr::<#inner_ty>()
396                            .ok_or(::telepath_server::DispatchError::ResourceUnavailable)?
397                    };
398                }
399            } else {
400                quote! {
401                    let #ident: &#inner_ty = unsafe {
402                        &*__resources.get_ptr::<#inner_ty>()
403                            .ok_or(::telepath_server::DispatchError::ResourceUnavailable)?
404                    };
405                }
406            }
407        })
408        .collect();
409
410    // Call arguments in declaration order
411    let call_args: Vec<_> = all_arg_idents
412        .iter()
413        .map(|ident| quote! { #ident })
414        .collect();
415
416    let shim_body = quote! {
417        #wire_deser
418        #(#resource_lookups)*
419        let __ret = #fn_ident(#(#call_args),*);
420        match ::postcard::to_slice(&__ret, output) {
421            Ok(s) => ::core::result::Result::Ok(s.len()),
422            Err(_) => ::core::result::Result::Err(
423                ::telepath_server::DispatchError::SerializeError
424            ),
425        }
426    };
427
428    // Strip #[resource] attributes from the original function so that
429    // it compiles as a normal function with reference parameters.
430    let mut clean_func = func.clone();
431    for fn_arg in &mut clean_func.sig.inputs {
432        if let FnArg::Typed(pat_type) = fn_arg {
433            pat_type.attrs.retain(|a| !a.path().is_ident("resource"));
434        }
435    }
436
437    Ok(quote! {
438        #clean_func
439
440        #[allow(non_snake_case)]
441        fn #shim_ident(
442            input: &[u8],
443            output: &mut [u8],
444            __resources: &::telepath_server::ResourceRegistry,
445        ) -> ::core::result::Result<usize, ::telepath_server::DispatchError> {
446            #shim_body
447        }
448
449        #[allow(non_snake_case)]
450        fn #args_schema_ident(out: &mut [u8]) -> ::core::result::Result<usize, ()> {
451            ::postcard::to_slice(
452                <#args_schema_type as ::telepath_server::__postcard_schema::Schema>::SCHEMA,
453                out,
454            )
455            .map(|s| s.len())
456            .map_err(|_| ())
457        }
458
459        #[allow(non_snake_case)]
460        fn #ret_schema_ident(out: &mut [u8]) -> ::core::result::Result<usize, ()> {
461            ::postcard::to_slice(
462                <#ret_schema_type as ::telepath_server::__postcard_schema::Schema>::SCHEMA,
463                out,
464            )
465            .map(|s| s.len())
466            .map_err(|_| ())
467        }
468
469        pub const #static_ident: ::telepath_server::CommandMetadata =
470            ::telepath_server::CommandMetadata {
471                name: #fn_name_str,
472                id: #cmd_id_expr,
473                invoke: #shim_ident,
474                args_schema: #args_schema_ident,
475                ret_schema: #ret_schema_ident,
476                arg_names: #arg_names_str,
477            };
478
479        #[allow(non_upper_case_globals, non_snake_case)]
480        #[::telepath_server::__linkme::distributed_slice(::telepath_server::TELEPATH_COMMANDS)]
481        #[linkme(crate = ::telepath_server::__linkme)]
482        static #reg_ident: ::telepath_server::CommandMetadata = #static_ident;
483
484        // Link-time duplicate cmd_id guard.
485        //
486        // If two #[command] functions in the same binary (possibly from different
487        // crates) share the same cmd_id, the linker will emit a "multiple
488        // definition" error for `__telepath_cmd_id_XXXX`, stopping the build
489        // before the firmware is ever flashed.
490        //
491        // The in-process check above already catches same-crate collisions as a
492        // nicer compile_error!; this symbol is the defense-in-depth for
493        // incremental builds and cross-crate collisions.
494        #[doc(hidden)]
495        #[allow(non_upper_case_globals, dead_code)]
496        #[used]
497        #[export_name = #collision_export]
498        pub static #guard_ident: u8 = 0;
499
500    })
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use std::sync::Mutex;
507
508    // Serializes all tests that touch the global seen_cmd_ids() registry.
509    static TEST_GUARD: Mutex<()> = Mutex::new(());
510
511    fn parse_fn(src: &str) -> ItemFn {
512        syn::parse_str(src).unwrap()
513    }
514
515    #[test]
516    fn same_crate_collision_is_rejected() {
517        let _g = TEST_GUARD.lock().unwrap();
518        seen_cmd_ids().lock().unwrap().clear();
519        // cmd_446() -> u32 and cmd_470() -> u32 both map to 0x43AE (verified by brute force).
520        assert!(expand_command(parse_fn("fn cmd_446() -> u32 { 0 }"), None).is_ok());
521        let err = expand_command(parse_fn("fn cmd_470() -> u32 { 0 }"), None)
522            .unwrap_err()
523            .to_string();
524        assert!(
525            err.contains("cmd_id collision"),
526            "expected collision error, got: {err}"
527        );
528        assert!(
529            err.contains("0x43AE"),
530            "expected hex id 0x43AE in error, got: {err}"
531        );
532        assert!(
533            err.contains("cmd_446") && err.contains("cmd_470"),
534            "expected both command names in error, got: {err}"
535        );
536        seen_cmd_ids().lock().unwrap().clear();
537    }
538
539    #[test]
540    fn guard_symbol_has_correct_export_name() {
541        let _g = TEST_GUARD.lock().unwrap();
542        seen_cmd_ids().lock().unwrap().clear();
543        let ts = expand_command(parse_fn("fn cmd_446() -> u32 { 0 }"), None)
544            .unwrap()
545            .to_string();
546        // Guard static export_name encodes the cmd_id as uppercase hex.
547        assert!(
548            ts.contains("__telepath_cmd_id_43AE"),
549            "guard symbol export_name not found in generated code: {ts}"
550        );
551        seen_cmd_ids().lock().unwrap().clear();
552    }
553
554    #[test]
555    fn distinct_commands_do_not_collide() {
556        let _g = TEST_GUARD.lock().unwrap();
557        seen_cmd_ids().lock().unwrap().clear();
558        assert!(expand_command(parse_fn("fn ping() -> u32 { 0 }"), None).is_ok());
559        assert!(expand_command(parse_fn("fn echo(x: u32) -> u32 { x }"), None).is_ok());
560        seen_cmd_ids().lock().unwrap().clear();
561    }
562
563    #[test]
564    fn explicit_cmd_id_overrides_derive() {
565        let _g = TEST_GUARD.lock().unwrap();
566        seen_cmd_ids().lock().unwrap().clear();
567        let ts = expand_command(parse_fn("fn get_metrics() -> u32 { 0 }"), Some(0xFFFE))
568            .unwrap()
569            .to_string();
570        // The generated CommandMetadata.id must be the literal 0xFFFE, not a __derive_cmd_id call.
571        assert!(
572            ts.contains("65534"), // 0xFFFE == 65534 in decimal token
573            "explicit cmd_id 0xFFFE not found as literal in generated code: {ts}"
574        );
575        // Guard symbol must encode the explicit id.
576        assert!(
577            ts.contains("__telepath_cmd_id_FFFE"),
578            "guard symbol for explicit cmd_id not found in generated code: {ts}"
579        );
580        seen_cmd_ids().lock().unwrap().clear();
581    }
582
583    #[test]
584    fn explicit_cmd_id_collision_rejected() {
585        let _g = TEST_GUARD.lock().unwrap();
586        seen_cmd_ids().lock().unwrap().clear();
587        assert!(expand_command(parse_fn("fn foo() -> u32 { 0 }"), Some(0xFFFE)).is_ok());
588        let err = expand_command(parse_fn("fn bar() -> u32 { 0 }"), Some(0xFFFE))
589            .unwrap_err()
590            .to_string();
591        assert!(
592            err.contains("cmd_id collision"),
593            "expected collision error for duplicate explicit cmd_id, got: {err}"
594        );
595        seen_cmd_ids().lock().unwrap().clear();
596    }
597}