maudit_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse::{self, Parse, ParseStream, Parser as _, Result};
4use syn::{Expr, Ident, ItemStruct, Token, parse_macro_input, punctuated::Punctuated};
5
6enum LocaleKind {
7    FullPath(Expr),
8    Prefix(Expr),
9}
10
11struct LocaleVariant {
12    locale: Ident,
13    kind: LocaleKind,
14}
15
16impl Parse for LocaleVariant {
17    fn parse(input: ParseStream) -> Result<Self> {
18        let locale = input.parse::<Ident>()?;
19
20        // Check if it's `locale = "path"`, `locale(path = "path")`, or `locale(prefix = "path")`
21        let lookahead = input.lookahead1();
22
23        let kind = if lookahead.peek(Token![=]) {
24            // Shorthand full path: `en = "/en/about"`
25            input.parse::<Token![=]>()?;
26            let path = input.parse::<Expr>()?;
27            LocaleKind::FullPath(path)
28        } else if lookahead.peek(syn::token::Paren) {
29            // Either `en(path = "...")` or `en(prefix = "...")`
30            let content;
31            syn::parenthesized!(content in input);
32
33            let key_ident: Ident = content.parse()?;
34            content.parse::<Token![=]>()?;
35            let value = content.parse::<Expr>()?;
36
37            if key_ident == "path" {
38                LocaleKind::FullPath(value)
39            } else if key_ident == "prefix" {
40                LocaleKind::Prefix(value)
41            } else {
42                return Err(content.error("expected 'path' or 'prefix'"));
43            }
44        } else {
45            return Err(lookahead.error());
46        };
47
48        Ok(LocaleVariant { locale, kind })
49    }
50}
51
52struct SitemapArgs {
53    exclude: Option<bool>,
54    changefreq: Option<Expr>,
55    priority: Option<Expr>,
56}
57
58impl Parse for SitemapArgs {
59    fn parse(input: ParseStream) -> Result<Self> {
60        let mut exclude = None;
61        let mut changefreq = None;
62        let mut priority = None;
63
64        while !input.is_empty() {
65            let key: Ident = input.parse()?;
66            input.parse::<Token![=]>()?;
67
68            match key.to_string().as_str() {
69                "exclude" => {
70                    let value: syn::LitBool = input.parse()?;
71                    exclude = Some(value.value);
72                }
73                "changefreq" => {
74                    changefreq = Some(input.parse()?);
75                }
76                "priority" => {
77                    priority = Some(input.parse()?);
78                }
79                _ => {
80                    return Err(syn::Error::new_spanned(
81                        key,
82                        "unknown sitemap argument, expected 'exclude', 'changefreq', or 'priority'",
83                    ));
84                }
85            }
86
87            if input.peek(Token![,]) {
88                input.parse::<Token![,]>()?;
89            } else {
90                break;
91            }
92        }
93
94        Ok(SitemapArgs {
95            exclude,
96            changefreq,
97            priority,
98        })
99    }
100}
101
102struct RouteArgs {
103    path: Option<Expr>,
104    locales: Vec<LocaleVariant>,
105    sitemap: Option<SitemapArgs>,
106}
107
108impl Parse for RouteArgs {
109    fn parse(input: ParseStream) -> Result<Self> {
110        let mut path = None;
111        let mut locales = Vec::new();
112        let mut sitemap = None;
113
114        if input.is_empty() {
115            return Ok(RouteArgs {
116                path,
117                locales,
118                sitemap,
119            });
120        }
121
122        // First argument: either a path expression or a named argument like locales(...)
123        if input.peek(Ident) && input.peek2(syn::token::Paren) {
124            // If the first argument is a named one, that means there's no base path and this route should only have variants
125            let ident: Ident = input.parse()?;
126            let ident_str = ident.to_string();
127
128            if ident_str == "locales" {
129                let content;
130                syn::parenthesized!(content in input);
131                let variants = Punctuated::<LocaleVariant, Token![,]>::parse_terminated(&content)?;
132                locales = variants.into_iter().collect();
133            } else if ident_str == "sitemap" {
134                let content;
135                syn::parenthesized!(content in input);
136                sitemap = Some(content.parse()?);
137            } else {
138                return Err(syn::Error::new_spanned(
139                    ident,
140                    format!(
141                        "unknown argument '{}', expected 'locales' or 'sitemap'",
142                        ident_str
143                    ),
144                ));
145            }
146        } else {
147            // First argument is a path expression, e.g., "/about" so proceed as normal
148            path = Some(input.parse::<Expr>()?);
149        }
150
151        // Parse remaining named arguments (right now just locales(...))
152        while !input.is_empty() {
153            input.parse::<Token![,]>()?;
154
155            if input.is_empty() {
156                break;
157            }
158
159            // All subsequent arguments must be named (e.g., locales(...), the path must be first)
160            if input.peek(Ident) && input.peek2(syn::token::Paren) {
161                let ident: Ident = input.parse()?;
162                let ident_str = ident.to_string();
163
164                if ident_str == "locales" {
165                    if !locales.is_empty() {
166                        return Err(syn::Error::new_spanned(
167                            ident,
168                            "locales specified multiple times",
169                        ));
170                    }
171                    let content;
172                    syn::parenthesized!(content in input);
173                    let variants =
174                        Punctuated::<LocaleVariant, Token![,]>::parse_terminated(&content)?;
175                    locales = variants.into_iter().collect();
176                } else if ident_str == "sitemap" {
177                    if sitemap.is_some() {
178                        return Err(syn::Error::new_spanned(
179                            ident,
180                            "sitemap specified multiple times",
181                        ));
182                    }
183                    let content;
184                    syn::parenthesized!(content in input);
185                    sitemap = Some(content.parse()?);
186                } else {
187                    return Err(syn::Error::new_spanned(
188                        ident,
189                        format!("unknown argument '{}'", ident_str),
190                    ));
191                }
192            } else {
193                return Err(syn::Error::new(
194                    input.span(),
195                    "expected named argument (e.g., locales(...)), path must be first argument",
196                ));
197            }
198        }
199
200        // Check for duplicate locales
201        Self::check_duplicate_locales(&locales)?;
202
203        Ok(RouteArgs {
204            path,
205            locales,
206            sitemap,
207        })
208    }
209}
210
211impl RouteArgs {
212    fn check_duplicate_locales(locales: &[LocaleVariant]) -> Result<()> {
213        use std::collections::HashSet;
214        let mut seen = HashSet::new();
215
216        for variant in locales {
217            let locale_name = variant.locale.to_string();
218            if !seen.insert(locale_name.clone()) {
219                return Err(syn::Error::new_spanned(
220                    &variant.locale,
221                    format!("duplicate locale '{}' specified", locale_name),
222                ));
223            }
224        }
225
226        Ok(())
227    }
228}
229
230#[proc_macro_attribute]
231pub fn route(attrs: TokenStream, item: TokenStream) -> TokenStream {
232    // Parse the input tokens into a syntax tree
233    let item_struct = syn::parse_macro_input!(item as ItemStruct);
234    let args = syn::parse_macro_input!(attrs as RouteArgs);
235
236    let struct_name = &item_struct.ident;
237
238    // Generate variants method based on locales
239    let variant_method = if !args.locales.is_empty() {
240        let variant_tuples = args.locales.iter().map(|variant| {
241            let locale_name = variant.locale.to_string();
242
243            match &variant.kind {
244                LocaleKind::FullPath(path) => {
245                    quote! {
246                        (#locale_name.to_string(), #path.to_string())
247                    }
248                }
249                LocaleKind::Prefix(prefix) => {
250                    if args.path.is_none() {
251                        // Emit compile error if prefix is used without base path
252                        quote! {
253                            compile_error!("Cannot use locale prefix without a base route path")
254                        }
255                    } else {
256                        let base_path = args.path.as_ref().unwrap();
257                        quote! {
258                            (#locale_name.to_string(), format!("{}{}", #prefix, #base_path))
259                        }
260                    }
261                }
262            }
263        });
264
265        quote! {
266            fn variants(&self) -> Vec<(String, String)> {
267                vec![#(#variant_tuples),*]
268            }
269        }
270    } else {
271        quote! {
272            fn variants(&self) -> Vec<(String, String)> {
273                vec![]
274            }
275        }
276    };
277
278    // Generate route_raw implementation based on whether path is provided
279    let route_raw_impl = if let Some(path) = &args.path {
280        quote! {
281            fn route_raw(&self) -> Option<String> {
282                Some(#path.to_string())
283            }
284        }
285    } else {
286        quote! {
287            fn route_raw(&self) -> Option<String> {
288                None
289            }
290        }
291    };
292
293    // Generate sitemap metadata method
294    let sitemap_method = if let Some(sitemap_args) = &args.sitemap {
295        let exclude_impl = if let Some(exclude) = sitemap_args.exclude {
296            quote! { Some(#exclude) }
297        } else {
298            quote! { None }
299        };
300
301        let changefreq_impl = if let Some(changefreq) = &sitemap_args.changefreq {
302            quote! { Some(#changefreq) }
303        } else {
304            quote! { None }
305        };
306
307        let priority_impl = if let Some(priority) = &sitemap_args.priority {
308            quote! { Some(#priority) }
309        } else {
310            quote! { None }
311        };
312
313        quote! {
314            fn sitemap_metadata(&self) -> maudit::sitemap::RouteSitemapMetadata {
315                maudit::sitemap::RouteSitemapMetadata {
316                    exclude: #exclude_impl,
317                    changefreq: #changefreq_impl,
318                    priority: #priority_impl,
319                }
320            }
321        }
322    } else {
323        quote! {
324            fn sitemap_metadata(&self) -> maudit::sitemap::RouteSitemapMetadata {
325                maudit::sitemap::RouteSitemapMetadata::default()
326            }
327        }
328    };
329
330    let expanded = quote! {
331        impl maudit::route::InternalRoute for #struct_name {
332            #route_raw_impl
333
334            #variant_method
335
336            #sitemap_method
337        }
338
339        impl maudit::route::FullRoute for #struct_name {
340            fn render_internal(&self, ctx: &mut maudit::route::PageContext) -> Result<maudit::route::RenderResult, Box<dyn std::error::Error>> {
341                let result: maudit::route::RenderResult = self.render(ctx).into();
342                result.into()
343            }
344
345            fn pages_internal(&self, ctx: &mut maudit::route::DynamicRouteContext) -> Vec<(maudit::route::PageParams, Box<dyn std::any::Any + Send + Sync>, Box<dyn std::any::Any + Send + Sync>)> {
346                self.pages(ctx)
347                    .into_iter()
348                    .map(|route| {
349                        let raw_params: maudit::route::PageParams = (&route.params).into();
350                        let typed_params: Box<dyn std::any::Any + Send + Sync> = Box::new(route.params);
351                        let props: Box<dyn std::any::Any + Send + Sync> = Box::new(route.props);
352                        (raw_params, typed_params, props)
353                    })
354                    .collect()
355            }
356        }
357
358        #item_struct
359    };
360
361    TokenStream::from(expanded)
362}
363
364#[proc_macro_derive(Params)]
365pub fn derive_params(item: TokenStream) -> TokenStream {
366    let item_struct = syn::parse_macro_input!(item as ItemStruct);
367    let struct_name = &item_struct.ident;
368
369    let field_conversions = match &item_struct.fields {
370        syn::Fields::Named(fields) => fields
371            .named
372            .iter()
373            .map(|field| {
374                let field_name = field.ident.as_ref().unwrap();
375                let field_name_str = field_name.to_string();
376
377                // Check if the field type is Option<T>
378                if is_option_type(&field.ty) {
379                    quote! {
380                        map.insert(
381                            #field_name_str.to_string(),
382                            self.#field_name.as_ref().map(|v| v.to_string())
383                        );
384                    }
385                } else {
386                    quote! {
387                        map.insert(#field_name_str.to_string(), Some(self.#field_name.to_string()));
388                    }
389                }
390            })
391            .collect::<Vec<_>>(),
392        _ => panic!("Only named fields are supported"),
393    };
394
395    let expanded = quote! {
396        impl Into<PageParams> for #struct_name {
397            fn into(self) -> PageParams {
398                (&self).into()
399            }
400        }
401
402        impl Into<PageParams> for &#struct_name {
403            fn into(self) -> PageParams {
404                let mut map = maudit::FxHashMap::default();
405                #(#field_conversions)*
406                PageParams(map)
407            }
408        }
409    };
410
411    TokenStream::from(expanded)
412}
413
414fn is_option_type(ty: &syn::Type) -> bool {
415    if let syn::Type::Path(type_path) = ty
416        && let Some(segment) = type_path.path.segments.last()
417    {
418        return segment.ident == "Option";
419    }
420    false
421}
422
423#[proc_macro_attribute]
424// Helps implement a struct as a Markdown content entry.
425//
426// See complete documentation in `crates/maudit/src/content.rs`.
427pub fn markdown_entry(args: TokenStream, item: TokenStream) -> TokenStream {
428    let mut item_struct = syn::parse_macro_input!(item as ItemStruct);
429    let _ = parse_macro_input!(args as parse::Nothing);
430
431    let struct_name = &item_struct.ident;
432
433    // Add __internal_headings field
434    if let syn::Fields::Named(ref mut fields) = item_struct.fields {
435        fields.named.push(
436            syn::Field::parse_named
437                .parse2(quote! {
438                    #[serde(skip)]
439                    __internal_headings: Vec<maudit::content::MarkdownHeading>
440                })
441                .unwrap(),
442        );
443    }
444
445    let expanded = quote! {
446        #[derive(serde::Deserialize)]
447        #item_struct
448
449        impl maudit::content::MarkdownContent for #struct_name {
450            fn get_headings(&self) -> &Vec<maudit::content::MarkdownHeading> {
451                &self.__internal_headings
452            }
453        }
454
455        impl maudit::content::InternalMarkdownContent for #struct_name {
456            fn set_headings(&mut self, headings: Vec<maudit::content::MarkdownHeading>) {
457                self.__internal_headings = headings;
458            }
459        }
460    };
461
462    TokenStream::from(expanded)
463}