Skip to main content

microkit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use std::fs;
4use std::path::PathBuf;
5use syn::{parse_macro_input, Item, ItemFn, LitStr};
6
7/// Discovers and registers all endpoint modules in a directory
8///
9/// This macro scans the specified directory for .rs files (excluding mod.rs),
10/// parses each file to find handler functions with #[utoipa::path] attributes,
11/// and automatically generates everything needed for registration
12///
13/// # Example
14///
15/// In your `endpoints/mod.rs`:
16/// ```rust
17/// microkit::discover_endpoints!("src/endpoints");
18/// ```
19///
20/// Then in your main lib.rs:
21/// ```rust
22/// endpoints::init_endpoints(&mut service)?;
23/// ```
24#[proc_macro]
25pub fn discover_endpoints(input: TokenStream) -> TokenStream {
26    let path_lit = parse_macro_input!(input as LitStr);
27    let endpoints_path = path_lit.value();
28
29    // Get the manifest directory (the crate root)
30    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
31
32    let full_path = PathBuf::from(manifest_dir).join(&endpoints_path);
33
34    // Structure to hold endpoint information
35    struct EndpointInfo {
36        module_name: String,
37        handlers: Vec<String>,
38    }
39
40    let mut endpoints = Vec::new();
41
42    if full_path.exists() && full_path.is_dir() {
43        match fs::read_dir(&full_path) {
44            Ok(entries) => {
45                for entry in entries.flatten() {
46                    let path = entry.path();
47
48                    if path.is_file() {
49                        if let Some(file_name) = path.file_name() {
50                            if let Some(file_name_str) = file_name.to_str() {
51                                // Skip mod.rs and only process .rs files
52                                if file_name_str.ends_with(".rs") && file_name_str != "mod.rs" {
53                                    // Extract module name (remove .rs extension)
54                                    let module_name = &file_name_str[..file_name_str.len() - 3];
55
56                                    // Parse the file to find handler functions
57                                    if let Ok(content) = fs::read_to_string(&path) {
58                                        if let Ok(syntax_tree) = syn::parse_file(&content) {
59                                            let mut handlers = Vec::new();
60
61                                            for item in syntax_tree.items {
62                                                if let Item::Fn(func) = item {
63                                                    // Check if function has #[utoipa::path] attribute
64                                                    if has_utoipa_path_attr(&func) {
65                                                        handlers.push(func.sig.ident.to_string());
66                                                    }
67                                                }
68                                            }
69
70                                            if !handlers.is_empty() {
71                                                endpoints.push(EndpointInfo {
72                                                    module_name: module_name.to_string(),
73                                                    handlers,
74                                                });
75                                            }
76                                        }
77                                    }
78                                }
79                            }
80                        }
81                    }
82                }
83            }
84            Err(e) => {
85                return syn::Error::new(
86                    path_lit.span(),
87                    format!("Failed to read directory '{}': {}", full_path.display(), e),
88                )
89                .to_compile_error()
90                .into();
91            }
92        }
93    } else {
94        return syn::Error::new(
95            path_lit.span(),
96            format!("Directory '{}' does not exist", full_path.display()),
97        )
98        .to_compile_error()
99        .into();
100    }
101
102    // Sort for consistent output
103    endpoints.sort_by(|a, b| a.module_name.cmp(&b.module_name));
104
105    if endpoints.is_empty() {
106        return syn::Error::new(
107            path_lit.span(),
108            format!("No endpoint modules found in '{}'", full_path.display()),
109        )
110        .to_compile_error()
111        .into();
112    }
113
114    // Generate module declarations
115    let module_idents: Vec<_> = endpoints
116        .iter()
117        .map(|ep| syn::Ident::new(&ep.module_name, proc_macro2::Span::call_site()))
118        .collect();
119
120    let module_decls = module_idents.iter().map(|ident| {
121        quote! {
122            pub mod #ident;
123        }
124    });
125
126    // Generate registration calls
127    let register_calls = endpoints.iter().map(|ep| {
128        let module_ident = syn::Ident::new(&ep.module_name, proc_macro2::Span::call_site());
129        let handler_idents: Vec<_> = ep
130            .handlers
131            .iter()
132            .map(|h| syn::Ident::new(h, proc_macro2::Span::call_site()))
133            .collect();
134
135        quote! {
136            if let Some(db) = &service.database {
137                let router = ::utoipa_axum::router::OpenApiRouter::new()
138                    .routes(::utoipa_axum::routes!(#(#module_ident::#handler_idents),*))
139                    .with_state(db.clone());
140                service.add_route(router);
141            }
142        }
143    });
144
145    // Generate the complete code
146    let expanded = quote! {
147        #(#module_decls)*
148
149        /// Automatically registers all discovered endpoint modules
150        ///
151        /// This function is generated by the `discover_endpoints!` macro and will
152        /// register all handler functions found in each endpoint module
153        pub fn init_endpoints(
154            service: &mut microkit::MicroKit
155        ) -> anyhow::Result<()> {
156            #(#register_calls)*
157            Ok(())
158        }
159    };
160
161    TokenStream::from(expanded)
162}
163
164/// Check if a function has a #[utoipa::path] attribute
165fn has_utoipa_path_attr(func: &ItemFn) -> bool {
166    for attr in &func.attrs {
167        // Check if the attribute path matches "utoipa::path"
168        if attr.path().segments.len() == 2 {
169            let segments: Vec<_> = attr.path().segments.iter().collect();
170            if segments[0].ident == "utoipa" && segments[1].ident == "path" {
171                return true;
172            }
173        }
174    }
175    false
176}
177
178/// Registers endpoint modules with a MicroKit service
179///
180/// # Example
181///
182/// ```rust
183/// let db = &service.database;
184/// microkit::register_endpoints!(service, db, endpoints => [users, posts]);
185/// ```
186#[proc_macro]
187pub fn register_endpoints(input: TokenStream) -> TokenStream {
188    use syn::{
189        parse::{Parse, ParseStream},
190        punctuated::Punctuated,
191        Ident, Token,
192    };
193
194    struct RegisterEndpointsInput {
195        service: Ident,
196        db: Ident,
197        module: Ident,
198        endpoints: Vec<Ident>,
199    }
200
201    impl Parse for RegisterEndpointsInput {
202        fn parse(input: ParseStream) -> syn::Result<Self> {
203            let service: Ident = input.parse()?;
204            input.parse::<Token![,]>()?;
205            let db: Ident = input.parse()?;
206            input.parse::<Token![,]>()?;
207            let module: Ident = input.parse()?;
208            input.parse::<Token![=>]>()?;
209
210            let content;
211            syn::bracketed!(content in input);
212            let endpoints_punct = Punctuated::<Ident, Token![,]>::parse_terminated(&content)?;
213            let endpoints = endpoints_punct.into_iter().collect();
214
215            Ok(RegisterEndpointsInput {
216                service,
217                db,
218                module,
219                endpoints,
220            })
221        }
222    }
223
224    let RegisterEndpointsInput {
225        service,
226        db,
227        module,
228        endpoints,
229    } = parse_macro_input!(input as RegisterEndpointsInput);
230
231    let register_calls = endpoints.iter().map(|name| {
232        quote! {
233            #service.add_route(#module::#name::api(&#db)?);
234        }
235    });
236
237    let expanded = quote! {
238        #(#register_calls)*
239    };
240
241    TokenStream::from(expanded)
242}