1use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{Ident, ItemMod, ItemTrait, LitStr, Token};
7
8#[proc_macro_attribute]
46pub fn oxapi(attr: TokenStream, item: TokenStream) -> TokenStream {
47 match do_oxapi(attr.into(), item.into()) {
48 Ok(tokens) => tokens.into(),
49 Err(err) => err.to_compile_error().into(),
50 }
51}
52
53fn do_oxapi(attr: TokenStream2, item: TokenStream2) -> syn::Result<TokenStream2> {
54 let args = syn::parse2::<MacroArgs>(attr)?;
56
57 let spec_path = resolve_spec_path(&args.spec_path)?;
59 let generator = oxapi_impl::Generator::from_file(&spec_path).map_err(|e| {
60 syn::Error::new(args.spec_path.span(), format!("failed to load spec: {}", e))
61 })?;
62
63 if let Ok(trait_item) = syn::parse2::<ItemTrait>(item.clone()) {
65 let processor = TraitProcessor::new(generator, args.framework, trait_item)?;
66 processor.generate()
67 } else if let Ok(mod_item) = syn::parse2::<ItemMod>(item) {
68 let processor = ModuleProcessor::new(generator, args.framework, mod_item)?;
69 processor.generate()
70 } else {
71 Err(syn::Error::new(
72 proc_macro2::Span::call_site(),
73 "expected trait or mod item",
74 ))
75 }
76}
77
78fn resolve_spec_path(lit: &LitStr) -> syn::Result<std::path::PathBuf> {
80 let dir = std::env::var("CARGO_MANIFEST_DIR")
81 .map_err(|_| syn::Error::new(lit.span(), "CARGO_MANIFEST_DIR not set"))?;
82
83 let path = std::path::Path::new(&dir).join(lit.value());
84 Ok(path)
85}
86
87struct MacroArgs {
89 framework: Framework,
90 spec_path: LitStr,
91}
92
93impl syn::parse::Parse for MacroArgs {
94 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
95 let framework: Ident = input.parse()?;
96 input.parse::<Token![,]>()?;
97 let spec_path: LitStr = input.parse()?;
98
99 let framework = match framework.to_string().as_str() {
100 "axum" => Framework::Axum,
101 other => {
102 return Err(syn::Error::new(
103 framework.span(),
104 format!("unsupported framework: {}", other),
105 ));
106 }
107 };
108
109 Ok(MacroArgs {
110 framework,
111 spec_path,
112 })
113 }
114}
115
116#[derive(Clone, Copy)]
117enum Framework {
118 Axum,
119}
120
121struct TraitProcessor {
123 generator: oxapi_impl::Generator,
124 #[allow(dead_code)]
125 framework: Framework,
126 trait_item: ItemTrait,
127}
128
129impl TraitProcessor {
130 fn new(
131 generator: oxapi_impl::Generator,
132 framework: Framework,
133 trait_item: ItemTrait,
134 ) -> syn::Result<Self> {
135 Ok(Self {
136 generator,
137 framework,
138 trait_item,
139 })
140 }
141
142 fn generate(self) -> syn::Result<TokenStream2> {
143 use std::collections::HashMap;
144
145 let trait_name = &self.trait_item.ident;
146 let types_mod_name = syn::Ident::new(
147 &format!("{}_types", heck::AsSnakeCase(trait_name.to_string())),
148 trait_name.span(),
149 );
150
151 let mut covered: HashMap<(oxapi_impl::HttpMethod, String), ()> = HashMap::new();
153 let mut map_method: Option<&syn::TraitItemFn> = None;
154 let mut handler_methods: Vec<(&syn::TraitItemFn, oxapi_impl::HttpMethod, String)> =
155 Vec::new();
156
157 for item in &self.trait_item.items {
158 if let syn::TraitItem::Fn(method) = item {
159 if let Some(attr) = find_oxapi_attr(&method.attrs)? {
160 match attr {
161 OxapiAttr::Map => {
162 map_method = Some(method);
163 }
164 OxapiAttr::Route {
165 method: http_method,
166 path,
167 } => {
168 covered.insert((http_method, path.clone()), ());
169 handler_methods.push((method, http_method, path));
170 }
171 }
172 } else {
173 return Err(syn::Error::new_spanned(
174 method,
175 "all trait methods must have #[oxapi(...)] attribute",
176 ));
177 }
178 }
179 }
180
181 self.generator
183 .validate_coverage(&covered)
184 .map_err(|e| syn::Error::new_spanned(&self.trait_item, e.to_string()))?;
185
186 let types = self.generator.generate_types();
188 let responses = self.generator.generate_responses();
189
190 let mut transformed_methods = Vec::new();
192
193 if let Some(map_fn) = map_method {
195 let router_gen = oxapi_impl::RouterGenerator::new(&self.generator);
196 let map_body = router_gen.generate_map_routes(
197 &handler_methods
198 .iter()
199 .map(|(m, method, path)| (m.sig.ident.clone(), *method, path.clone()))
200 .collect::<Vec<_>>(),
201 );
202
203 let sig = &map_fn.sig;
204 transformed_methods.push(quote! {
205 #sig {
206 #map_body
207 }
208 });
209 }
210
211 let method_transformer =
213 oxapi_impl::MethodTransformer::new(&self.generator, &types_mod_name);
214 for (method, http_method, path) in &handler_methods {
215 let op = self
216 .generator
217 .get_operation(*http_method, path)
218 .ok_or_else(|| {
219 syn::Error::new_spanned(
220 method,
221 format!("operation not found: {} {}", http_method, path),
222 )
223 })?;
224
225 let transformed = method_transformer.transform(method, op)?;
226 transformed_methods.push(transformed);
227 }
228
229 let vis = &self.trait_item.vis;
231 let trait_attrs: Vec<_> = self
232 .trait_item
233 .attrs
234 .iter()
235 .filter(|a| !a.path().is_ident("oxapi"))
236 .collect();
237
238 let output = quote! {
239 #vis mod #types_mod_name {
240 use super::*;
241
242 #types
243 #responses
244 }
245
246 #(#trait_attrs)*
247 #vis trait #trait_name: 'static {
248 #(#transformed_methods)*
249 }
250 };
251
252 Ok(output)
253 }
254}
255
256struct ModuleProcessor {
258 generator: oxapi_impl::Generator,
259 #[allow(dead_code)]
260 framework: Framework,
261 mod_item: ItemMod,
262}
263
264impl ModuleProcessor {
265 fn new(
266 generator: oxapi_impl::Generator,
267 framework: Framework,
268 mod_item: ItemMod,
269 ) -> syn::Result<Self> {
270 if mod_item.content.is_none() {
272 return Err(syn::Error::new_spanned(
273 &mod_item,
274 "module must have inline content, not just a declaration",
275 ));
276 }
277 Ok(Self {
278 generator,
279 framework,
280 mod_item,
281 })
282 }
283
284 fn generate(self) -> syn::Result<TokenStream2> {
285 use std::collections::HashMap;
286
287 let mod_name = &self.mod_item.ident;
288 let mod_vis = &self.mod_item.vis;
289 let (_, content) = self.mod_item.content.as_ref().unwrap();
290
291 let mut traits: Vec<&ItemTrait> = Vec::new();
293 let mut other_items: Vec<&syn::Item> = Vec::new();
294
295 for item in content {
296 match item {
297 syn::Item::Trait(t) => traits.push(t),
298 other => other_items.push(other),
299 }
300 }
301
302 if traits.is_empty() {
303 return Err(syn::Error::new_spanned(
304 &self.mod_item,
305 "module must contain at least one trait",
306 ));
307 }
308
309 let mut all_covered: HashMap<(oxapi_impl::HttpMethod, String), ()> = HashMap::new();
311
312 struct TraitInfo<'a> {
314 trait_item: &'a ItemTrait,
315 map_method: Option<&'a syn::TraitItemFn>,
316 handler_methods: Vec<(&'a syn::TraitItemFn, oxapi_impl::HttpMethod, String)>,
317 }
318
319 let mut trait_infos: Vec<TraitInfo> = Vec::new();
320
321 for trait_item in &traits {
322 let mut map_method: Option<&syn::TraitItemFn> = None;
323 let mut handler_methods: Vec<(&syn::TraitItemFn, oxapi_impl::HttpMethod, String)> =
324 Vec::new();
325
326 for item in &trait_item.items {
327 if let syn::TraitItem::Fn(method) = item {
328 if let Some(attr) = find_oxapi_attr(&method.attrs)? {
329 match attr {
330 OxapiAttr::Map => {
331 map_method = Some(method);
332 }
333 OxapiAttr::Route {
334 method: http_method,
335 path,
336 } => {
337 let key = (http_method, path.clone());
339 if all_covered.contains_key(&key) {
340 return Err(syn::Error::new_spanned(
341 method,
342 format!(
343 "operation {} {} is already defined in another trait",
344 http_method, path
345 ),
346 ));
347 }
348 all_covered.insert(key, ());
349 handler_methods.push((method, http_method, path));
350 }
351 }
352 } else {
353 return Err(syn::Error::new_spanned(
354 method,
355 "all trait methods must have #[oxapi(...)] attribute",
356 ));
357 }
358 }
359 }
360
361 trait_infos.push(TraitInfo {
362 trait_item,
363 map_method,
364 handler_methods,
365 });
366 }
367
368 self.generator
370 .validate_coverage(&all_covered)
371 .map_err(|e| syn::Error::new_spanned(&self.mod_item, e.to_string()))?;
372
373 let types = self.generator.generate_types();
375 let responses = self.generator.generate_responses();
376
377 let types_mod_name = syn::Ident::new("types", proc_macro2::Span::call_site());
379 let method_transformer =
380 oxapi_impl::MethodTransformer::new(&self.generator, &types_mod_name);
381
382 let mut generated_traits = Vec::new();
383
384 for info in &trait_infos {
385 let trait_name = &info.trait_item.ident;
386 let _trait_vis = &info.trait_item.vis;
387 let trait_attrs: Vec<_> = info
388 .trait_item
389 .attrs
390 .iter()
391 .filter(|a| !a.path().is_ident("oxapi"))
392 .collect();
393
394 let mut transformed_methods = Vec::new();
395
396 if let Some(map_fn) = info.map_method {
398 let router_gen = oxapi_impl::RouterGenerator::new(&self.generator);
399 let map_body = router_gen.generate_map_routes(
400 &info
401 .handler_methods
402 .iter()
403 .map(|(m, method, path)| (m.sig.ident.clone(), *method, path.clone()))
404 .collect::<Vec<_>>(),
405 );
406
407 let sig = &map_fn.sig;
408 transformed_methods.push(quote! {
409 #sig {
410 #map_body
411 }
412 });
413 }
414
415 for (method, http_method, path) in &info.handler_methods {
417 let op = self
418 .generator
419 .get_operation(*http_method, path)
420 .ok_or_else(|| {
421 syn::Error::new_spanned(
422 method,
423 format!("operation not found: {} {}", http_method, path),
424 )
425 })?;
426
427 let transformed = method_transformer.transform(method, op)?;
428 transformed_methods.push(transformed);
429 }
430
431 generated_traits.push(quote! {
433 #(#trait_attrs)*
434 pub trait #trait_name: 'static {
435 #(#transformed_methods)*
436 }
437 });
438 }
439
440 let output = quote! {
442 #mod_vis mod #mod_name {
443 use super::*;
444
445 pub mod #types_mod_name {
446 use super::*;
447
448 #types
449 #responses
450 }
451
452 #(#generated_traits)*
453 }
454 };
455
456 Ok(output)
457 }
458}
459
460fn find_oxapi_attr(attrs: &[syn::Attribute]) -> syn::Result<Option<OxapiAttr>> {
462 for attr in attrs {
463 if attr.path().is_ident("oxapi") {
464 return attr.parse_args::<OxapiAttr>().map(Some);
465 }
466 }
467 Ok(None)
468}
469
470enum OxapiAttr {
472 Map,
473 Route {
474 method: oxapi_impl::HttpMethod,
475 path: String,
476 },
477}
478
479impl syn::parse::Parse for OxapiAttr {
480 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
481 let ident: Ident = input.parse()?;
482
483 match ident.to_string().as_str() {
484 "map" => Ok(OxapiAttr::Map),
485 _ => {
486 let method = parse_http_method(&ident)?;
487 input.parse::<Token![,]>()?;
488 let path: LitStr = input.parse()?;
489 Ok(OxapiAttr::Route {
490 method,
491 path: path.value(),
492 })
493 }
494 }
495 }
496}
497
498fn parse_http_method(ident: &Ident) -> syn::Result<oxapi_impl::HttpMethod> {
499 match ident.to_string().as_str() {
500 "get" => Ok(oxapi_impl::HttpMethod::Get),
501 "post" => Ok(oxapi_impl::HttpMethod::Post),
502 "put" => Ok(oxapi_impl::HttpMethod::Put),
503 "delete" => Ok(oxapi_impl::HttpMethod::Delete),
504 "patch" => Ok(oxapi_impl::HttpMethod::Patch),
505 "head" => Ok(oxapi_impl::HttpMethod::Head),
506 "options" => Ok(oxapi_impl::HttpMethod::Options),
507 other => Err(syn::Error::new(
508 ident.span(),
509 format!("unknown HTTP method: {}", other),
510 )),
511 }
512}