1use proc_macro::TokenStream;
2use quote::quote;
3use std::fs;
4use std::path::PathBuf;
5use syn::{parse_macro_input, Item, ItemFn, LitStr};
6
7#[proc_macro]
25pub fn discover_endpoints(input: TokenStream) -> TokenStream {
26 let path_lit = parse_macro_input!(input as LitStr);
27 let endpoints_path = path_lit.value();
28
29 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
31
32 let full_path = PathBuf::from(manifest_dir).join(&endpoints_path);
33
34 struct EndpointInfo {
36 module_name: String,
37 handlers: Vec<String>,
38 }
39
40 let mut endpoints = Vec::new();
41
42 if full_path.exists() && full_path.is_dir() {
43 match fs::read_dir(&full_path) {
44 Ok(entries) => {
45 for entry in entries.flatten() {
46 let path = entry.path();
47
48 if path.is_file() {
49 if let Some(file_name) = path.file_name() {
50 if let Some(file_name_str) = file_name.to_str() {
51 if file_name_str.ends_with(".rs") && file_name_str != "mod.rs" {
53 let module_name = &file_name_str[..file_name_str.len() - 3];
55
56 if let Ok(content) = fs::read_to_string(&path) {
58 if let Ok(syntax_tree) = syn::parse_file(&content) {
59 let mut handlers = Vec::new();
60
61 for item in syntax_tree.items {
62 if let Item::Fn(func) = item {
63 if has_utoipa_path_attr(&func) {
65 handlers.push(func.sig.ident.to_string());
66 }
67 }
68 }
69
70 if !handlers.is_empty() {
71 endpoints.push(EndpointInfo {
72 module_name: module_name.to_string(),
73 handlers,
74 });
75 }
76 }
77 }
78 }
79 }
80 }
81 }
82 }
83 }
84 Err(e) => {
85 return syn::Error::new(
86 path_lit.span(),
87 format!("Failed to read directory '{}': {}", full_path.display(), e),
88 )
89 .to_compile_error()
90 .into();
91 }
92 }
93 } else {
94 return syn::Error::new(
95 path_lit.span(),
96 format!("Directory '{}' does not exist", full_path.display()),
97 )
98 .to_compile_error()
99 .into();
100 }
101
102 endpoints.sort_by(|a, b| a.module_name.cmp(&b.module_name));
104
105 if endpoints.is_empty() {
106 return syn::Error::new(
107 path_lit.span(),
108 format!("No endpoint modules found in '{}'", full_path.display()),
109 )
110 .to_compile_error()
111 .into();
112 }
113
114 let module_idents: Vec<_> = endpoints
116 .iter()
117 .map(|ep| syn::Ident::new(&ep.module_name, proc_macro2::Span::call_site()))
118 .collect();
119
120 let module_decls = module_idents.iter().map(|ident| {
121 quote! {
122 pub mod #ident;
123 }
124 });
125
126 let register_calls = endpoints.iter().map(|ep| {
128 let module_ident = syn::Ident::new(&ep.module_name, proc_macro2::Span::call_site());
129 let handler_idents: Vec<_> = ep
130 .handlers
131 .iter()
132 .map(|h| syn::Ident::new(h, proc_macro2::Span::call_site()))
133 .collect();
134
135 quote! {
136 if let Some(db) = &service.database {
137 let router = ::utoipa_axum::router::OpenApiRouter::new()
138 .routes(::utoipa_axum::routes!(#(#module_ident::#handler_idents),*))
139 .with_state(db.clone());
140 service.add_route(router);
141 }
142 }
143 });
144
145 let expanded = quote! {
147 #(#module_decls)*
148
149 pub fn init_endpoints(
154 service: &mut microkit::MicroKit
155 ) -> anyhow::Result<()> {
156 #(#register_calls)*
157 Ok(())
158 }
159 };
160
161 TokenStream::from(expanded)
162}
163
164fn has_utoipa_path_attr(func: &ItemFn) -> bool {
166 for attr in &func.attrs {
167 if attr.path().segments.len() == 2 {
169 let segments: Vec<_> = attr.path().segments.iter().collect();
170 if segments[0].ident == "utoipa" && segments[1].ident == "path" {
171 return true;
172 }
173 }
174 }
175 false
176}
177
178#[proc_macro]
187pub fn register_endpoints(input: TokenStream) -> TokenStream {
188 use syn::{
189 parse::{Parse, ParseStream},
190 punctuated::Punctuated,
191 Ident, Token,
192 };
193
194 struct RegisterEndpointsInput {
195 service: Ident,
196 db: Ident,
197 module: Ident,
198 endpoints: Vec<Ident>,
199 }
200
201 impl Parse for RegisterEndpointsInput {
202 fn parse(input: ParseStream) -> syn::Result<Self> {
203 let service: Ident = input.parse()?;
204 input.parse::<Token![,]>()?;
205 let db: Ident = input.parse()?;
206 input.parse::<Token![,]>()?;
207 let module: Ident = input.parse()?;
208 input.parse::<Token![=>]>()?;
209
210 let content;
211 syn::bracketed!(content in input);
212 let endpoints_punct = Punctuated::<Ident, Token![,]>::parse_terminated(&content)?;
213 let endpoints = endpoints_punct.into_iter().collect();
214
215 Ok(RegisterEndpointsInput {
216 service,
217 db,
218 module,
219 endpoints,
220 })
221 }
222 }
223
224 let RegisterEndpointsInput {
225 service,
226 db,
227 module,
228 endpoints,
229 } = parse_macro_input!(input as RegisterEndpointsInput);
230
231 let register_calls = endpoints.iter().map(|name| {
232 quote! {
233 #service.add_route(#module::#name::api(&#db)?);
234 }
235 });
236
237 let expanded = quote! {
238 #(#register_calls)*
239 };
240
241 TokenStream::from(expanded)
242}