1use proc_macro::TokenStream;
2use quote::quote;
3use std::fs;
4use std::path::PathBuf;
5use syn::{Item, ItemFn, LitStr, parse_macro_input};
6
7#[proc_macro]
25pub fn discover_endpoints(input: TokenStream) -> TokenStream {
26 let path_lit = parse_macro_input!(input as LitStr);
27 let endpoints_path = path_lit.value();
28
29 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
31
32 let full_path = PathBuf::from(manifest_dir).join(&endpoints_path);
33
34 struct EndpointInfo {
36 module_name: String,
37 handlers: Vec<String>,
38 }
39
40 let mut endpoints = Vec::new();
41
42 if full_path.exists() && full_path.is_dir() {
43 match fs::read_dir(&full_path) {
44 Ok(entries) => {
45 for entry in entries.flatten() {
46 let path = entry.path();
47
48 if path.is_file()
49 && let Some(file_name) = path.file_name()
50 && let Some(file_name_str) = file_name.to_str()
51 {
52 if file_name_str.ends_with(".rs") && file_name_str != "mod.rs" {
54 let module_name = &file_name_str[..file_name_str.len() - 3];
56
57 if let Ok(content) = fs::read_to_string(&path)
59 && let Ok(syntax_tree) = syn::parse_file(&content)
60 {
61 let mut handlers = Vec::new();
62
63 for item in syntax_tree.items {
64 if let Item::Fn(func) = item {
65 if has_utoipa_path_attr(&func) {
67 handlers.push(func.sig.ident.to_string());
68 }
69 }
70 }
71
72 if !handlers.is_empty() {
73 endpoints.push(EndpointInfo {
74 module_name: module_name.to_string(),
75 handlers,
76 });
77 }
78 }
79 }
80 }
81 }
82 }
83 Err(e) => {
84 return syn::Error::new(
85 path_lit.span(),
86 format!("Failed to read directory '{}': {}", full_path.display(), e),
87 )
88 .to_compile_error()
89 .into();
90 }
91 }
92 } else {
93 return syn::Error::new(
94 path_lit.span(),
95 format!("Directory '{}' does not exist", full_path.display()),
96 )
97 .to_compile_error()
98 .into();
99 }
100
101 endpoints.sort_by(|a, b| a.module_name.cmp(&b.module_name));
103
104 if endpoints.is_empty() {
105 return syn::Error::new(
106 path_lit.span(),
107 format!("No endpoint modules found in '{}'", full_path.display()),
108 )
109 .to_compile_error()
110 .into();
111 }
112
113 let module_idents: Vec<_> = endpoints
115 .iter()
116 .map(|ep| syn::Ident::new(&ep.module_name, proc_macro2::Span::call_site()))
117 .collect();
118
119 let module_decls = module_idents.iter().map(|ident| {
120 quote! {
121 pub mod #ident;
122 }
123 });
124
125 let register_calls = endpoints.iter().map(|ep| {
127 let module_ident = syn::Ident::new(&ep.module_name, proc_macro2::Span::call_site());
128 let handler_idents: Vec<_> = ep
129 .handlers
130 .iter()
131 .map(|h| syn::Ident::new(h, proc_macro2::Span::call_site()))
132 .collect();
133
134 quote! {
135 if let Some(db) = &service.database {
136 let router = ::utoipa_axum::router::OpenApiRouter::new()
137 .routes(::utoipa_axum::routes!(#(#module_ident::#handler_idents),*))
138 .with_state(db.clone());
139 service.add_route(router);
140 }
141 }
142 });
143
144 let expanded = quote! {
146 #(#module_decls)*
147
148 pub fn init_endpoints(
153 service: &mut microkit::MicroKit
154 ) -> anyhow::Result<()> {
155 #(#register_calls)*
156 Ok(())
157 }
158 };
159
160 TokenStream::from(expanded)
161}
162
163fn has_utoipa_path_attr(func: &ItemFn) -> bool {
165 for attr in &func.attrs {
166 if attr.path().segments.len() == 2 {
168 let segments: Vec<_> = attr.path().segments.iter().collect();
169 if segments[0].ident == "utoipa" && segments[1].ident == "path" {
170 return true;
171 }
172 }
173 }
174 false
175}
176
177#[proc_macro]
186pub fn register_endpoints(input: TokenStream) -> TokenStream {
187 use syn::{
188 Ident, Token,
189 parse::{Parse, ParseStream},
190 punctuated::Punctuated,
191 };
192
193 struct RegisterEndpointsInput {
194 service: Ident,
195 db: Ident,
196 module: Ident,
197 endpoints: Vec<Ident>,
198 }
199
200 impl Parse for RegisterEndpointsInput {
201 fn parse(input: ParseStream) -> syn::Result<Self> {
202 let service: Ident = input.parse()?;
203 input.parse::<Token![,]>()?;
204 let db: Ident = input.parse()?;
205 input.parse::<Token![,]>()?;
206 let module: Ident = input.parse()?;
207 input.parse::<Token![=>]>()?;
208
209 let content;
210 syn::bracketed!(content in input);
211 let endpoints_punct = Punctuated::<Ident, Token![,]>::parse_terminated(&content)?;
212 let endpoints = endpoints_punct.into_iter().collect();
213
214 Ok(RegisterEndpointsInput {
215 service,
216 db,
217 module,
218 endpoints,
219 })
220 }
221 }
222
223 let RegisterEndpointsInput {
224 service,
225 db,
226 module,
227 endpoints,
228 } = parse_macro_input!(input as RegisterEndpointsInput);
229
230 let register_calls = endpoints.iter().map(|name| {
231 quote! {
232 #service.add_route(#module::#name::api(&#db)?);
233 }
234 });
235
236 let expanded = quote! {
237 #(#register_calls)*
238 };
239
240 TokenStream::from(expanded)
241}