Skip to main content

meld_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::quote;
5use syn::punctuated::Punctuated;
6use syn::spanned::Spanned;
7use syn::{
8    parse::Parse, parse_macro_input, parse_quote, Attribute, Error, FnArg, GenericArgument, Ident,
9    Item, ItemEnum, ItemFn, ItemStruct, LitStr, Pat, PatTupleStruct, PathArguments, PathSegment,
10    Token, Type,
11};
12
13struct RouteArgs {
14    method: RouteMethod,
15    path: LitStr,
16    auto_validate: bool,
17}
18
19#[derive(Clone, Copy, PartialEq, Eq, Debug)]
20enum RouteMethod {
21    Get,
22    Post,
23    Put,
24    Patch,
25    Delete,
26}
27
28#[derive(Clone, Copy)]
29enum ExtractorKind {
30    Json,
31    Query,
32    Path,
33}
34
35impl ExtractorKind {
36    fn parse(name: &str) -> Option<Self> {
37        match name {
38            "Json" => Some(Self::Json),
39            "Query" => Some(Self::Query),
40            "Path" => Some(Self::Path),
41            _ => None,
42        }
43    }
44
45    fn source_ident(self) -> &'static str {
46        match self {
47            Self::Json => "Json",
48            Self::Query => "Query",
49            Self::Path => "Path",
50        }
51    }
52
53    fn validated_ident(self) -> &'static str {
54        match self {
55            Self::Json => "ValidatedJson",
56            Self::Query => "ValidatedQuery",
57            Self::Path => "ValidatedPath",
58        }
59    }
60}
61
62impl Parse for RouteArgs {
63    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
64        let method_ident: Ident = input.parse()?;
65        let method = match method_ident.to_string().as_str() {
66            "get" => RouteMethod::Get,
67            "post" => RouteMethod::Post,
68            "put" => RouteMethod::Put,
69            "patch" => RouteMethod::Patch,
70            "delete" => RouteMethod::Delete,
71            _ => {
72                return Err(Error::new(
73                    method_ident.span(),
74                    "unsupported method; use one of: get, post, put, patch, delete",
75                ))
76            }
77        };
78        if input.is_empty() {
79            return Err(Error::new(method_ident.span(), "route path is required"));
80        }
81        input.parse::<Token![,]>()?;
82
83        let path: LitStr = input
84            .parse()
85            .map_err(|_| Error::new(input.span(), "route path must be a string literal"))?;
86
87        let mut auto_validate = false;
88        while !input.is_empty() {
89            input.parse::<Token![,]>()?;
90            let flag: Ident = input.parse()?;
91            match flag.to_string().as_str() {
92                "auto_validate" => auto_validate = true,
93                _ => return Err(Error::new(flag.span(), format!("unknown flag `{}`", flag))),
94            }
95        }
96
97        Ok(Self {
98            method,
99            path,
100            auto_validate,
101        })
102    }
103}
104
105#[proc_macro_attribute]
106pub fn route(args: TokenStream, item: TokenStream) -> TokenStream {
107    let parsed = parse_macro_input!(args as RouteArgs);
108    let mut item_fn = parse_macro_input!(item as ItemFn);
109
110    if parsed.auto_validate {
111        let meld_crate = match resolve_meld_server_path() {
112            Ok(path) => path,
113            Err(err) => return err.to_compile_error().into(),
114        };
115        if let Err(err) = apply_auto_validate(&mut item_fn, &meld_crate) {
116            return err.to_compile_error().into();
117        }
118    }
119
120    let _ = (parsed.method, parsed.path);
121
122    TokenStream::from(quote! { #item_fn })
123}
124
125#[proc_macro_attribute]
126pub fn dto(args: TokenStream, item: TokenStream) -> TokenStream {
127    if !args.is_empty() {
128        return Error::new(
129            Span::call_site(),
130            "`#[dto]` does not accept arguments; use it as `#[dto]`",
131        )
132        .to_compile_error()
133        .into();
134    }
135
136    let mut item = parse_macro_input!(item as Item);
137    let meld_crate = match resolve_meld_server_path() {
138        Ok(path) => path,
139        Err(err) => return err.to_compile_error().into(),
140    };
141
142    let apply_result = match &mut item {
143        Item::Struct(ItemStruct { attrs, .. }) => ensure_dto_derives(attrs, &meld_crate),
144        Item::Enum(ItemEnum { attrs, .. }) => ensure_dto_derives(attrs, &meld_crate),
145        _ => Err(Error::new(
146            item.span(),
147            "`#[dto]` can only be used on structs or enums",
148        )),
149    };
150
151    match apply_result {
152        Ok(()) => TokenStream::from(quote!(#item)),
153        Err(err) => err.to_compile_error().into(),
154    }
155}
156
157fn resolve_meld_server_path() -> syn::Result<syn::Path> {
158    let found = crate_name("meld-server").or_else(|_| crate_name("alloy-server"));
159    match found {
160        Ok(FoundCrate::Itself) => Ok(parse_quote!(crate)),
161        Ok(FoundCrate::Name(name)) => {
162            let sanitized = name.replace('-', "_");
163            let ident = Ident::new(&sanitized, Span::call_site());
164            Ok(parse_quote!(::#ident))
165        }
166        Err(_) => Err(Error::new(
167            Span::call_site(),
168            "failed to resolve `meld-server` crate for `#[route(..., auto_validate)]`; \
169             ensure `meld-server` (or legacy `alloy-server`) is present in Cargo.toml dependencies",
170        )),
171    }
172}
173
174fn ensure_dto_derives(attrs: &mut Vec<Attribute>, meld_crate: &syn::Path) -> syn::Result<()> {
175    let required: [syn::Path; 3] = [
176        parse_quote!(#meld_crate::serde::Deserialize),
177        parse_quote!(#meld_crate::validator::Validate),
178        parse_quote!(#meld_crate::utoipa::ToSchema),
179    ];
180    let mut existing_last_segments = std::collections::BTreeSet::new();
181    let mut first_derive: Option<(usize, Punctuated<syn::Path, Token![,]>)> = None;
182
183    for (idx, attr) in attrs.iter().enumerate() {
184        if !attr.path().is_ident("derive") {
185            continue;
186        }
187
188        let derives = attr.parse_args_with(Punctuated::<syn::Path, Token![,]>::parse_terminated)?;
189        for path in &derives {
190            if let Some(last) = path.segments.last() {
191                existing_last_segments.insert(last.ident.to_string());
192            }
193        }
194        if first_derive.is_none() {
195            first_derive = Some((idx, derives));
196        }
197    }
198
199    let mut missing = Vec::new();
200    for path in required {
201        if let Some(last) = path.segments.last() {
202            if !existing_last_segments.contains(&last.ident.to_string()) {
203                missing.push(path);
204            }
205        }
206    }
207
208    if missing.is_empty() {
209        return Ok(());
210    }
211
212    if let Some((idx, mut derive_paths)) = first_derive {
213        for path in missing {
214            derive_paths.push(path);
215        }
216        attrs[idx] = parse_quote!(#[derive(#derive_paths)]);
217    } else {
218        attrs.insert(0, parse_quote!(#[derive(#(#missing),*)]));
219    }
220
221    Ok(())
222}
223
224fn apply_auto_validate(item_fn: &mut ItemFn, meld_crate: &syn::Path) -> syn::Result<()> {
225    let mut errors: Option<syn::Error> = None;
226
227    for input in &mut item_fn.sig.inputs {
228        if let FnArg::Typed(arg) = input {
229            if let Err(err) = maybe_rewrite_typed_arg(arg, meld_crate) {
230                if let Some(existing) = &mut errors {
231                    existing.combine(err);
232                } else {
233                    errors = Some(err);
234                }
235            }
236        }
237    }
238
239    match errors {
240        Some(err) => Err(err),
241        None => Ok(()),
242    }
243}
244
245fn maybe_rewrite_typed_arg(arg: &mut syn::PatType, meld_crate: &syn::Path) -> syn::Result<()> {
246    let (kind, original_segment, inner_ty) = match extract_rewrite_target(&arg.ty)? {
247        Some(values) => values,
248        None => return Ok(()),
249    };
250
251    let validated_path = validated_extractor_path(meld_crate, kind);
252    let rewritten_ty: Type = parse_quote!(#validated_path<#inner_ty>);
253    *arg.ty = rewritten_ty;
254
255    rewrite_pattern(&mut arg.pat, kind, &validated_path, &original_segment)
256}
257
258fn extract_rewrite_target(ty: &Type) -> syn::Result<Option<(ExtractorKind, PathSegment, Type)>> {
259    let Type::Path(type_path) = ty else {
260        return Ok(None);
261    };
262
263    let Some(segment) = type_path.path.segments.last() else {
264        return Ok(None);
265    };
266
267    let Some(kind) = ExtractorKind::parse(segment.ident.to_string().as_str()) else {
268        return Ok(None);
269    };
270
271    let inner_ty = extract_single_generic_type(&segment.arguments).map_err(|err| {
272        Error::new(
273            segment.ident.span(),
274            format!(
275                "`{}` extractor in auto_validate must have exactly one type parameter: {err}",
276                kind.source_ident()
277            ),
278        )
279    })?;
280
281    Ok(Some((kind, segment.clone(), inner_ty)))
282}
283
284fn rewrite_pattern(
285    pat: &mut Box<Pat>,
286    kind: ExtractorKind,
287    validated_path: &syn::Path,
288    original_segment: &PathSegment,
289) -> syn::Result<()> {
290    match pat.as_mut() {
291        Pat::TupleStruct(PatTupleStruct { path, .. }) => {
292            let Some(last) = path.segments.last() else {
293                return Err(Error::new(
294                    path.span(),
295                    format!(
296                        "unsupported `{}` pattern in auto_validate; use `{name}(value)` or `value: {name}<T>`",
297                        kind.source_ident(),
298                        name = kind.source_ident()
299                    ),
300                ));
301            };
302
303            let last_name = last.ident.to_string();
304            let source = kind.source_ident();
305            let validated = kind.validated_ident();
306            if last_name != source && last_name != validated {
307                return Err(Error::new(
308                    last.ident.span(),
309                    format!(
310                        "pattern `{}` does not match extractor `{}` in auto_validate; expected `{}` pattern",
311                        last_name, source, source
312                    ),
313                ));
314            }
315
316            *path = validated_path.clone();
317            Ok(())
318        }
319        Pat::Ident(ident_pat) => {
320            if ident_pat.by_ref.is_some() || ident_pat.subpat.is_some() {
321                return Err(Error::new(
322                    ident_pat.span(),
323                    format!(
324                        "unsupported `{}` binding form in auto_validate; use simple binding like `value: {}<T>`",
325                        kind.source_ident(),
326                        kind.source_ident()
327                    ),
328                ));
329            }
330
331            let ident = ident_pat.ident.clone();
332            let new_pat: Pat = if ident_pat.mutability.is_some() {
333                parse_quote!(#validated_path(mut #ident))
334            } else {
335                parse_quote!(#validated_path(#ident))
336            };
337            **pat = new_pat;
338            Ok(())
339        }
340        Pat::Wild(_) => {
341            let new_pat: Pat = parse_quote!(#validated_path(_));
342            **pat = new_pat;
343            Ok(())
344        }
345        _ => Err(Error::new(
346            pat.span(),
347            format!(
348                "unsupported pattern for `{}` in auto_validate; use `{}` destructuring (`{}(value)`) or simple binding (`value: {}<T>`)",
349                original_segment.ident,
350                kind.source_ident(),
351                kind.source_ident(),
352                kind.source_ident(),
353            ),
354        )),
355    }
356}
357
358fn validated_extractor_path(meld_crate: &syn::Path, kind: ExtractorKind) -> syn::Path {
359    match kind {
360        ExtractorKind::Json => parse_quote!(#meld_crate::api::ValidatedJson),
361        ExtractorKind::Query => parse_quote!(#meld_crate::api::ValidatedQuery),
362        ExtractorKind::Path => parse_quote!(#meld_crate::api::ValidatedPath),
363    }
364}
365
366fn extract_single_generic_type(arguments: &PathArguments) -> syn::Result<Type> {
367    let PathArguments::AngleBracketed(args) = arguments else {
368        return Err(Error::new(Span::call_site(), "missing generic parameter"));
369    };
370    if args.args.len() != 1 {
371        return Err(Error::new(
372            Span::call_site(),
373            "expected exactly one generic parameter",
374        ));
375    }
376    match args.args.first() {
377        Some(GenericArgument::Type(ty)) => Ok(ty.clone()),
378        _ => Err(Error::new(
379            Span::call_site(),
380            "generic parameter must be a concrete type",
381        )),
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use syn::{parse_quote, parse_str};
389
390    #[test]
391    fn parses_method_path_and_auto_validate_flag() {
392        let parsed = parse_str::<RouteArgs>(r#"post, "/notes", auto_validate"#)
393            .expect("route args should parse");
394
395        assert_eq!(parsed.method, RouteMethod::Post);
396        assert_eq!(parsed.path.value(), "/notes");
397        assert!(parsed.auto_validate);
398    }
399
400    #[test]
401    fn parses_without_auto_validate() {
402        let parsed = parse_str::<RouteArgs>(r#"get, "/health""#).expect("route args should parse");
403
404        assert_eq!(parsed.method, RouteMethod::Get);
405        assert_eq!(parsed.path.value(), "/health");
406        assert!(!parsed.auto_validate);
407    }
408
409    #[test]
410    fn rejects_unsupported_method() {
411        let err = match parse_str::<RouteArgs>(r#"options, "/notes""#) {
412            Ok(_) => panic!("unsupported method must fail"),
413            Err(err) => err,
414        };
415        assert!(err.to_string().contains("unsupported method"));
416    }
417
418    #[test]
419    fn rejects_unknown_flag() {
420        let err = match parse_str::<RouteArgs>(r#"post, "/notes", unknown_flag"#) {
421            Ok(_) => panic!("unknown flag must fail"),
422            Err(err) => err,
423        };
424
425        let message = err.to_string();
426        assert!(message.contains("unknown flag"));
427    }
428
429    #[test]
430    fn rejects_missing_path() {
431        let err = match parse_str::<RouteArgs>("post") {
432            Ok(_) => panic!("missing path must fail"),
433            Err(err) => err,
434        };
435        assert!(err.to_string().contains("path"));
436    }
437
438    #[test]
439    fn rejects_non_string_path() {
440        let err = match parse_str::<RouteArgs>("post, 10") {
441            Ok(_) => panic!("non-string path must fail"),
442            Err(err) => err,
443        };
444        assert!(err.to_string().contains("string"));
445    }
446
447    #[test]
448    fn auto_validate_rewrites_json_query_and_path_extractors() {
449        let mut item_fn: ItemFn = parse_quote! {
450            async fn create_note(
451                Query(q): Query<ListQuery>,
452                Json(body): Json<CreateNote>,
453                Path(path): Path<NotePath>,
454            ) {}
455        };
456
457        let meld: syn::Path = parse_quote!(::meld_server);
458        apply_auto_validate(&mut item_fn, &meld).expect("rewrite should work");
459
460        let first = item_fn
461            .sig
462            .inputs
463            .iter()
464            .next()
465            .expect("first arg should exist");
466        let second = item_fn
467            .sig
468            .inputs
469            .iter()
470            .nth(1)
471            .expect("second arg should exist");
472        let third = item_fn
473            .sig
474            .inputs
475            .iter()
476            .nth(2)
477            .expect("third arg should exist");
478
479        assert_eq!(arg_type_ident(first), Some("ValidatedQuery".to_string()));
480        assert_eq!(arg_pat_ident(first), Some("ValidatedQuery".to_string()));
481        assert_eq!(arg_type_ident(second), Some("ValidatedJson".to_string()));
482        assert_eq!(arg_pat_ident(second), Some("ValidatedJson".to_string()));
483        assert_eq!(arg_type_ident(third), Some("ValidatedPath".to_string()));
484        assert_eq!(arg_pat_ident(third), Some("ValidatedPath".to_string()));
485    }
486
487    #[test]
488    fn auto_validate_rewrites_identifier_pattern_to_destructure() {
489        let mut item_fn: ItemFn = parse_quote! {
490            async fn create_note(query: Query<ListQuery>) {}
491        };
492
493        let meld: syn::Path = parse_quote!(::meld_server);
494        apply_auto_validate(&mut item_fn, &meld).expect("rewrite should work");
495
496        let first = item_fn.sig.inputs.iter().next().expect("arg should exist");
497        assert_eq!(arg_type_ident(first), Some("ValidatedQuery".to_string()));
498        assert_eq!(arg_pat_ident(first), Some("ValidatedQuery".to_string()));
499    }
500
501    #[test]
502    fn auto_validate_reports_actionable_error_for_unsupported_pattern() {
503        let mut item_fn: ItemFn = parse_quote! {
504            async fn create_note((query): Query<ListQuery>) {}
505        };
506
507        let meld: syn::Path = parse_quote!(::meld_server);
508        let err = apply_auto_validate(&mut item_fn, &meld).expect_err("must fail");
509        assert!(err.to_string().contains("unsupported pattern"));
510    }
511
512    #[test]
513    fn without_auto_validate_keeps_original_extractors() {
514        let mut item_fn: ItemFn = parse_quote! {
515            async fn create_note(Query(q): Query<ListQuery>, Json(body): Json<CreateNote>) {}
516        };
517        let args = parse_str::<RouteArgs>(r#"post, "/notes""#).expect("route args should parse");
518
519        if args.auto_validate {
520            let meld: syn::Path = parse_quote!(::meld_server);
521            apply_auto_validate(&mut item_fn, &meld).expect("rewrite should work");
522        }
523
524        let rendered = quote!(#item_fn).to_string();
525        assert!(rendered.contains("Query"));
526        assert!(rendered.contains("Json"));
527        assert!(!rendered.contains("ValidatedQuery"));
528        assert!(!rendered.contains("ValidatedJson"));
529    }
530
531    #[test]
532    fn resolved_path_uses_callsite_crate_alias() {
533        let path: syn::Path = match FoundCrate::Name("meld_api".to_string()) {
534            FoundCrate::Name(name) => {
535                let ident = Ident::new(&name.replace('-', "_"), Span::call_site());
536                parse_quote!(::#ident)
537            }
538            FoundCrate::Itself => parse_quote!(crate),
539        };
540
541        let rendered = quote!(#path).to_string();
542        assert_eq!(rendered, ":: meld_api");
543    }
544
545    #[test]
546    fn dto_injects_deserialize_validate_and_schema_derives() {
547        let mut item: ItemStruct = parse_quote! {
548            struct Payload {
549                #[validate(length(min = 1))]
550                name: String,
551            }
552        };
553        let meld: syn::Path = parse_quote!(::meld_server);
554        ensure_dto_derives(&mut item.attrs, &meld).expect("dto derives should be injected");
555
556        let rendered = quote!(#item).to_string();
557        assert!(rendered.contains(":: meld_server :: serde :: Deserialize"));
558        assert!(rendered.contains(":: meld_server :: validator :: Validate"));
559        assert!(rendered.contains(":: meld_server :: utoipa :: ToSchema"));
560    }
561
562    #[test]
563    fn dto_keeps_existing_derive_and_appends_missing() {
564        let mut item: ItemStruct = parse_quote! {
565            #[derive(Debug, serde::Deserialize)]
566            struct Payload {
567                #[validate(length(min = 1))]
568                name: String,
569            }
570        };
571        let meld: syn::Path = parse_quote!(::meld_server);
572        ensure_dto_derives(&mut item.attrs, &meld).expect("dto derives should be injected");
573
574        let rendered = quote!(#item).to_string();
575        assert!(rendered.contains("Debug"));
576        assert!(rendered.contains("serde :: Deserialize"));
577        assert!(rendered.contains(":: meld_server :: validator :: Validate"));
578        assert!(rendered.contains(":: meld_server :: utoipa :: ToSchema"));
579    }
580
581    fn arg_type_ident(arg: &FnArg) -> Option<String> {
582        let FnArg::Typed(arg) = arg else {
583            return None;
584        };
585        let Type::Path(type_path) = arg.ty.as_ref() else {
586            return None;
587        };
588        type_path.path.segments.last().map(|s| s.ident.to_string())
589    }
590
591    fn arg_pat_ident(arg: &FnArg) -> Option<String> {
592        let FnArg::Typed(arg) = arg else {
593            return None;
594        };
595        let Pat::TupleStruct(tuple_struct) = arg.pat.as_ref() else {
596            return None;
597        };
598        tuple_struct
599            .path
600            .segments
601            .last()
602            .map(|s| s.ident.to_string())
603    }
604}