Skip to main content

oxapi_macro/
lib.rs

1//! Procedural macros for the oxapi OpenAPI server stub generator.
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{Ident, ItemMod, ItemTrait, LitStr, Token};
7
8/// Main attribute macro for generating server stubs from OpenAPI specs.
9///
10/// # Example (single trait)
11///
12/// ```ignore
13/// #[oxapi::oxapi(axum, "spec.json")]
14/// trait MyServer {
15///     #[oxapi(map)]
16///     fn map_routes(router: Router) -> Router;
17///
18///     #[oxapi(get, "/users")]
19///     async fn get_users(state: State<AppState>, query: Query<_>);
20/// }
21/// ```
22///
23/// # Example (module with multiple traits)
24///
25/// ```ignore
26/// #[oxapi::oxapi(axum, "spec.json")]
27/// mod my_api {
28///     trait PetService {
29///         #[oxapi(map)]
30///         fn map_routes(router: Router<PetState>) -> Router<PetState>;
31///
32///         #[oxapi(get, "/pet/{petId}")]
33///         async fn get_pet(state: State<PetState>, pet_id: Path<_>);
34///     }
35///
36///     trait StoreService {
37///         #[oxapi(map)]
38///         fn map_routes(router: Router<StoreState>) -> Router<StoreState>;
39///
40///         #[oxapi(get, "/store/inventory")]
41///         async fn get_inventory(state: State<StoreState>);
42///     }
43/// }
44/// ```
45#[proc_macro_attribute]
46pub fn oxapi(attr: TokenStream, item: TokenStream) -> TokenStream {
47    match do_oxapi(attr.into(), item.into()) {
48        Ok(tokens) => tokens.into(),
49        Err(err) => err.to_compile_error().into(),
50    }
51}
52
53fn do_oxapi(attr: TokenStream2, item: TokenStream2) -> syn::Result<TokenStream2> {
54    // Parse the attribute arguments: (framework, "spec_path")
55    let args = syn::parse2::<MacroArgs>(attr)?;
56
57    // Load the OpenAPI spec
58    let spec_path = resolve_spec_path(&args.spec_path)?;
59    let generator = oxapi_impl::Generator::from_file(&spec_path).map_err(|e| {
60        syn::Error::new(args.spec_path.span(), format!("failed to load spec: {}", e))
61    })?;
62
63    // Try to parse as trait first, then as module
64    if let Ok(trait_item) = syn::parse2::<ItemTrait>(item.clone()) {
65        let processor = TraitProcessor::new(generator, args.framework, trait_item)?;
66        processor.generate()
67    } else if let Ok(mod_item) = syn::parse2::<ItemMod>(item) {
68        let processor = ModuleProcessor::new(generator, args.framework, mod_item)?;
69        processor.generate()
70    } else {
71        Err(syn::Error::new(
72            proc_macro2::Span::call_site(),
73            "expected trait or mod item",
74        ))
75    }
76}
77
78/// Resolve the spec path relative to CARGO_MANIFEST_DIR.
79fn resolve_spec_path(lit: &LitStr) -> syn::Result<std::path::PathBuf> {
80    let dir = std::env::var("CARGO_MANIFEST_DIR")
81        .map_err(|_| syn::Error::new(lit.span(), "CARGO_MANIFEST_DIR not set"))?;
82
83    let path = std::path::Path::new(&dir).join(lit.value());
84    Ok(path)
85}
86
87/// Parsed macro arguments.
88struct MacroArgs {
89    framework: Framework,
90    spec_path: LitStr,
91}
92
93impl syn::parse::Parse for MacroArgs {
94    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
95        let framework: Ident = input.parse()?;
96        input.parse::<Token![,]>()?;
97        let spec_path: LitStr = input.parse()?;
98
99        let framework = match framework.to_string().as_str() {
100            "axum" => Framework::Axum,
101            other => {
102                return Err(syn::Error::new(
103                    framework.span(),
104                    format!("unsupported framework: {}", other),
105                ));
106            }
107        };
108
109        Ok(MacroArgs {
110            framework,
111            spec_path,
112        })
113    }
114}
115
116#[derive(Clone, Copy)]
117enum Framework {
118    Axum,
119}
120
121/// Processes a trait and generates the output.
122struct TraitProcessor {
123    generator: oxapi_impl::Generator,
124    #[allow(dead_code)]
125    framework: Framework,
126    trait_item: ItemTrait,
127}
128
129impl TraitProcessor {
130    fn new(
131        generator: oxapi_impl::Generator,
132        framework: Framework,
133        trait_item: ItemTrait,
134    ) -> syn::Result<Self> {
135        Ok(Self {
136            generator,
137            framework,
138            trait_item,
139        })
140    }
141
142    fn generate(self) -> syn::Result<TokenStream2> {
143        use std::collections::HashMap;
144
145        let trait_name = &self.trait_item.ident;
146        let types_mod_name = syn::Ident::new(
147            &format!("{}_types", heck::AsSnakeCase(trait_name.to_string())),
148            trait_name.span(),
149        );
150
151        // Parse all method attributes and collect coverage info
152        let mut covered: HashMap<(oxapi_impl::HttpMethod, String), ()> = HashMap::new();
153        let mut map_method: Option<&syn::TraitItemFn> = None;
154        let mut handler_methods: Vec<(&syn::TraitItemFn, oxapi_impl::HttpMethod, String)> =
155            Vec::new();
156
157        for item in &self.trait_item.items {
158            if let syn::TraitItem::Fn(method) = item {
159                if let Some(attr) = find_oxapi_attr(&method.attrs)? {
160                    match attr {
161                        OxapiAttr::Map => {
162                            map_method = Some(method);
163                        }
164                        OxapiAttr::Route {
165                            method: http_method,
166                            path,
167                        } => {
168                            covered.insert((http_method, path.clone()), ());
169                            handler_methods.push((method, http_method, path));
170                        }
171                    }
172                } else {
173                    return Err(syn::Error::new_spanned(
174                        method,
175                        "all trait methods must have #[oxapi(...)] attribute",
176                    ));
177                }
178            }
179        }
180
181        // Validate coverage
182        self.generator
183            .validate_coverage(&covered)
184            .map_err(|e| syn::Error::new_spanned(&self.trait_item, e.to_string()))?;
185
186        // Generate types module
187        let types = self.generator.generate_types();
188        let responses = self.generator.generate_responses();
189
190        // Generate transformed methods
191        let mut transformed_methods = Vec::new();
192
193        // Generate map_routes if present
194        if let Some(map_fn) = map_method {
195            let router_gen = oxapi_impl::RouterGenerator::new(&self.generator);
196            let map_body = router_gen.generate_map_routes(
197                &handler_methods
198                    .iter()
199                    .map(|(m, method, path)| (m.sig.ident.clone(), *method, path.clone()))
200                    .collect::<Vec<_>>(),
201            );
202
203            let sig = &map_fn.sig;
204            transformed_methods.push(quote! {
205                #sig {
206                    #map_body
207                }
208            });
209        }
210
211        // Generate handler methods
212        let method_transformer =
213            oxapi_impl::MethodTransformer::new(&self.generator, &types_mod_name);
214        for (method, http_method, path) in &handler_methods {
215            let op = self
216                .generator
217                .get_operation(*http_method, path)
218                .ok_or_else(|| {
219                    syn::Error::new_spanned(
220                        method,
221                        format!("operation not found: {} {}", http_method, path),
222                    )
223                })?;
224
225            let transformed = method_transformer.transform(method, op)?;
226            transformed_methods.push(transformed);
227        }
228
229        // Generate the full output
230        let vis = &self.trait_item.vis;
231        let trait_attrs: Vec<_> = self
232            .trait_item
233            .attrs
234            .iter()
235            .filter(|a| !a.path().is_ident("oxapi"))
236            .collect();
237
238        let output = quote! {
239            #vis mod #types_mod_name {
240                use super::*;
241
242                #types
243                #responses
244            }
245
246            #(#trait_attrs)*
247            #vis trait #trait_name: 'static {
248                #(#transformed_methods)*
249            }
250        };
251
252        Ok(output)
253    }
254}
255
256/// Processes a module containing multiple traits.
257struct ModuleProcessor {
258    generator: oxapi_impl::Generator,
259    #[allow(dead_code)]
260    framework: Framework,
261    mod_item: ItemMod,
262}
263
264impl ModuleProcessor {
265    fn new(
266        generator: oxapi_impl::Generator,
267        framework: Framework,
268        mod_item: ItemMod,
269    ) -> syn::Result<Self> {
270        // Module must have content (not just a declaration)
271        if mod_item.content.is_none() {
272            return Err(syn::Error::new_spanned(
273                &mod_item,
274                "module must have inline content, not just a declaration",
275            ));
276        }
277        Ok(Self {
278            generator,
279            framework,
280            mod_item,
281        })
282    }
283
284    fn generate(self) -> syn::Result<TokenStream2> {
285        use std::collections::HashMap;
286
287        let mod_name = &self.mod_item.ident;
288        let mod_vis = &self.mod_item.vis;
289        let (_, content) = self.mod_item.content.as_ref().unwrap();
290
291        // Find all traits in the module
292        let mut traits: Vec<&ItemTrait> = Vec::new();
293        let mut other_items: Vec<&syn::Item> = Vec::new();
294
295        for item in content {
296            match item {
297                syn::Item::Trait(t) => traits.push(t),
298                other => other_items.push(other),
299            }
300        }
301
302        if traits.is_empty() {
303            return Err(syn::Error::new_spanned(
304                &self.mod_item,
305                "module must contain at least one trait",
306            ));
307        }
308
309        // Collect all operations across all traits for coverage validation
310        let mut all_covered: HashMap<(oxapi_impl::HttpMethod, String), ()> = HashMap::new();
311
312        // Process each trait and collect handler info
313        struct TraitInfo<'a> {
314            trait_item: &'a ItemTrait,
315            map_method: Option<&'a syn::TraitItemFn>,
316            handler_methods: Vec<(&'a syn::TraitItemFn, oxapi_impl::HttpMethod, String)>,
317        }
318
319        let mut trait_infos: Vec<TraitInfo> = Vec::new();
320
321        for trait_item in &traits {
322            let mut map_method: Option<&syn::TraitItemFn> = None;
323            let mut handler_methods: Vec<(&syn::TraitItemFn, oxapi_impl::HttpMethod, String)> =
324                Vec::new();
325
326            for item in &trait_item.items {
327                if let syn::TraitItem::Fn(method) = item {
328                    if let Some(attr) = find_oxapi_attr(&method.attrs)? {
329                        match attr {
330                            OxapiAttr::Map => {
331                                map_method = Some(method);
332                            }
333                            OxapiAttr::Route {
334                                method: http_method,
335                                path,
336                            } => {
337                                // Check for duplicates across traits
338                                let key = (http_method, path.clone());
339                                if all_covered.contains_key(&key) {
340                                    return Err(syn::Error::new_spanned(
341                                        method,
342                                        format!(
343                                            "operation {} {} is already defined in another trait",
344                                            http_method, path
345                                        ),
346                                    ));
347                                }
348                                all_covered.insert(key, ());
349                                handler_methods.push((method, http_method, path));
350                            }
351                        }
352                    } else {
353                        return Err(syn::Error::new_spanned(
354                            method,
355                            "all trait methods must have #[oxapi(...)] attribute",
356                        ));
357                    }
358                }
359            }
360
361            trait_infos.push(TraitInfo {
362                trait_item,
363                map_method,
364                handler_methods,
365            });
366        }
367
368        // Validate coverage against spec
369        self.generator
370            .validate_coverage(&all_covered)
371            .map_err(|e| syn::Error::new_spanned(&self.mod_item, e.to_string()))?;
372
373        // Generate shared types module
374        let types = self.generator.generate_types();
375        let responses = self.generator.generate_responses();
376
377        // Generate each trait
378        let types_mod_name = syn::Ident::new("types", proc_macro2::Span::call_site());
379        let method_transformer =
380            oxapi_impl::MethodTransformer::new(&self.generator, &types_mod_name);
381
382        let mut generated_traits = Vec::new();
383
384        for info in &trait_infos {
385            let trait_name = &info.trait_item.ident;
386            let _trait_vis = &info.trait_item.vis;
387            let trait_attrs: Vec<_> = info
388                .trait_item
389                .attrs
390                .iter()
391                .filter(|a| !a.path().is_ident("oxapi"))
392                .collect();
393
394            let mut transformed_methods = Vec::new();
395
396            // Generate map_routes if present
397            if let Some(map_fn) = info.map_method {
398                let router_gen = oxapi_impl::RouterGenerator::new(&self.generator);
399                let map_body = router_gen.generate_map_routes(
400                    &info
401                        .handler_methods
402                        .iter()
403                        .map(|(m, method, path)| (m.sig.ident.clone(), *method, path.clone()))
404                        .collect::<Vec<_>>(),
405                );
406
407                let sig = &map_fn.sig;
408                transformed_methods.push(quote! {
409                    #sig {
410                        #map_body
411                    }
412                });
413            }
414
415            // Generate handler methods
416            for (method, http_method, path) in &info.handler_methods {
417                let op = self
418                    .generator
419                    .get_operation(*http_method, path)
420                    .ok_or_else(|| {
421                        syn::Error::new_spanned(
422                            method,
423                            format!("operation not found: {} {}", http_method, path),
424                        )
425                    })?;
426
427                let transformed = method_transformer.transform(method, op)?;
428                transformed_methods.push(transformed);
429            }
430
431            // Traits in module are always pub (for external use)
432            generated_traits.push(quote! {
433                #(#trait_attrs)*
434                pub trait #trait_name: 'static {
435                    #(#transformed_methods)*
436                }
437            });
438        }
439
440        // Generate the module
441        let output = quote! {
442            #mod_vis mod #mod_name {
443                use super::*;
444
445                pub mod #types_mod_name {
446                    use super::*;
447
448                    #types
449                    #responses
450                }
451
452                #(#generated_traits)*
453            }
454        };
455
456        Ok(output)
457    }
458}
459
460/// Find and parse the #[oxapi(...)] attribute on a method.
461fn find_oxapi_attr(attrs: &[syn::Attribute]) -> syn::Result<Option<OxapiAttr>> {
462    for attr in attrs {
463        if attr.path().is_ident("oxapi") {
464            return attr.parse_args::<OxapiAttr>().map(Some);
465        }
466    }
467    Ok(None)
468}
469
470/// Parsed #[oxapi(...)] attribute.
471enum OxapiAttr {
472    Map,
473    Route {
474        method: oxapi_impl::HttpMethod,
475        path: String,
476    },
477}
478
479impl syn::parse::Parse for OxapiAttr {
480    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
481        let ident: Ident = input.parse()?;
482
483        match ident.to_string().as_str() {
484            "map" => Ok(OxapiAttr::Map),
485            _ => {
486                let method = parse_http_method(&ident)?;
487                input.parse::<Token![,]>()?;
488                let path: LitStr = input.parse()?;
489                Ok(OxapiAttr::Route {
490                    method,
491                    path: path.value(),
492                })
493            }
494        }
495    }
496}
497
498fn parse_http_method(ident: &Ident) -> syn::Result<oxapi_impl::HttpMethod> {
499    match ident.to_string().as_str() {
500        "get" => Ok(oxapi_impl::HttpMethod::Get),
501        "post" => Ok(oxapi_impl::HttpMethod::Post),
502        "put" => Ok(oxapi_impl::HttpMethod::Put),
503        "delete" => Ok(oxapi_impl::HttpMethod::Delete),
504        "patch" => Ok(oxapi_impl::HttpMethod::Patch),
505        "head" => Ok(oxapi_impl::HttpMethod::Head),
506        "options" => Ok(oxapi_impl::HttpMethod::Options),
507        other => Err(syn::Error::new(
508            ident.span(),
509            format!("unknown HTTP method: {}", other),
510        )),
511    }
512}