Skip to main content

metaxy_cli/parser/
extract.rs

1use std::fs;
2use std::path::Path;
3
4use anyhow::{Context, Result};
5use globset::{GlobBuilder, GlobSet, GlobSetBuilder};
6use syn::{Attribute, File, FnArg, Item, ItemFn, ReturnType};
7use walkdir::WalkDir;
8
9use super::serde as serde_attr;
10use super::types::{extract_rust_type, extract_struct_fields, extract_tuple_fields};
11use crate::config::InputConfig;
12use crate::model::{
13    EnumDef, EnumVariant, Manifest, Procedure, ProcedureKind, RustType, StructDef, VariantKind,
14};
15
16/// RPC attribute names recognized by the parser.
17const RPC_QUERY_ATTR: &str = "rpc_query";
18const RPC_MUTATION_ATTR: &str = "rpc_mutation";
19const RPC_STREAM_ATTR: &str = "rpc_stream";
20
21/// Builds a `GlobSet` from a list of glob pattern strings.
22fn build_glob_set(patterns: &[String]) -> Result<GlobSet> {
23    let mut builder = GlobSetBuilder::new();
24    for pattern in patterns {
25        let glob = GlobBuilder::new(pattern)
26            .literal_separator(false)
27            .build()
28            .with_context(|| format!("Invalid glob pattern: {pattern}"))?;
29        builder.add(glob);
30    }
31    builder.build().context("Failed to build glob set")
32}
33
34/// Scans `.rs` files in the configured directory and extracts RPC metadata.
35///
36/// Walks the directory recursively, applying `include`/`exclude` glob patterns
37/// from the config, then parsing each matching Rust source file for
38/// `#[rpc_query]` / `#[rpc_mutation]` annotated functions and `#[derive(Serialize)]` structs.
39pub fn scan_directory(input: &InputConfig) -> Result<Manifest> {
40    let mut manifest = Manifest::default();
41
42    let include_set = build_glob_set(&input.include)?;
43    let exclude_set = build_glob_set(&input.exclude)?;
44
45    let mut file_count = 0;
46    for entry in WalkDir::new(&input.dir)
47        .into_iter()
48        // Skip unreadable entries (e.g. permission denied); the scan should
49        // not abort because a single directory entry is inaccessible.
50        .filter_map(|e| e.ok())
51        .filter(|e| {
52            if e.path().extension().is_none_or(|ext| ext != "rs") {
53                return false;
54            }
55            let rel = e.path().strip_prefix(&input.dir).unwrap_or(e.path());
56            include_set.is_match(rel) && !exclude_set.is_match(rel)
57        })
58    {
59        file_count += 1;
60        let path = entry.path();
61        let file_manifest =
62            parse_file(path).with_context(|| format!("Failed to parse {}", path.display()))?;
63
64        manifest.procedures.extend(file_manifest.procedures);
65        manifest.structs.extend(file_manifest.structs);
66        manifest.enums.extend(file_manifest.enums);
67    }
68
69    if file_count == 0 {
70        anyhow::bail!("No .rs files found in {}", input.dir.display());
71    }
72
73    // Sort for deterministic output
74    manifest.procedures.sort_by(|a, b| a.name.cmp(&b.name));
75    manifest.structs.sort_by(|a, b| a.name.cmp(&b.name));
76    manifest.enums.sort_by(|a, b| a.name.cmp(&b.name));
77
78    Ok(manifest)
79}
80
81/// Parses a single Rust source file and extracts all RPC procedures and struct definitions.
82pub fn parse_file(path: &Path) -> Result<Manifest> {
83    let source =
84        fs::read_to_string(path).with_context(|| format!("Cannot read {}", path.display()))?;
85
86    let syntax: File =
87        syn::parse_file(&source).with_context(|| format!("Syntax error in {}", path.display()))?;
88
89    let mut manifest = Manifest::default();
90
91    for item in &syntax.items {
92        match item {
93            Item::Fn(func) => {
94                if let Some(procedure) = try_extract_procedure(func, path) {
95                    manifest.procedures.push(procedure);
96                }
97            }
98            Item::Struct(item_struct) => {
99                if has_serde_derive(&item_struct.attrs) {
100                    let generics = extract_generic_param_names(&item_struct.generics);
101                    let tuple_fields = extract_tuple_fields(&item_struct.fields);
102                    let fields = if tuple_fields.is_empty() {
103                        extract_struct_fields(&item_struct.fields)
104                    } else {
105                        vec![]
106                    };
107                    let docs = extract_docs(&item_struct.attrs);
108                    let rename_all = serde_attr::parse_rename_all(&item_struct.attrs);
109                    manifest.structs.push(StructDef {
110                        name: item_struct.ident.to_string(),
111                        generics,
112                        fields,
113                        tuple_fields,
114                        source_file: path.to_path_buf(),
115                        docs,
116                        rename_all,
117                    });
118                }
119            }
120            Item::Enum(item_enum) => {
121                if has_serde_derive(&item_enum.attrs) {
122                    let generics = extract_generic_param_names(&item_enum.generics);
123                    let rename_all = serde_attr::parse_rename_all(&item_enum.attrs);
124                    let tagging = serde_attr::parse_enum_tagging(&item_enum.attrs);
125                    let variants = extract_enum_variants(item_enum);
126                    let docs = extract_docs(&item_enum.attrs);
127                    manifest.enums.push(EnumDef {
128                        name: item_enum.ident.to_string(),
129                        generics,
130                        variants,
131                        source_file: path.to_path_buf(),
132                        docs,
133                        rename_all,
134                        tagging,
135                    });
136                }
137            }
138            _ => {}
139        }
140    }
141
142    Ok(manifest)
143}
144
145/// Extracts doc comments from `#[doc = "..."]` attributes (written as `///` in source).
146///
147/// Returns `None` if no doc comments are present.
148fn extract_docs(attrs: &[Attribute]) -> Option<String> {
149    let lines: Vec<String> = attrs
150        .iter()
151        .filter_map(|attr| {
152            if !attr.path().is_ident("doc") {
153                return None;
154            }
155            if let syn::Meta::NameValue(nv) = &attr.meta
156                && let syn::Expr::Lit(syn::ExprLit {
157                    lit: syn::Lit::Str(s),
158                    ..
159                }) = &nv.value
160            {
161                let text = s.value();
162                // `///` comments produce a leading space, strip it
163                return Some(text.strip_prefix(' ').unwrap_or(&text).to_string());
164            }
165            None
166        })
167        .collect();
168
169    if lines.is_empty() {
170        None
171    } else {
172        Some(lines.join("\n"))
173    }
174}
175
176/// Attempts to extract an RPC procedure from a function item.
177/// Returns `None` if the function doesn't have an RPC attribute.
178fn try_extract_procedure(func: &ItemFn, path: &Path) -> Option<Procedure> {
179    let kind = detect_rpc_kind(&func.attrs)?;
180    let name = func.sig.ident.to_string();
181    let docs = extract_docs(&func.attrs);
182
183    let input = func.sig.inputs.iter().find_map(|arg| {
184        let FnArg::Typed(pat) = arg else { return None };
185        // Skip the Headers parameter — it's not part of the RPC input.
186        if is_headers_type(&pat.ty) {
187            return None;
188        }
189        // Skip reference parameters — these are init-injected state (&T).
190        if matches!(&*pat.ty, syn::Type::Reference(_)) {
191            return None;
192        }
193        // Skip the StreamSender parameter — it's an internal streaming channel.
194        if is_stream_sender_type(&pat.ty) {
195            return None;
196        }
197        Some(extract_rust_type(&pat.ty))
198    });
199
200    // For streams, the output type comes from the StreamSender<T> parameter.
201    // For queries/mutations, it comes from the function return type.
202    let output = if kind == ProcedureKind::Stream {
203        func.sig.inputs.iter().find_map(|arg| {
204            let FnArg::Typed(pat) = arg else { return None };
205            extract_stream_chunk_type(&pat.ty)
206        })
207    } else {
208        match &func.sig.output {
209            ReturnType::Default => None,
210            ReturnType::Type(_, ty) => {
211                let rust_type = extract_rust_type(ty);
212                // Unwrap Result<T, _> to just T
213                if rust_type.name == "Result" && !rust_type.generics.is_empty() {
214                    rust_type.generics.into_iter().next()
215                } else {
216                    Some(rust_type)
217                }
218            }
219        }
220    };
221
222    let timeout_ms = extract_timeout_ms(&func.attrs);
223    let idempotent = extract_idempotent(&func.attrs);
224
225    Some(Procedure {
226        name,
227        kind,
228        input,
229        output,
230        source_file: path.to_path_buf(),
231        docs,
232        timeout_ms,
233        idempotent,
234    })
235}
236
237/// Checks function attributes for `#[rpc_query]`, `#[rpc_mutation]`, or `#[rpc_stream]`.
238fn detect_rpc_kind(attrs: &[Attribute]) -> Option<ProcedureKind> {
239    for attr in attrs {
240        if attr.path().is_ident(RPC_QUERY_ATTR) {
241            return Some(ProcedureKind::Query);
242        }
243        if attr.path().is_ident(RPC_MUTATION_ATTR) {
244            return Some(ProcedureKind::Mutation);
245        }
246        if attr.path().is_ident(RPC_STREAM_ATTR) {
247            return Some(ProcedureKind::Stream);
248        }
249    }
250    None
251}
252
253/// Extracts generic type parameter names from `syn::Generics`.
254///
255/// Only type parameters are extracted; lifetimes and const generics are skipped.
256fn extract_generic_param_names(generics: &syn::Generics) -> Vec<String> {
257    generics
258        .params
259        .iter()
260        .filter_map(|p| match p {
261            syn::GenericParam::Type(t) => Some(t.ident.to_string()),
262            _ => None,
263        })
264        .collect()
265}
266
267/// Extracts variants from a Rust enum into `EnumVariant` representations.
268fn extract_enum_variants(item_enum: &syn::ItemEnum) -> Vec<EnumVariant> {
269    item_enum
270        .variants
271        .iter()
272        .map(|v| {
273            let name = v.ident.to_string();
274            let rename = serde_attr::parse_rename(&v.attrs);
275            let kind = match &v.fields {
276                syn::Fields::Unit => VariantKind::Unit,
277                syn::Fields::Unnamed(fields) => {
278                    let types = fields
279                        .unnamed
280                        .iter()
281                        .map(|f| extract_rust_type(&f.ty))
282                        .collect();
283                    VariantKind::Tuple(types)
284                }
285                syn::Fields::Named(_) => {
286                    let fields = extract_struct_fields(&v.fields);
287                    VariantKind::Struct(fields)
288                }
289            };
290            EnumVariant { name, kind, rename }
291        })
292        .collect()
293}
294
295/// Returns `true` if the type path ends with `Headers` (e.g. `Headers`, `metaxy::Headers`).
296///
297/// Used to skip the `Headers` parameter when extracting RPC input types,
298/// since it carries request metadata rather than user-provided input.
299fn is_headers_type(ty: &syn::Type) -> bool {
300    if let syn::Type::Path(type_path) = ty
301        && let Some(segment) = type_path.path.segments.last()
302    {
303        return segment.ident == "Headers";
304    }
305    false
306}
307
308/// Returns `true` if the type path ends with `StreamSender` (e.g. `StreamSender`, `metaxy::StreamSender`).
309///
310/// Used to skip the `StreamSender` parameter when extracting RPC input types,
311/// since it is an internal streaming channel rather than user-provided input.
312fn is_stream_sender_type(ty: &syn::Type) -> bool {
313    if let syn::Type::Path(type_path) = ty
314        && let Some(segment) = type_path.path.segments.last()
315    {
316        return segment.ident == "StreamSender";
317    }
318    false
319}
320
321/// Extracts the chunk type `T` from `StreamSender<T>`.
322///
323/// Returns `None` for bare `StreamSender` (no type parameter).
324fn extract_stream_chunk_type(ty: &syn::Type) -> Option<RustType> {
325    let syn::Type::Path(type_path) = ty else {
326        return None;
327    };
328    let segment = type_path.path.segments.last()?;
329    if segment.ident != "StreamSender" {
330        return None;
331    }
332    let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
333        return None;
334    };
335    for arg in &args.args {
336        if let syn::GenericArgument::Type(inner_ty) = arg {
337            return Some(extract_rust_type(inner_ty));
338        }
339    }
340    None
341}
342
343/// Extracts the `timeout` value from `#[rpc_query(timeout = "30s")]` or `#[rpc_mutation(timeout = "30s")]`.
344///
345/// Returns `Some(milliseconds)` if a valid timeout is found, `None` otherwise.
346/// Uses `Punctuated<Meta>` to handle mixed bare flags (e.g. `idempotent`) alongside key-value pairs.
347fn extract_timeout_ms(attrs: &[Attribute]) -> Option<u64> {
348    for attr in attrs {
349        if !attr.path().is_ident(RPC_QUERY_ATTR)
350            && !attr.path().is_ident(RPC_MUTATION_ATTR)
351            && !attr.path().is_ident(RPC_STREAM_ATTR)
352        {
353            continue;
354        }
355        let Ok(parsed) = attr.parse_args_with(
356            syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
357        ) else {
358            continue;
359        };
360        for meta in &parsed {
361            if let syn::Meta::NameValue(nv) = meta
362                && nv.path.is_ident("timeout")
363                && let syn::Expr::Lit(syn::ExprLit {
364                    lit: syn::Lit::Str(s),
365                    ..
366                }) = &nv.value
367            {
368                return parse_duration_to_ms(&s.value());
369            }
370        }
371    }
372    None
373}
374
375/// Parses a human-readable duration shorthand into milliseconds.
376///
377/// Lenient: returns `None` on any parse error instead of failing the scan.
378fn parse_duration_to_ms(s: &str) -> Option<u64> {
379    let (num_str, multiplier) = if let Some(n) = s.strip_suffix('s') {
380        (n, 1_000)
381    } else if let Some(n) = s.strip_suffix('m') {
382        (n, 60_000)
383    } else if let Some(n) = s.strip_suffix('h') {
384        (n, 3_600_000)
385    } else if let Some(n) = s.strip_suffix('d') {
386        (n, 86_400_000)
387    } else {
388        return None;
389    };
390    let num: u64 = num_str.parse().ok()?;
391    if num == 0 {
392        return None;
393    }
394    Some(num * multiplier)
395}
396
397/// Extracts the bare `idempotent` flag from `#[rpc_mutation(idempotent)]`.
398///
399/// Only checks `RPC_MUTATION_ATTR` attributes. Lenient: silently ignores
400/// `rpc_query(idempotent)` (the proc macro rejects it at compile time).
401fn extract_idempotent(attrs: &[Attribute]) -> bool {
402    for attr in attrs {
403        if !attr.path().is_ident(RPC_MUTATION_ATTR) {
404            continue;
405        }
406        let Ok(parsed) = attr.parse_args_with(
407            syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
408        ) else {
409            continue;
410        };
411        for meta in &parsed {
412            if let syn::Meta::Path(path) = meta
413                && path.is_ident("idempotent")
414            {
415                return true;
416            }
417        }
418    }
419    false
420}
421
422/// Checks if a struct has `#[derive(Serialize)]` or `#[derive(serde::Serialize)]`.
423fn has_serde_derive(attrs: &[Attribute]) -> bool {
424    attrs.iter().any(|attr| {
425        if !attr.path().is_ident("derive") {
426            return false;
427        }
428        attr.parse_args_with(
429            syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
430        )
431        .is_ok_and(|nested| {
432            nested.iter().any(|path| {
433                path.is_ident("Serialize")
434                    || path.segments.last().is_some_and(|s| s.ident == "Serialize")
435            })
436        })
437    })
438}