Skip to main content

microkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use std::fs;
4use std::path::PathBuf;
5use syn::{Item, ItemFn, LitStr, parse_macro_input};
6
7/// Discovers and registers all endpoint modules in a directory
8///
9/// This macro scans the specified directory for .rs files (excluding mod.rs),
10/// parses each file to find handler functions with #[utoipa::path] attributes,
11/// and automatically generates everything needed for registration
12///
13/// # Example
14///
15/// In your `endpoints/mod.rs`:
16/// ```rust
17/// microkit::discover_endpoints!("src/endpoints");
18/// ```
19///
20/// Then in your main lib.rs:
21/// ```rust
22/// endpoints::init_endpoints(&mut service)?;
23/// ```
24#[proc_macro]
25pub fn discover_endpoints(input: TokenStream) -> TokenStream {
26    let path_lit = parse_macro_input!(input as LitStr);
27    let endpoints_path = path_lit.value();
28
29    // Get the manifest directory (the crate root)
30    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
31
32    let full_path = PathBuf::from(manifest_dir).join(&endpoints_path);
33
34    // Structure to hold endpoint information
35    struct EndpointInfo {
36        module_name: String,
37        handlers: Vec<String>,
38    }
39
40    let mut endpoints = Vec::new();
41
42    if full_path.exists() && full_path.is_dir() {
43        match fs::read_dir(&full_path) {
44            Ok(entries) => {
45                for entry in entries.flatten() {
46                    let path = entry.path();
47
48                    if path.is_file()
49                        && let Some(file_name) = path.file_name()
50                        && let Some(file_name_str) = file_name.to_str()
51                    {
52                        // Skip mod.rs and only process .rs files
53                        if file_name_str.ends_with(".rs") && file_name_str != "mod.rs" {
54                            // Extract module name (remove .rs extension)
55                            let module_name = &file_name_str[..file_name_str.len() - 3];
56
57                            // Parse the file to find handler functions
58                            if let Ok(content) = fs::read_to_string(&path)
59                                && let Ok(syntax_tree) = syn::parse_file(&content)
60                            {
61                                let mut handlers = Vec::new();
62
63                                for item in syntax_tree.items {
64                                    if let Item::Fn(func) = item {
65                                        // Check if function has #[utoipa::path] attribute
66                                        if has_utoipa_path_attr(&func) {
67                                            handlers.push(func.sig.ident.to_string());
68                                        }
69                                    }
70                                }
71
72                                if !handlers.is_empty() {
73                                    endpoints.push(EndpointInfo {
74                                        module_name: module_name.to_string(),
75                                        handlers,
76                                    });
77                                }
78                            }
79                        }
80                    }
81                }
82            }
83            Err(e) => {
84                return syn::Error::new(
85                    path_lit.span(),
86                    format!("Failed to read directory '{}': {}", full_path.display(), e),
87                )
88                .to_compile_error()
89                .into();
90            }
91        }
92    } else {
93        return syn::Error::new(
94            path_lit.span(),
95            format!("Directory '{}' does not exist", full_path.display()),
96        )
97        .to_compile_error()
98        .into();
99    }
100
101    // Sort for consistent output
102    endpoints.sort_by(|a, b| a.module_name.cmp(&b.module_name));
103
104    if endpoints.is_empty() {
105        return syn::Error::new(
106            path_lit.span(),
107            format!("No endpoint modules found in '{}'", full_path.display()),
108        )
109        .to_compile_error()
110        .into();
111    }
112
113    // Generate module declarations
114    let module_idents: Vec<_> = endpoints
115        .iter()
116        .map(|ep| syn::Ident::new(&ep.module_name, proc_macro2::Span::call_site()))
117        .collect();
118
119    let module_decls = module_idents.iter().map(|ident| {
120        quote! {
121            pub mod #ident;
122        }
123    });
124
125    // Generate registration calls
126    let register_calls = endpoints.iter().map(|ep| {
127        let module_ident = syn::Ident::new(&ep.module_name, proc_macro2::Span::call_site());
128        let handler_idents: Vec<_> = ep
129            .handlers
130            .iter()
131            .map(|h| syn::Ident::new(h, proc_macro2::Span::call_site()))
132            .collect();
133
134        quote! {
135            if let Some(db) = &service.database {
136                let router = ::utoipa_axum::router::OpenApiRouter::new()
137                    .routes(::utoipa_axum::routes!(#(#module_ident::#handler_idents),*))
138                    .with_state(db.clone());
139                service.add_route(router);
140            }
141        }
142    });
143
144    // Generate the complete code
145    let expanded = quote! {
146        #(#module_decls)*
147
148        /// Automatically registers all discovered endpoint modules
149        ///
150        /// This function is generated by the `discover_endpoints!` macro and will
151        /// register all handler functions found in each endpoint module
152        pub fn init_endpoints(
153            service: &mut microkit::MicroKit
154        ) -> anyhow::Result<()> {
155            #(#register_calls)*
156            Ok(())
157        }
158    };
159
160    TokenStream::from(expanded)
161}
162
163/// Check if a function has a #[utoipa::path] attribute
164fn has_utoipa_path_attr(func: &ItemFn) -> bool {
165    for attr in &func.attrs {
166        // Check if the attribute path matches "utoipa::path"
167        if attr.path().segments.len() == 2 {
168            let segments: Vec<_> = attr.path().segments.iter().collect();
169            if segments[0].ident == "utoipa" && segments[1].ident == "path" {
170                return true;
171            }
172        }
173    }
174    false
175}
176
177/// Registers endpoint modules with a MicroKit service
178///
179/// # Example
180///
181/// ```rust
182/// let db = &service.database;
183/// microkit::register_endpoints!(service, db, endpoints => [users, posts]);
184/// ```
185#[proc_macro]
186pub fn register_endpoints(input: TokenStream) -> TokenStream {
187    use syn::{
188        Ident, Token,
189        parse::{Parse, ParseStream},
190        punctuated::Punctuated,
191    };
192
193    struct RegisterEndpointsInput {
194        service: Ident,
195        db: Ident,
196        module: Ident,
197        endpoints: Vec<Ident>,
198    }
199
200    impl Parse for RegisterEndpointsInput {
201        fn parse(input: ParseStream) -> syn::Result<Self> {
202            let service: Ident = input.parse()?;
203            input.parse::<Token![,]>()?;
204            let db: Ident = input.parse()?;
205            input.parse::<Token![,]>()?;
206            let module: Ident = input.parse()?;
207            input.parse::<Token![=>]>()?;
208
209            let content;
210            syn::bracketed!(content in input);
211            let endpoints_punct = Punctuated::<Ident, Token![,]>::parse_terminated(&content)?;
212            let endpoints = endpoints_punct.into_iter().collect();
213
214            Ok(RegisterEndpointsInput {
215                service,
216                db,
217                module,
218                endpoints,
219            })
220        }
221    }
222
223    let RegisterEndpointsInput {
224        service,
225        db,
226        module,
227        endpoints,
228    } = parse_macro_input!(input as RegisterEndpointsInput);
229
230    let register_calls = endpoints.iter().map(|name| {
231        quote! {
232            #service.add_route(#module::#name::api(&#db)?);
233        }
234    });
235
236    let expanded = quote! {
237        #(#register_calls)*
238    };
239
240    TokenStream::from(expanded)
241}