1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::quote;
5use syn::punctuated::Punctuated;
6use syn::spanned::Spanned;
7use syn::{
8 parse::Parse, parse_macro_input, parse_quote, Attribute, Error, FnArg, GenericArgument, Ident,
9 Item, ItemEnum, ItemFn, ItemStruct, LitStr, Pat, PatTupleStruct, PathArguments, PathSegment,
10 Token, Type,
11};
12
13struct RouteArgs {
14 method: RouteMethod,
15 path: LitStr,
16 auto_validate: bool,
17}
18
19#[derive(Clone, Copy, PartialEq, Eq, Debug)]
20enum RouteMethod {
21 Get,
22 Post,
23 Put,
24 Patch,
25 Delete,
26}
27
28#[derive(Clone, Copy)]
29enum ExtractorKind {
30 Json,
31 Query,
32 Path,
33}
34
35impl ExtractorKind {
36 fn parse(name: &str) -> Option<Self> {
37 match name {
38 "Json" => Some(Self::Json),
39 "Query" => Some(Self::Query),
40 "Path" => Some(Self::Path),
41 _ => None,
42 }
43 }
44
45 fn source_ident(self) -> &'static str {
46 match self {
47 Self::Json => "Json",
48 Self::Query => "Query",
49 Self::Path => "Path",
50 }
51 }
52
53 fn validated_ident(self) -> &'static str {
54 match self {
55 Self::Json => "ValidatedJson",
56 Self::Query => "ValidatedQuery",
57 Self::Path => "ValidatedPath",
58 }
59 }
60}
61
62impl Parse for RouteArgs {
63 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
64 let method_ident: Ident = input.parse()?;
65 let method = match method_ident.to_string().as_str() {
66 "get" => RouteMethod::Get,
67 "post" => RouteMethod::Post,
68 "put" => RouteMethod::Put,
69 "patch" => RouteMethod::Patch,
70 "delete" => RouteMethod::Delete,
71 _ => {
72 return Err(Error::new(
73 method_ident.span(),
74 "unsupported method; use one of: get, post, put, patch, delete",
75 ))
76 }
77 };
78 if input.is_empty() {
79 return Err(Error::new(method_ident.span(), "route path is required"));
80 }
81 input.parse::<Token![,]>()?;
82
83 let path: LitStr = input
84 .parse()
85 .map_err(|_| Error::new(input.span(), "route path must be a string literal"))?;
86
87 let mut auto_validate = false;
88 while !input.is_empty() {
89 input.parse::<Token![,]>()?;
90 let flag: Ident = input.parse()?;
91 match flag.to_string().as_str() {
92 "auto_validate" => auto_validate = true,
93 _ => return Err(Error::new(flag.span(), format!("unknown flag `{}`", flag))),
94 }
95 }
96
97 Ok(Self {
98 method,
99 path,
100 auto_validate,
101 })
102 }
103}
104
105#[proc_macro_attribute]
106pub fn route(args: TokenStream, item: TokenStream) -> TokenStream {
107 let parsed = parse_macro_input!(args as RouteArgs);
108 let mut item_fn = parse_macro_input!(item as ItemFn);
109
110 if parsed.auto_validate {
111 let meld_crate = match resolve_meld_server_path() {
112 Ok(path) => path,
113 Err(err) => return err.to_compile_error().into(),
114 };
115 if let Err(err) = apply_auto_validate(&mut item_fn, &meld_crate) {
116 return err.to_compile_error().into();
117 }
118 }
119
120 let _ = (parsed.method, parsed.path);
121
122 TokenStream::from(quote! { #item_fn })
123}
124
125#[proc_macro_attribute]
126pub fn dto(args: TokenStream, item: TokenStream) -> TokenStream {
127 if !args.is_empty() {
128 return Error::new(
129 Span::call_site(),
130 "`#[dto]` does not accept arguments; use it as `#[dto]`",
131 )
132 .to_compile_error()
133 .into();
134 }
135
136 let mut item = parse_macro_input!(item as Item);
137 let meld_crate = match resolve_meld_server_path() {
138 Ok(path) => path,
139 Err(err) => return err.to_compile_error().into(),
140 };
141
142 let apply_result = match &mut item {
143 Item::Struct(ItemStruct { attrs, .. }) => ensure_dto_derives(attrs, &meld_crate),
144 Item::Enum(ItemEnum { attrs, .. }) => ensure_dto_derives(attrs, &meld_crate),
145 _ => Err(Error::new(
146 item.span(),
147 "`#[dto]` can only be used on structs or enums",
148 )),
149 };
150
151 match apply_result {
152 Ok(()) => TokenStream::from(quote!(#item)),
153 Err(err) => err.to_compile_error().into(),
154 }
155}
156
157fn resolve_meld_server_path() -> syn::Result<syn::Path> {
158 let found = crate_name("meld-server").or_else(|_| crate_name("alloy-server"));
159 match found {
160 Ok(FoundCrate::Itself) => Ok(parse_quote!(crate)),
161 Ok(FoundCrate::Name(name)) => {
162 let sanitized = name.replace('-', "_");
163 let ident = Ident::new(&sanitized, Span::call_site());
164 Ok(parse_quote!(::#ident))
165 }
166 Err(_) => Err(Error::new(
167 Span::call_site(),
168 "failed to resolve `meld-server` crate for `#[route(..., auto_validate)]`; \
169 ensure `meld-server` (or legacy `alloy-server`) is present in Cargo.toml dependencies",
170 )),
171 }
172}
173
174fn ensure_dto_derives(attrs: &mut Vec<Attribute>, meld_crate: &syn::Path) -> syn::Result<()> {
175 let required: [syn::Path; 3] = [
176 parse_quote!(#meld_crate::serde::Deserialize),
177 parse_quote!(#meld_crate::validator::Validate),
178 parse_quote!(#meld_crate::utoipa::ToSchema),
179 ];
180 let mut existing_last_segments = std::collections::BTreeSet::new();
181 let mut first_derive: Option<(usize, Punctuated<syn::Path, Token![,]>)> = None;
182
183 for (idx, attr) in attrs.iter().enumerate() {
184 if !attr.path().is_ident("derive") {
185 continue;
186 }
187
188 let derives = attr.parse_args_with(Punctuated::<syn::Path, Token![,]>::parse_terminated)?;
189 for path in &derives {
190 if let Some(last) = path.segments.last() {
191 existing_last_segments.insert(last.ident.to_string());
192 }
193 }
194 if first_derive.is_none() {
195 first_derive = Some((idx, derives));
196 }
197 }
198
199 let mut missing = Vec::new();
200 for path in required {
201 if let Some(last) = path.segments.last() {
202 if !existing_last_segments.contains(&last.ident.to_string()) {
203 missing.push(path);
204 }
205 }
206 }
207
208 if missing.is_empty() {
209 return Ok(());
210 }
211
212 if let Some((idx, mut derive_paths)) = first_derive {
213 for path in missing {
214 derive_paths.push(path);
215 }
216 attrs[idx] = parse_quote!(#[derive(#derive_paths)]);
217 } else {
218 attrs.insert(0, parse_quote!(#[derive(#(#missing),*)]));
219 }
220
221 Ok(())
222}
223
224fn apply_auto_validate(item_fn: &mut ItemFn, meld_crate: &syn::Path) -> syn::Result<()> {
225 let mut errors: Option<syn::Error> = None;
226
227 for input in &mut item_fn.sig.inputs {
228 if let FnArg::Typed(arg) = input {
229 if let Err(err) = maybe_rewrite_typed_arg(arg, meld_crate) {
230 if let Some(existing) = &mut errors {
231 existing.combine(err);
232 } else {
233 errors = Some(err);
234 }
235 }
236 }
237 }
238
239 match errors {
240 Some(err) => Err(err),
241 None => Ok(()),
242 }
243}
244
245fn maybe_rewrite_typed_arg(arg: &mut syn::PatType, meld_crate: &syn::Path) -> syn::Result<()> {
246 let (kind, original_segment, inner_ty) = match extract_rewrite_target(&arg.ty)? {
247 Some(values) => values,
248 None => return Ok(()),
249 };
250
251 let validated_path = validated_extractor_path(meld_crate, kind);
252 let rewritten_ty: Type = parse_quote!(#validated_path<#inner_ty>);
253 *arg.ty = rewritten_ty;
254
255 rewrite_pattern(&mut arg.pat, kind, &validated_path, &original_segment)
256}
257
258fn extract_rewrite_target(ty: &Type) -> syn::Result<Option<(ExtractorKind, PathSegment, Type)>> {
259 let Type::Path(type_path) = ty else {
260 return Ok(None);
261 };
262
263 let Some(segment) = type_path.path.segments.last() else {
264 return Ok(None);
265 };
266
267 let Some(kind) = ExtractorKind::parse(segment.ident.to_string().as_str()) else {
268 return Ok(None);
269 };
270
271 let inner_ty = extract_single_generic_type(&segment.arguments).map_err(|err| {
272 Error::new(
273 segment.ident.span(),
274 format!(
275 "`{}` extractor in auto_validate must have exactly one type parameter: {err}",
276 kind.source_ident()
277 ),
278 )
279 })?;
280
281 Ok(Some((kind, segment.clone(), inner_ty)))
282}
283
284fn rewrite_pattern(
285 pat: &mut Box<Pat>,
286 kind: ExtractorKind,
287 validated_path: &syn::Path,
288 original_segment: &PathSegment,
289) -> syn::Result<()> {
290 match pat.as_mut() {
291 Pat::TupleStruct(PatTupleStruct { path, .. }) => {
292 let Some(last) = path.segments.last() else {
293 return Err(Error::new(
294 path.span(),
295 format!(
296 "unsupported `{}` pattern in auto_validate; use `{name}(value)` or `value: {name}<T>`",
297 kind.source_ident(),
298 name = kind.source_ident()
299 ),
300 ));
301 };
302
303 let last_name = last.ident.to_string();
304 let source = kind.source_ident();
305 let validated = kind.validated_ident();
306 if last_name != source && last_name != validated {
307 return Err(Error::new(
308 last.ident.span(),
309 format!(
310 "pattern `{}` does not match extractor `{}` in auto_validate; expected `{}` pattern",
311 last_name, source, source
312 ),
313 ));
314 }
315
316 *path = validated_path.clone();
317 Ok(())
318 }
319 Pat::Ident(ident_pat) => {
320 if ident_pat.by_ref.is_some() || ident_pat.subpat.is_some() {
321 return Err(Error::new(
322 ident_pat.span(),
323 format!(
324 "unsupported `{}` binding form in auto_validate; use simple binding like `value: {}<T>`",
325 kind.source_ident(),
326 kind.source_ident()
327 ),
328 ));
329 }
330
331 let ident = ident_pat.ident.clone();
332 let new_pat: Pat = if ident_pat.mutability.is_some() {
333 parse_quote!(#validated_path(mut #ident))
334 } else {
335 parse_quote!(#validated_path(#ident))
336 };
337 **pat = new_pat;
338 Ok(())
339 }
340 Pat::Wild(_) => {
341 let new_pat: Pat = parse_quote!(#validated_path(_));
342 **pat = new_pat;
343 Ok(())
344 }
345 _ => Err(Error::new(
346 pat.span(),
347 format!(
348 "unsupported pattern for `{}` in auto_validate; use `{}` destructuring (`{}(value)`) or simple binding (`value: {}<T>`)",
349 original_segment.ident,
350 kind.source_ident(),
351 kind.source_ident(),
352 kind.source_ident(),
353 ),
354 )),
355 }
356}
357
358fn validated_extractor_path(meld_crate: &syn::Path, kind: ExtractorKind) -> syn::Path {
359 match kind {
360 ExtractorKind::Json => parse_quote!(#meld_crate::api::ValidatedJson),
361 ExtractorKind::Query => parse_quote!(#meld_crate::api::ValidatedQuery),
362 ExtractorKind::Path => parse_quote!(#meld_crate::api::ValidatedPath),
363 }
364}
365
366fn extract_single_generic_type(arguments: &PathArguments) -> syn::Result<Type> {
367 let PathArguments::AngleBracketed(args) = arguments else {
368 return Err(Error::new(Span::call_site(), "missing generic parameter"));
369 };
370 if args.args.len() != 1 {
371 return Err(Error::new(
372 Span::call_site(),
373 "expected exactly one generic parameter",
374 ));
375 }
376 match args.args.first() {
377 Some(GenericArgument::Type(ty)) => Ok(ty.clone()),
378 _ => Err(Error::new(
379 Span::call_site(),
380 "generic parameter must be a concrete type",
381 )),
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use syn::{parse_quote, parse_str};
389
390 #[test]
391 fn parses_method_path_and_auto_validate_flag() {
392 let parsed = parse_str::<RouteArgs>(r#"post, "/notes", auto_validate"#)
393 .expect("route args should parse");
394
395 assert_eq!(parsed.method, RouteMethod::Post);
396 assert_eq!(parsed.path.value(), "/notes");
397 assert!(parsed.auto_validate);
398 }
399
400 #[test]
401 fn parses_without_auto_validate() {
402 let parsed = parse_str::<RouteArgs>(r#"get, "/health""#).expect("route args should parse");
403
404 assert_eq!(parsed.method, RouteMethod::Get);
405 assert_eq!(parsed.path.value(), "/health");
406 assert!(!parsed.auto_validate);
407 }
408
409 #[test]
410 fn rejects_unsupported_method() {
411 let err = match parse_str::<RouteArgs>(r#"options, "/notes""#) {
412 Ok(_) => panic!("unsupported method must fail"),
413 Err(err) => err,
414 };
415 assert!(err.to_string().contains("unsupported method"));
416 }
417
418 #[test]
419 fn rejects_unknown_flag() {
420 let err = match parse_str::<RouteArgs>(r#"post, "/notes", unknown_flag"#) {
421 Ok(_) => panic!("unknown flag must fail"),
422 Err(err) => err,
423 };
424
425 let message = err.to_string();
426 assert!(message.contains("unknown flag"));
427 }
428
429 #[test]
430 fn rejects_missing_path() {
431 let err = match parse_str::<RouteArgs>("post") {
432 Ok(_) => panic!("missing path must fail"),
433 Err(err) => err,
434 };
435 assert!(err.to_string().contains("path"));
436 }
437
438 #[test]
439 fn rejects_non_string_path() {
440 let err = match parse_str::<RouteArgs>("post, 10") {
441 Ok(_) => panic!("non-string path must fail"),
442 Err(err) => err,
443 };
444 assert!(err.to_string().contains("string"));
445 }
446
447 #[test]
448 fn auto_validate_rewrites_json_query_and_path_extractors() {
449 let mut item_fn: ItemFn = parse_quote! {
450 async fn create_note(
451 Query(q): Query<ListQuery>,
452 Json(body): Json<CreateNote>,
453 Path(path): Path<NotePath>,
454 ) {}
455 };
456
457 let meld: syn::Path = parse_quote!(::meld_server);
458 apply_auto_validate(&mut item_fn, &meld).expect("rewrite should work");
459
460 let first = item_fn
461 .sig
462 .inputs
463 .iter()
464 .next()
465 .expect("first arg should exist");
466 let second = item_fn
467 .sig
468 .inputs
469 .iter()
470 .nth(1)
471 .expect("second arg should exist");
472 let third = item_fn
473 .sig
474 .inputs
475 .iter()
476 .nth(2)
477 .expect("third arg should exist");
478
479 assert_eq!(arg_type_ident(first), Some("ValidatedQuery".to_string()));
480 assert_eq!(arg_pat_ident(first), Some("ValidatedQuery".to_string()));
481 assert_eq!(arg_type_ident(second), Some("ValidatedJson".to_string()));
482 assert_eq!(arg_pat_ident(second), Some("ValidatedJson".to_string()));
483 assert_eq!(arg_type_ident(third), Some("ValidatedPath".to_string()));
484 assert_eq!(arg_pat_ident(third), Some("ValidatedPath".to_string()));
485 }
486
487 #[test]
488 fn auto_validate_rewrites_identifier_pattern_to_destructure() {
489 let mut item_fn: ItemFn = parse_quote! {
490 async fn create_note(query: Query<ListQuery>) {}
491 };
492
493 let meld: syn::Path = parse_quote!(::meld_server);
494 apply_auto_validate(&mut item_fn, &meld).expect("rewrite should work");
495
496 let first = item_fn.sig.inputs.iter().next().expect("arg should exist");
497 assert_eq!(arg_type_ident(first), Some("ValidatedQuery".to_string()));
498 assert_eq!(arg_pat_ident(first), Some("ValidatedQuery".to_string()));
499 }
500
501 #[test]
502 fn auto_validate_reports_actionable_error_for_unsupported_pattern() {
503 let mut item_fn: ItemFn = parse_quote! {
504 async fn create_note((query): Query<ListQuery>) {}
505 };
506
507 let meld: syn::Path = parse_quote!(::meld_server);
508 let err = apply_auto_validate(&mut item_fn, &meld).expect_err("must fail");
509 assert!(err.to_string().contains("unsupported pattern"));
510 }
511
512 #[test]
513 fn without_auto_validate_keeps_original_extractors() {
514 let mut item_fn: ItemFn = parse_quote! {
515 async fn create_note(Query(q): Query<ListQuery>, Json(body): Json<CreateNote>) {}
516 };
517 let args = parse_str::<RouteArgs>(r#"post, "/notes""#).expect("route args should parse");
518
519 if args.auto_validate {
520 let meld: syn::Path = parse_quote!(::meld_server);
521 apply_auto_validate(&mut item_fn, &meld).expect("rewrite should work");
522 }
523
524 let rendered = quote!(#item_fn).to_string();
525 assert!(rendered.contains("Query"));
526 assert!(rendered.contains("Json"));
527 assert!(!rendered.contains("ValidatedQuery"));
528 assert!(!rendered.contains("ValidatedJson"));
529 }
530
531 #[test]
532 fn resolved_path_uses_callsite_crate_alias() {
533 let path: syn::Path = match FoundCrate::Name("meld_api".to_string()) {
534 FoundCrate::Name(name) => {
535 let ident = Ident::new(&name.replace('-', "_"), Span::call_site());
536 parse_quote!(::#ident)
537 }
538 FoundCrate::Itself => parse_quote!(crate),
539 };
540
541 let rendered = quote!(#path).to_string();
542 assert_eq!(rendered, ":: meld_api");
543 }
544
545 #[test]
546 fn dto_injects_deserialize_validate_and_schema_derives() {
547 let mut item: ItemStruct = parse_quote! {
548 struct Payload {
549 #[validate(length(min = 1))]
550 name: String,
551 }
552 };
553 let meld: syn::Path = parse_quote!(::meld_server);
554 ensure_dto_derives(&mut item.attrs, &meld).expect("dto derives should be injected");
555
556 let rendered = quote!(#item).to_string();
557 assert!(rendered.contains(":: meld_server :: serde :: Deserialize"));
558 assert!(rendered.contains(":: meld_server :: validator :: Validate"));
559 assert!(rendered.contains(":: meld_server :: utoipa :: ToSchema"));
560 }
561
562 #[test]
563 fn dto_keeps_existing_derive_and_appends_missing() {
564 let mut item: ItemStruct = parse_quote! {
565 #[derive(Debug, serde::Deserialize)]
566 struct Payload {
567 #[validate(length(min = 1))]
568 name: String,
569 }
570 };
571 let meld: syn::Path = parse_quote!(::meld_server);
572 ensure_dto_derives(&mut item.attrs, &meld).expect("dto derives should be injected");
573
574 let rendered = quote!(#item).to_string();
575 assert!(rendered.contains("Debug"));
576 assert!(rendered.contains("serde :: Deserialize"));
577 assert!(rendered.contains(":: meld_server :: validator :: Validate"));
578 assert!(rendered.contains(":: meld_server :: utoipa :: ToSchema"));
579 }
580
581 fn arg_type_ident(arg: &FnArg) -> Option<String> {
582 let FnArg::Typed(arg) = arg else {
583 return None;
584 };
585 let Type::Path(type_path) = arg.ty.as_ref() else {
586 return None;
587 };
588 type_path.path.segments.last().map(|s| s.ident.to_string())
589 }
590
591 fn arg_pat_ident(arg: &FnArg) -> Option<String> {
592 let FnArg::Typed(arg) = arg else {
593 return None;
594 };
595 let Pat::TupleStruct(tuple_struct) = arg.pat.as_ref() else {
596 return None;
597 };
598 tuple_struct
599 .path
600 .segments
601 .last()
602 .map(|s| s.ident.to_string())
603 }
604}