Skip to main content

roam_macros_parse/
lib.rs

1//! Parser grammar for roam RPC service trait definitions.
2//!
3//! # This Is Just a Grammar
4//!
5//! This crate contains **only** the [unsynn] grammar for parsing Rust trait definitions
6//! that define roam RPC services. It does not:
7//!
8//! - Generate any code
9//! - Perform validation
10//! - Know anything about roam's wire protocol
11//! - Have opinions about how services should be implemented
12//!
13//! It simply parses syntax like:
14//!
15//! ```ignore
16//! pub trait Calculator {
17//!     /// Add two numbers.
18//!     async fn add(&self, a: i32, b: i32) -> i32;
19//! }
20//! ```
21//!
22//! ...and produces an AST ([`ServiceTrait`]) that downstream crates can inspect.
23//!
24//! # Why a Separate Crate?
25//!
26//! The grammar is extracted into its own crate so that:
27//!
28//! 1. **It can be tested independently** — We use [datatest-stable] + [insta] for
29//!    snapshot testing the parsed AST, which isn't possible in a proc-macro crate.
30//!
31//! 2. **It's reusable** — Other tools (linters, documentation generators, IDE plugins)
32//!    can parse service definitions without pulling in proc-macro dependencies.
33//!
34//! 3. **Separation of concerns** — The grammar is pure parsing; [`roam-macros`] handles
35//!    the proc-macro machinery; [`roam-codegen`] handles actual code generation.
36//!
37//! # The Bigger Picture
38//!
39//! ```text
40//! roam-macros-parse     roam-macros              roam-codegen
41//! ┌──────────────┐     ┌──────────────┐         ┌──────────────┐
42//! │              │     │              │         │              │
43//! │  unsynn      │────▶│  #[service]  │────────▶│  build.rs    │
44//! │  grammar     │     │  proc macro  │         │  code gen    │
45//! │              │     │              │         │              │
46//! └──────────────┘     └──────────────┘         └──────────────┘
47//!    just parsing         emit metadata          Rust, TS, Go...
48//! ```
49//!
50//! [unsynn]: https://docs.rs/unsynn
51//! [datatest-stable]: https://docs.rs/datatest-stable
52//! [insta]: https://docs.rs/insta
53//! [`roam-macros`]: https://docs.rs/roam-service-macros
54//! [`roam-codegen`]: https://docs.rs/roam-codegen
55
56pub use unsynn::Error as ParseError;
57pub use unsynn::ToTokens;
58
59use proc_macro2::TokenStream as TokenStream2;
60use unsynn::operator::names::{
61    Assign, Colon, Comma, Gt, LifetimeTick, Lt, PathSep, Pound, RArrow, Semicolon,
62};
63use unsynn::{
64    Any, BraceGroupContaining, BracketGroupContaining, CommaDelimitedVec, Cons, Either,
65    EndOfStream, Except, Ident, LiteralString, Many, Optional, ParenthesisGroupContaining, Parse,
66    ToTokenIter, TokenStream, keyword, unsynn,
67};
68
69keyword! {
70    pub KAsync = "async";
71    pub KFn = "fn";
72    pub KTrait = "trait";
73    pub KSelfKw = "self";
74    pub KMut = "mut";
75    pub KDoc = "doc";
76    pub KPub = "pub";
77    pub KWhere = "where";
78}
79
80/// Parses tokens and groups until `C` is found, handling `<...>` correctly.
81type VerbatimUntil<C> = Many<Cons<Except<C>, AngleTokenTree>>;
82
83unsynn! {
84    /// Parses either a `TokenTree` or `<...>` grouping.
85    #[derive(Clone)]
86    pub struct AngleTokenTree(
87        pub Either<Cons<Lt, Vec<Cons<Except<Gt>, AngleTokenTree>>, Gt>, unsynn::TokenTree>,
88    );
89
90    pub struct RawAttribute {
91        pub _pound: Pound,
92        pub body: BracketGroupContaining<TokenStream>,
93    }
94
95    pub struct DocAttribute {
96        pub _doc: KDoc,
97        pub _assign: Assign,
98        pub value: LiteralString,
99    }
100
101    pub enum Visibility {
102        Pub(KPub),
103        PubRestricted(Cons<KPub, ParenthesisGroupContaining<TokenStream>>),
104    }
105
106    pub struct RefSelf {
107        pub _amp: unsynn::operator::names::And,
108        pub mutability: Option<KMut>,
109        pub name: KSelfKw,
110    }
111
112    pub struct MethodParam {
113        pub name: Ident,
114        pub _colon: Colon,
115        pub ty: Type,
116    }
117
118    pub struct GenericParams {
119        pub _lt: Lt,
120        pub params: VerbatimUntil<Gt>,
121        pub _gt: Gt,
122    }
123
124    #[derive(Clone)]
125    pub struct TypePath {
126        pub leading: Option<PathSep>,
127        pub first: Ident,
128        pub rest: Any<Cons<PathSep, Ident>>,
129    }
130
131    #[derive(Clone)]
132    pub struct Lifetime {
133        pub _apo: LifetimeTick,
134        pub ident: Ident,
135    }
136
137    #[derive(Clone)]
138    pub enum GenericArgument {
139        Lifetime(Lifetime),
140        Type(Type),
141    }
142
143    #[derive(Clone)]
144    pub enum Type {
145        Reference(TypeRef),
146        Tuple(TypeTuple),
147        PathWithGenerics(PathWithGenerics),
148        Path(TypePath),
149    }
150
151    #[derive(Clone)]
152    pub struct TypeRef {
153        pub _amp: unsynn::operator::names::And,
154        pub lifetime: Option<Cons<LifetimeTick, Ident>>,
155        pub mutable: Option<KMut>,
156        pub inner: Box<Type>,
157    }
158
159    #[derive(Clone)]
160    pub struct TypeTuple(
161        pub ParenthesisGroupContaining<CommaDelimitedVec<Type>>,
162    );
163
164    #[derive(Clone)]
165    pub struct PathWithGenerics {
166        pub path: TypePath,
167        pub _lt: Lt,
168        pub args: CommaDelimitedVec<GenericArgument>,
169        pub _gt: Gt,
170    }
171
172    pub struct ReturnType {
173        pub _arrow: RArrow,
174        pub ty: Type,
175    }
176
177    pub struct WhereClause {
178        pub _where: KWhere,
179        pub bounds: VerbatimUntil<Semicolon>,
180    }
181
182    pub struct MethodParams {
183        pub receiver: RefSelf,
184        pub rest: Optional<Cons<Comma, CommaDelimitedVec<MethodParam>>>,
185    }
186
187    pub struct ServiceMethod {
188        pub attributes: Any<RawAttribute>,
189        pub _async: KAsync,
190        pub _fn: KFn,
191        pub name: Ident,
192        pub generics: Optional<GenericParams>,
193        pub params: ParenthesisGroupContaining<MethodParams>,
194        pub return_type: Optional<ReturnType>,
195        pub where_clause: Optional<WhereClause>,
196        pub _semi: Semicolon,
197    }
198
199    pub struct ServiceTrait {
200        pub attributes: Any<RawAttribute>,
201        pub vis: Optional<Visibility>,
202        pub _trait: KTrait,
203        pub name: Ident,
204        pub generics: Optional<GenericParams>,
205        pub body: BraceGroupContaining<Any<ServiceMethod>>,
206        pub _eos: EndOfStream,
207    }
208}
209
210// ============================================================================
211// Helper methods for GenericArgument
212// ============================================================================
213
214impl GenericArgument {
215    pub fn has_lifetime(&self) -> bool {
216        match self {
217            GenericArgument::Lifetime(_) => true,
218            GenericArgument::Type(ty) => ty.has_lifetime(),
219        }
220    }
221
222    pub fn has_named_lifetime(&self, name: &str) -> bool {
223        match self {
224            GenericArgument::Lifetime(lifetime) => lifetime.ident == name,
225            GenericArgument::Type(ty) => ty.has_named_lifetime(name),
226        }
227    }
228
229    pub fn has_non_named_lifetime(&self, name: &str) -> bool {
230        match self {
231            GenericArgument::Lifetime(lifetime) => lifetime.ident != name,
232            GenericArgument::Type(ty) => ty.has_non_named_lifetime(name),
233        }
234    }
235
236    pub fn has_elided_reference_lifetime(&self) -> bool {
237        match self {
238            GenericArgument::Lifetime(_) => false,
239            GenericArgument::Type(ty) => ty.has_elided_reference_lifetime(),
240        }
241    }
242
243    pub fn contains_channel(&self) -> bool {
244        match self {
245            GenericArgument::Lifetime(_) => false,
246            GenericArgument::Type(ty) => ty.contains_channel(),
247        }
248    }
249}
250
251// ============================================================================
252// Helper methods for Type
253// ============================================================================
254
255impl Type {
256    /// Extract Ok and Err types if this is Result<T, E>
257    pub fn as_result(&self) -> Option<(&Type, &Type)> {
258        match self {
259            Type::PathWithGenerics(PathWithGenerics { path, args, .. })
260                if path.last_segment().as_str() == "Result" && args.len() == 2 =>
261            {
262                let args_slice = args.as_slice();
263                match (&args_slice[0].value, &args_slice[1].value) {
264                    (GenericArgument::Type(ok), GenericArgument::Type(err)) => Some((ok, err)),
265                    _ => None,
266                }
267            }
268            _ => None,
269        }
270    }
271
272    /// Check if type contains a lifetime anywhere in the tree
273    pub fn has_lifetime(&self) -> bool {
274        match self {
275            Type::Reference(TypeRef {
276                lifetime: Some(_), ..
277            }) => true,
278            Type::Reference(TypeRef { inner, .. }) => inner.has_lifetime(),
279            Type::PathWithGenerics(PathWithGenerics { args, .. }) => {
280                args.iter().any(|t| t.value.has_lifetime())
281            }
282            Type::Tuple(TypeTuple(group)) => group.content.iter().any(|t| t.value.has_lifetime()),
283            Type::Path(_) => false,
284        }
285    }
286
287    /// Check if type contains the named lifetime anywhere in the tree.
288    pub fn has_named_lifetime(&self, name: &str) -> bool {
289        match self {
290            Type::Reference(TypeRef {
291                lifetime: Some(lifetime),
292                ..
293            }) => lifetime.second == name,
294            Type::Reference(TypeRef { inner, .. }) => inner.has_named_lifetime(name),
295            Type::PathWithGenerics(PathWithGenerics { args, .. }) => {
296                args.iter().any(|t| t.value.has_named_lifetime(name))
297            }
298            Type::Tuple(TypeTuple(group)) => group
299                .content
300                .iter()
301                .any(|t| t.value.has_named_lifetime(name)),
302            Type::Path(_) => false,
303        }
304    }
305
306    /// Check if type contains any named lifetime other than `name`.
307    pub fn has_non_named_lifetime(&self, name: &str) -> bool {
308        match self {
309            Type::Reference(TypeRef {
310                lifetime: Some(lifetime),
311                ..
312            }) => lifetime.second != name,
313            Type::Reference(TypeRef { inner, .. }) => inner.has_non_named_lifetime(name),
314            Type::PathWithGenerics(PathWithGenerics { args, .. }) => {
315                args.iter().any(|t| t.value.has_non_named_lifetime(name))
316            }
317            Type::Tuple(TypeTuple(group)) => group
318                .content
319                .iter()
320                .any(|t| t.value.has_non_named_lifetime(name)),
321            Type::Path(_) => false,
322        }
323    }
324
325    /// Check if type contains any `&T` reference without an explicit lifetime.
326    ///
327    /// We require explicit `'roam` for borrowed RPC return payloads.
328    pub fn has_elided_reference_lifetime(&self) -> bool {
329        match self {
330            Type::Reference(TypeRef { lifetime: None, .. }) => true,
331            Type::Reference(TypeRef { inner, .. }) => inner.has_elided_reference_lifetime(),
332            Type::PathWithGenerics(PathWithGenerics { args, .. }) => {
333                args.iter().any(|t| t.value.has_elided_reference_lifetime())
334            }
335            Type::Tuple(TypeTuple(group)) => group
336                .content
337                .iter()
338                .any(|t| t.value.has_elided_reference_lifetime()),
339            Type::Path(_) => false,
340        }
341    }
342
343    /// Check if type contains Tx or Rx at any nesting level
344    ///
345    /// Note: This is a heuristic based on type names. Proper validation should
346    /// happen at codegen time when we can resolve types properly.
347    pub fn contains_channel(&self) -> bool {
348        match self {
349            Type::Reference(TypeRef { inner, .. }) => inner.contains_channel(),
350            Type::Tuple(TypeTuple(group)) => {
351                group.content.iter().any(|t| t.value.contains_channel())
352            }
353            Type::PathWithGenerics(PathWithGenerics { path, args, .. }) => {
354                let seg = path.last_segment();
355                if seg == "Tx" || seg == "Rx" {
356                    return true;
357                }
358                args.iter().any(|t| t.value.contains_channel())
359            }
360            Type::Path(path) => {
361                let seg = path.last_segment();
362                seg == "Tx" || seg == "Rx"
363            }
364        }
365    }
366}
367
368// ============================================================================
369// Helper methods for TypePath
370// ============================================================================
371
372impl TypePath {
373    /// Get the last segment (e.g., "Result" from "std::result::Result")
374    pub fn last_segment(&self) -> String {
375        self.rest
376            .iter()
377            .last()
378            .map(|seg| seg.value.second.to_string())
379            .unwrap_or_else(|| self.first.to_string())
380    }
381}
382
383// ============================================================================
384// Helper methods for ServiceTrait
385// ============================================================================
386
387impl ServiceTrait {
388    /// Get the trait name as a string.
389    pub fn name(&self) -> String {
390        self.name.to_string()
391    }
392
393    /// Get the trait's doc string (collected from #[doc = "..."] attributes).
394    pub fn doc(&self) -> Option<String> {
395        collect_doc_string(&self.attributes)
396    }
397
398    /// Get an iterator over the methods.
399    pub fn methods(&self) -> impl Iterator<Item = &ServiceMethod> {
400        self.body.content.iter().map(|entry| &entry.value)
401    }
402}
403
404// ============================================================================
405// Helper methods for ServiceMethod
406// ============================================================================
407
408impl ServiceMethod {
409    /// Get the method name as a string.
410    pub fn name(&self) -> String {
411        self.name.to_string()
412    }
413
414    /// Get the method's doc string (collected from #[doc = "..."] attributes).
415    pub fn doc(&self) -> Option<String> {
416        collect_doc_string(&self.attributes)
417    }
418
419    /// Get an iterator over the method's parameters (excluding &self).
420    pub fn args(&self) -> impl Iterator<Item = &MethodParam> {
421        self.params
422            .content
423            .rest
424            .iter()
425            .flat_map(|rest| rest.value.second.iter().map(|entry| &entry.value))
426    }
427
428    /// Get the return type, defaulting to () if not specified.
429    pub fn return_type(&self) -> Type {
430        self.return_type
431            .iter()
432            .next()
433            .map(|r| r.value.ty.clone())
434            .unwrap_or_else(unit_type)
435    }
436
437    /// Check if receiver is &mut self (not allowed for service methods).
438    pub fn is_mut_receiver(&self) -> bool {
439        self.params.content.receiver.mutability.is_some()
440    }
441
442    /// Check if method has generics.
443    pub fn has_generics(&self) -> bool {
444        !self.generics.is_empty()
445    }
446}
447
448// ============================================================================
449// Helper methods for MethodParam
450// ============================================================================
451
452impl MethodParam {
453    /// Get the parameter name as a string.
454    pub fn name(&self) -> String {
455        self.name.to_string()
456    }
457}
458
459// ============================================================================
460// Helper functions
461// ============================================================================
462
463/// Extract Ok and Err types from a return type.
464/// Returns (ok_type, Some(err_type)) for Result<T, E>, or (type, None) otherwise.
465pub fn method_ok_and_err_types(return_ty: &Type) -> (&Type, Option<&Type>) {
466    if let Some((ok, err)) = return_ty.as_result() {
467        (ok, Some(err))
468    } else {
469        (return_ty, None)
470    }
471}
472
473/// Returns the unit type `()`.
474fn unit_type() -> Type {
475    let mut iter = "()".to_token_iter();
476    Type::parse(&mut iter).expect("unit type should always parse")
477}
478
479/// Collect doc strings from attributes.
480fn collect_doc_string(attrs: &Any<RawAttribute>) -> Option<String> {
481    let mut docs = Vec::new();
482
483    for attr in attrs.iter() {
484        let mut body_iter = attr.value.body.content.clone().to_token_iter();
485        if let Ok(doc_attr) = DocAttribute::parse(&mut body_iter) {
486            let line = doc_attr
487                .value
488                .as_str()
489                .replace("\\\"", "\"")
490                .replace("\\'", "'");
491            docs.push(line);
492        }
493    }
494
495    if docs.is_empty() {
496        None
497    } else {
498        Some(docs.join("\n"))
499    }
500}
501
502/// Parse a trait definition from a token stream.
503#[allow(clippy::result_large_err)] // unsynn::Error is external, we can't box it
504pub fn parse_trait(tokens: &TokenStream2) -> Result<ServiceTrait, unsynn::Error> {
505    let mut iter = tokens.clone().to_token_iter();
506    ServiceTrait::parse(&mut iter)
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512
513    fn parse(src: &str) -> ServiceTrait {
514        let ts: TokenStream2 = src.parse().expect("tokenstream parse");
515        parse_trait(&ts).expect("trait parse")
516    }
517
518    #[test]
519    fn parse_trait_exposes_docs_methods_and_args() {
520        let trait_def = parse(
521            r#"
522            #[doc = "Calculator service."]
523            pub trait Calculator {
524                #[doc = "Adds two numbers."]
525                async fn add(&self, a: i32, b: i32) -> Result<i64, String>;
526            }
527            "#,
528        );
529
530        assert_eq!(trait_def.name(), "Calculator");
531        assert_eq!(trait_def.doc(), Some("Calculator service.".to_string()));
532
533        let method = trait_def.methods().next().expect("method");
534        assert_eq!(method.name(), "add");
535        assert_eq!(method.doc(), Some("Adds two numbers.".to_string()));
536        assert_eq!(
537            method.args().map(|arg| arg.name()).collect::<Vec<_>>(),
538            vec!["a", "b"]
539        );
540
541        let ret = method.return_type();
542        let (ok, err) = method_ok_and_err_types(&ret);
543        assert!(ok.as_result().is_none());
544        assert!(err.is_some());
545    }
546
547    #[test]
548    fn return_type_defaults_to_unit_when_omitted() {
549        let trait_def = parse(
550            r#"
551            trait Svc {
552                async fn ping(&self);
553            }
554            "#,
555        );
556        let method = trait_def.methods().next().expect("method");
557        let ret = method.return_type();
558        match ret {
559            Type::Tuple(TypeTuple(group)) => assert!(group.content.is_empty()),
560            other => panic!(
561                "expected unit tuple return, got {}",
562                other.to_token_stream()
563            ),
564        }
565    }
566
567    #[test]
568    fn method_helpers_detect_generics_and_mut_receiver() {
569        let trait_def = parse(
570            r#"
571            trait Svc {
572                async fn bad<T>(&mut self, value: T) -> T;
573            }
574            "#,
575        );
576        let method = trait_def.methods().next().expect("method");
577        assert!(method.has_generics());
578        assert!(method.is_mut_receiver());
579    }
580
581    #[test]
582    fn type_helpers_detect_result_lifetime_and_channel_nesting() {
583        let trait_def = parse(
584            r#"
585            trait Svc {
586                async fn stream(&self, input: &'static str) -> Result<Option<Tx<Vec<u8>>>, Rx<u32>>;
587            }
588            "#,
589        );
590        let method = trait_def.methods().next().expect("method");
591        let arg = method.args().next().expect("arg");
592        assert!(arg.ty.has_lifetime());
593        assert!(!arg.ty.contains_channel());
594
595        let ret = method.return_type();
596        let (ok, err) = method_ok_and_err_types(&ret);
597        assert!(ok.contains_channel());
598        assert!(err.expect("result err type").contains_channel());
599    }
600
601    #[test]
602    fn type_helpers_detect_named_and_elided_lifetimes() {
603        let trait_def = parse(
604            r#"
605            trait Svc {
606                async fn borrowed(&self) -> Result<&'roam str, Error>;
607                async fn bad_lifetime(&self) -> Result<&'a str, Error>;
608                async fn elided(&self) -> Result<&str, Error>;
609            }
610            "#,
611        );
612        let mut methods = trait_def.methods();
613
614        let borrowed = methods.next().expect("borrowed method").return_type();
615        let (borrowed_ok, _) = method_ok_and_err_types(&borrowed);
616        assert!(borrowed_ok.has_named_lifetime("roam"));
617        assert!(!borrowed_ok.has_non_named_lifetime("roam"));
618        assert!(!borrowed_ok.has_elided_reference_lifetime());
619
620        let bad_lifetime = methods.next().expect("bad_lifetime method").return_type();
621        let (bad_ok, _) = method_ok_and_err_types(&bad_lifetime);
622        assert!(!bad_ok.has_named_lifetime("roam"));
623        assert!(bad_ok.has_non_named_lifetime("roam"));
624        assert!(!bad_ok.has_elided_reference_lifetime());
625
626        let elided = methods.next().expect("elided method").return_type();
627        let (elided_ok, _) = method_ok_and_err_types(&elided);
628        assert!(!elided_ok.has_named_lifetime("roam"));
629        assert!(!elided_ok.has_non_named_lifetime("roam"));
630        assert!(elided_ok.has_elided_reference_lifetime());
631    }
632
633    #[test]
634    fn type_path_last_segment_uses_trailing_segment() {
635        let trait_def = parse(
636            r#"
637            trait Svc {
638                async fn f(&self) -> std::result::Result<u8, u8>;
639            }
640            "#,
641        );
642        let method = trait_def.methods().next().expect("method");
643        let ret = method.return_type();
644        let Type::PathWithGenerics(path_with_generics) = ret else {
645            panic!("expected path with generics");
646        };
647        assert_eq!(path_with_generics.path.last_segment(), "Result");
648    }
649}