Skip to main content

metaxy_cli/parser/
extract.rs

1use std::fs;
2use std::path::Path;
3
4use anyhow::{Context, Result};
5use globset::{GlobBuilder, GlobSet, GlobSetBuilder};
6use syn::{Attribute, File, FnArg, Item, ItemFn, ReturnType};
7use walkdir::WalkDir;
8
9use super::serde as serde_attr;
10use super::types::{extract_rust_type, extract_struct_fields, extract_tuple_fields};
11use crate::config::InputConfig;
12use crate::model::{
13    EnumDef, EnumVariant, Manifest, Procedure, ProcedureKind, StructDef, VariantKind,
14};
15
16/// RPC attribute names recognized by the parser.
17const RPC_QUERY_ATTR: &str = "rpc_query";
18const RPC_MUTATION_ATTR: &str = "rpc_mutation";
19
20/// Builds a `GlobSet` from a list of glob pattern strings.
21fn build_glob_set(patterns: &[String]) -> Result<GlobSet> {
22    let mut builder = GlobSetBuilder::new();
23    for pattern in patterns {
24        let glob = GlobBuilder::new(pattern)
25            .literal_separator(false)
26            .build()
27            .with_context(|| format!("Invalid glob pattern: {pattern}"))?;
28        builder.add(glob);
29    }
30    builder.build().context("Failed to build glob set")
31}
32
33/// Scans `.rs` files in the configured directory and extracts RPC metadata.
34///
35/// Walks the directory recursively, applying `include`/`exclude` glob patterns
36/// from the config, then parsing each matching Rust source file for
37/// `#[rpc_query]` / `#[rpc_mutation]` annotated functions and `#[derive(Serialize)]` structs.
38pub fn scan_directory(input: &InputConfig) -> Result<Manifest> {
39    let mut manifest = Manifest::default();
40
41    let include_set = build_glob_set(&input.include)?;
42    let exclude_set = build_glob_set(&input.exclude)?;
43
44    let mut file_count = 0;
45    for entry in WalkDir::new(&input.dir)
46        .into_iter()
47        // Skip unreadable entries (e.g. permission denied); the scan should
48        // not abort because a single directory entry is inaccessible.
49        .filter_map(|e| e.ok())
50        .filter(|e| {
51            if e.path().extension().is_none_or(|ext| ext != "rs") {
52                return false;
53            }
54            let rel = e.path().strip_prefix(&input.dir).unwrap_or(e.path());
55            include_set.is_match(rel) && !exclude_set.is_match(rel)
56        })
57    {
58        file_count += 1;
59        let path = entry.path();
60        let file_manifest =
61            parse_file(path).with_context(|| format!("Failed to parse {}", path.display()))?;
62
63        manifest.procedures.extend(file_manifest.procedures);
64        manifest.structs.extend(file_manifest.structs);
65        manifest.enums.extend(file_manifest.enums);
66    }
67
68    if file_count == 0 {
69        anyhow::bail!("No .rs files found in {}", input.dir.display());
70    }
71
72    // Sort for deterministic output
73    manifest.procedures.sort_by(|a, b| a.name.cmp(&b.name));
74    manifest.structs.sort_by(|a, b| a.name.cmp(&b.name));
75    manifest.enums.sort_by(|a, b| a.name.cmp(&b.name));
76
77    Ok(manifest)
78}
79
80/// Parses a single Rust source file and extracts all RPC procedures and struct definitions.
81pub fn parse_file(path: &Path) -> Result<Manifest> {
82    let source =
83        fs::read_to_string(path).with_context(|| format!("Cannot read {}", path.display()))?;
84
85    let syntax: File =
86        syn::parse_file(&source).with_context(|| format!("Syntax error in {}", path.display()))?;
87
88    let mut manifest = Manifest::default();
89
90    for item in &syntax.items {
91        match item {
92            Item::Fn(func) => {
93                if let Some(procedure) = try_extract_procedure(func, path) {
94                    manifest.procedures.push(procedure);
95                }
96            }
97            Item::Struct(item_struct) => {
98                if has_serde_derive(&item_struct.attrs) {
99                    let generics = extract_generic_param_names(&item_struct.generics);
100                    let tuple_fields = extract_tuple_fields(&item_struct.fields);
101                    let fields = if tuple_fields.is_empty() {
102                        extract_struct_fields(&item_struct.fields)
103                    } else {
104                        vec![]
105                    };
106                    let docs = extract_docs(&item_struct.attrs);
107                    let rename_all = serde_attr::parse_rename_all(&item_struct.attrs);
108                    manifest.structs.push(StructDef {
109                        name: item_struct.ident.to_string(),
110                        generics,
111                        fields,
112                        tuple_fields,
113                        source_file: path.to_path_buf(),
114                        docs,
115                        rename_all,
116                    });
117                }
118            }
119            Item::Enum(item_enum) => {
120                if has_serde_derive(&item_enum.attrs) {
121                    let generics = extract_generic_param_names(&item_enum.generics);
122                    let rename_all = serde_attr::parse_rename_all(&item_enum.attrs);
123                    let tagging = serde_attr::parse_enum_tagging(&item_enum.attrs);
124                    let variants = extract_enum_variants(item_enum);
125                    let docs = extract_docs(&item_enum.attrs);
126                    manifest.enums.push(EnumDef {
127                        name: item_enum.ident.to_string(),
128                        generics,
129                        variants,
130                        source_file: path.to_path_buf(),
131                        docs,
132                        rename_all,
133                        tagging,
134                    });
135                }
136            }
137            _ => {}
138        }
139    }
140
141    Ok(manifest)
142}
143
144/// Extracts doc comments from `#[doc = "..."]` attributes (written as `///` in source).
145///
146/// Returns `None` if no doc comments are present.
147fn extract_docs(attrs: &[Attribute]) -> Option<String> {
148    let lines: Vec<String> = attrs
149        .iter()
150        .filter_map(|attr| {
151            if !attr.path().is_ident("doc") {
152                return None;
153            }
154            if let syn::Meta::NameValue(nv) = &attr.meta
155                && let syn::Expr::Lit(syn::ExprLit {
156                    lit: syn::Lit::Str(s),
157                    ..
158                }) = &nv.value
159            {
160                let text = s.value();
161                // `///` comments produce a leading space, strip it
162                return Some(text.strip_prefix(' ').unwrap_or(&text).to_string());
163            }
164            None
165        })
166        .collect();
167
168    if lines.is_empty() {
169        None
170    } else {
171        Some(lines.join("\n"))
172    }
173}
174
175/// Attempts to extract an RPC procedure from a function item.
176/// Returns `None` if the function doesn't have an RPC attribute.
177fn try_extract_procedure(func: &ItemFn, path: &Path) -> Option<Procedure> {
178    let kind = detect_rpc_kind(&func.attrs)?;
179    let name = func.sig.ident.to_string();
180    let docs = extract_docs(&func.attrs);
181
182    let input = func.sig.inputs.iter().find_map(|arg| {
183        let FnArg::Typed(pat) = arg else { return None };
184        // Skip the Headers parameter — it's not part of the RPC input.
185        if is_headers_type(&pat.ty) {
186            return None;
187        }
188        // Skip reference parameters — these are init-injected state (&T).
189        if matches!(&*pat.ty, syn::Type::Reference(_)) {
190            return None;
191        }
192        Some(extract_rust_type(&pat.ty))
193    });
194
195    let output = match &func.sig.output {
196        ReturnType::Default => None,
197        ReturnType::Type(_, ty) => {
198            let rust_type = extract_rust_type(ty);
199            // Unwrap Result<T, _> to just T
200            if rust_type.name == "Result" && !rust_type.generics.is_empty() {
201                rust_type.generics.into_iter().next()
202            } else {
203                Some(rust_type)
204            }
205        }
206    };
207
208    let timeout_ms = extract_timeout_ms(&func.attrs);
209    let idempotent = extract_idempotent(&func.attrs);
210
211    Some(Procedure {
212        name,
213        kind,
214        input,
215        output,
216        source_file: path.to_path_buf(),
217        docs,
218        timeout_ms,
219        idempotent,
220    })
221}
222
223/// Checks function attributes for `#[rpc_query]` or `#[rpc_mutation]`.
224fn detect_rpc_kind(attrs: &[Attribute]) -> Option<ProcedureKind> {
225    for attr in attrs {
226        if attr.path().is_ident(RPC_QUERY_ATTR) {
227            return Some(ProcedureKind::Query);
228        }
229        if attr.path().is_ident(RPC_MUTATION_ATTR) {
230            return Some(ProcedureKind::Mutation);
231        }
232    }
233    None
234}
235
236/// Extracts generic type parameter names from `syn::Generics`.
237///
238/// Only type parameters are extracted; lifetimes and const generics are skipped.
239fn extract_generic_param_names(generics: &syn::Generics) -> Vec<String> {
240    generics
241        .params
242        .iter()
243        .filter_map(|p| match p {
244            syn::GenericParam::Type(t) => Some(t.ident.to_string()),
245            _ => None,
246        })
247        .collect()
248}
249
250/// Extracts variants from a Rust enum into `EnumVariant` representations.
251fn extract_enum_variants(item_enum: &syn::ItemEnum) -> Vec<EnumVariant> {
252    item_enum
253        .variants
254        .iter()
255        .map(|v| {
256            let name = v.ident.to_string();
257            let rename = serde_attr::parse_rename(&v.attrs);
258            let kind = match &v.fields {
259                syn::Fields::Unit => VariantKind::Unit,
260                syn::Fields::Unnamed(fields) => {
261                    let types = fields
262                        .unnamed
263                        .iter()
264                        .map(|f| extract_rust_type(&f.ty))
265                        .collect();
266                    VariantKind::Tuple(types)
267                }
268                syn::Fields::Named(_) => {
269                    let fields = extract_struct_fields(&v.fields);
270                    VariantKind::Struct(fields)
271                }
272            };
273            EnumVariant { name, kind, rename }
274        })
275        .collect()
276}
277
278/// Returns `true` if the type path ends with `Headers` (e.g. `Headers`, `metaxy::Headers`).
279///
280/// Used to skip the `Headers` parameter when extracting RPC input types,
281/// since it carries request metadata rather than user-provided input.
282fn is_headers_type(ty: &syn::Type) -> bool {
283    if let syn::Type::Path(type_path) = ty
284        && let Some(segment) = type_path.path.segments.last()
285    {
286        return segment.ident == "Headers";
287    }
288    false
289}
290
291/// Extracts the `timeout` value from `#[rpc_query(timeout = "30s")]` or `#[rpc_mutation(timeout = "30s")]`.
292///
293/// Returns `Some(milliseconds)` if a valid timeout is found, `None` otherwise.
294/// Uses `Punctuated<Meta>` to handle mixed bare flags (e.g. `idempotent`) alongside key-value pairs.
295fn extract_timeout_ms(attrs: &[Attribute]) -> Option<u64> {
296    for attr in attrs {
297        if !attr.path().is_ident(RPC_QUERY_ATTR) && !attr.path().is_ident(RPC_MUTATION_ATTR) {
298            continue;
299        }
300        let Ok(parsed) = attr.parse_args_with(
301            syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
302        ) else {
303            continue;
304        };
305        for meta in &parsed {
306            if let syn::Meta::NameValue(nv) = meta
307                && nv.path.is_ident("timeout")
308                && let syn::Expr::Lit(syn::ExprLit {
309                    lit: syn::Lit::Str(s),
310                    ..
311                }) = &nv.value
312            {
313                return parse_duration_to_ms(&s.value());
314            }
315        }
316    }
317    None
318}
319
320/// Parses a human-readable duration shorthand into milliseconds.
321///
322/// Lenient: returns `None` on any parse error instead of failing the scan.
323fn parse_duration_to_ms(s: &str) -> Option<u64> {
324    let (num_str, multiplier) = if let Some(n) = s.strip_suffix('s') {
325        (n, 1_000)
326    } else if let Some(n) = s.strip_suffix('m') {
327        (n, 60_000)
328    } else if let Some(n) = s.strip_suffix('h') {
329        (n, 3_600_000)
330    } else if let Some(n) = s.strip_suffix('d') {
331        (n, 86_400_000)
332    } else {
333        return None;
334    };
335    let num: u64 = num_str.parse().ok()?;
336    if num == 0 {
337        return None;
338    }
339    Some(num * multiplier)
340}
341
342/// Extracts the bare `idempotent` flag from `#[rpc_mutation(idempotent)]`.
343///
344/// Only checks `RPC_MUTATION_ATTR` attributes. Lenient: silently ignores
345/// `rpc_query(idempotent)` (the proc macro rejects it at compile time).
346fn extract_idempotent(attrs: &[Attribute]) -> bool {
347    for attr in attrs {
348        if !attr.path().is_ident(RPC_MUTATION_ATTR) {
349            continue;
350        }
351        let Ok(parsed) = attr.parse_args_with(
352            syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
353        ) else {
354            continue;
355        };
356        for meta in &parsed {
357            if let syn::Meta::Path(path) = meta
358                && path.is_ident("idempotent")
359            {
360                return true;
361            }
362        }
363    }
364    false
365}
366
367/// Checks if a struct has `#[derive(Serialize)]` or `#[derive(serde::Serialize)]`.
368fn has_serde_derive(attrs: &[Attribute]) -> bool {
369    attrs.iter().any(|attr| {
370        if !attr.path().is_ident("derive") {
371            return false;
372        }
373        attr.parse_args_with(
374            syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
375        )
376        .is_ok_and(|nested| {
377            nested.iter().any(|path| {
378                path.is_ident("Serialize")
379                    || path.segments.last().is_some_and(|s| s.ident == "Serialize")
380            })
381        })
382    })
383}