query_flow_macros/
lib.rs

1//! Procedural macros for query-flow.
2//!
3//! This crate provides attribute macros for defining queries and asset keys
4//! with minimal boilerplate.
5//!
6//! # Query Example
7//!
8//! ```ignore
9//! use query_flow::{query, QueryContext, QueryError};
10//!
11//! #[query]
12//! pub fn add(ctx: &mut QueryContext, a: i32, b: i32) -> Result<i32, QueryError> {
13//!     Ok(a + b)
14//! }
15//!
16//! // Generates:
17//! // pub struct Add { pub a: i32, pub b: i32 }
18//! // impl Query for Add { ... }
19//! ```
20//!
21//! # Asset Key Example
22//!
23//! ```ignore
24//! use query_flow::asset_key;
25//!
26//! #[asset_key(asset = String)]
27//! pub struct ConfigFile(pub PathBuf);
28//!
29//! #[asset_key(asset = String, durability = constant)]
30//! pub struct BundledAsset(pub PathBuf);
31//!
32//! // Generates:
33//! // impl AssetKey for ConfigFile { type Asset = String; ... }
34//! ```
35
36use darling::{ast::NestedMeta, FromMeta};
37use heck::ToUpperCamelCase;
38use proc_macro::TokenStream;
39use proc_macro2::TokenStream as TokenStream2;
40use quote::{format_ident, quote};
41use syn::{
42    parse_macro_input, spanned::Spanned, Error, FnArg, Ident, ItemFn, ItemStruct, Pat, PatType,
43    ReturnType, Type, Visibility,
44};
45
46/// Wrapper for parsing a list of identifiers: `keys(a, b, c)`
47#[derive(Debug, Default)]
48struct Keys(Vec<Ident>);
49
50impl FromMeta for Keys {
51    fn from_list(items: &[NestedMeta]) -> darling::Result<Self> {
52        let mut idents = Vec::new();
53        for item in items {
54            match item {
55                NestedMeta::Meta(syn::Meta::Path(path)) => {
56                    if let Some(ident) = path.get_ident() {
57                        idents.push(ident.clone());
58                    } else {
59                        return Err(darling::Error::custom("expected identifier").with_span(path));
60                    }
61                }
62                _ => {
63                    return Err(darling::Error::custom("expected identifier"));
64                }
65            }
66        }
67        Ok(Keys(idents))
68    }
69}
70
71/// Output equality option: `output_eq` or `output_eq = path`
72#[derive(Debug, Default)]
73enum OutputEq {
74    #[default]
75    None,
76    /// `output_eq` - use PartialEq
77    PartialEq,
78    /// `output_eq = path` - use custom function
79    Custom(syn::Path),
80}
81
82impl FromMeta for OutputEq {
83    fn from_word() -> darling::Result<Self> {
84        Ok(OutputEq::PartialEq)
85    }
86
87    fn from_value(value: &syn::Lit) -> darling::Result<Self> {
88        Err(darling::Error::unexpected_lit_type(value))
89    }
90
91    fn from_meta(item: &syn::Meta) -> darling::Result<Self> {
92        match item {
93            syn::Meta::Path(_) => Ok(OutputEq::PartialEq),
94            syn::Meta::NameValue(nv) => {
95                if let syn::Expr::Path(expr_path) = &nv.value {
96                    Ok(OutputEq::Custom(expr_path.path.clone()))
97                } else {
98                    Err(darling::Error::custom("expected path").with_span(&nv.value))
99                }
100            }
101            syn::Meta::List(_) => Err(darling::Error::unsupported_format("list")),
102        }
103    }
104}
105
106/// Options for the `#[query]` attribute.
107#[derive(Debug, Default, FromMeta)]
108struct QueryAttr {
109    /// Durability level (0-255). Default: 0 (volatile).
110    #[darling(default)]
111    durability: Option<u8>,
112
113    /// Output equality for early cutoff optimization.
114    /// Default: uses PartialEq (`old == new`).
115    /// `output_eq = path`: uses custom function for types without PartialEq.
116    #[darling(default)]
117    output_eq: OutputEq,
118
119    /// Params that form the cache key. Default: all params except ctx.
120    #[darling(default)]
121    keys: Keys,
122
123    /// Override the generated struct name. Default: PascalCase of function name.
124    #[darling(default)]
125    name: Option<String>,
126}
127
128/// A parsed function parameter.
129struct Param {
130    name: Ident,
131    ty: Type,
132}
133
134/// Parsed function information.
135struct ParsedFn {
136    vis: Visibility,
137    name: Ident,
138    params: Vec<Param>,
139    output_ty: Type,
140    body: TokenStream2,
141}
142
143/// Define a query from a function.
144///
145/// # Attributes
146///
147/// - `durability = N`: Set durability level (0-255, default 0)
148/// - `output_eq = path`: Custom equality function (default: PartialEq)
149/// - `keys(a, b, ...)`: Specify which params form the cache key
150/// - `name = "Name"`: Override generated struct name
151///
152/// # Example
153///
154/// ```ignore
155/// use query_flow::{query, QueryContext, QueryError};
156///
157/// // Basic query - all params are keys
158/// #[query]
159/// fn add(ctx: &mut QueryContext, a: i32, b: i32) -> Result<i32, QueryError> {
160///     Ok(a + b)
161/// }
162///
163/// // With options
164/// #[query(durability = 2, keys(id))]
165/// pub fn fetch_user(ctx: &mut QueryContext, id: u64, include_deleted: bool) -> Result<User, QueryError> {
166///     // include_deleted is NOT part of the cache key
167///     Ok(load_user(id, include_deleted))
168/// }
169/// ```
170#[proc_macro_attribute]
171pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream {
172    let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
173        Ok(v) => v,
174        Err(e) => return TokenStream::from(e.to_compile_error()),
175    };
176
177    let attr = match QueryAttr::from_list(&attr_args) {
178        Ok(v) => v,
179        Err(e) => return TokenStream::from(e.write_errors()),
180    };
181
182    let input_fn = parse_macro_input!(item as ItemFn);
183
184    match generate_query(attr, input_fn) {
185        Ok(tokens) => tokens.into(),
186        Err(e) => e.to_compile_error().into(),
187    }
188}
189
190fn generate_query(attr: QueryAttr, input_fn: ItemFn) -> Result<TokenStream2, Error> {
191    // Parse the function
192    let parsed = parse_function(&input_fn)?;
193
194    // Determine struct name
195    let struct_name = match &attr.name {
196        Some(name) => format_ident!("{}", name),
197        None => format_ident!("{}", parsed.name.to_string().to_upper_camel_case()),
198    };
199
200    // Determine which params are keys
201    let key_params: Vec<&Param> = if attr.keys.0.is_empty() {
202        // Default: all params are keys
203        parsed.params.iter().collect()
204    } else {
205        // Validate that specified keys exist
206        for key in &attr.keys.0 {
207            if !parsed.params.iter().any(|p| p.name == *key) {
208                return Err(Error::new(
209                    key.span(),
210                    format!("unknown parameter `{}` in keys", key),
211                ));
212            }
213        }
214        parsed
215            .params
216            .iter()
217            .filter(|p| attr.keys.0.contains(&p.name))
218            .collect()
219    };
220
221    // Generate struct definition
222    let struct_def = generate_struct(&parsed, &struct_name);
223
224    // Generate Query impl
225    let query_impl = generate_query_impl(&parsed, &struct_name, &key_params, &attr)?;
226
227    Ok(quote! {
228        #struct_def
229        #query_impl
230    })
231}
232
233fn parse_function(input_fn: &ItemFn) -> Result<ParsedFn, Error> {
234    let vis = input_fn.vis.clone();
235    let name = input_fn.sig.ident.clone();
236
237    // Check that first param is ctx: &mut QueryContext
238    let mut iter = input_fn.sig.inputs.iter();
239    let first_param = iter.next().ok_or_else(|| {
240        Error::new(
241            input_fn.sig.span(),
242            "query function must have `ctx: &mut QueryContext` as first parameter",
243        )
244    })?;
245
246    validate_ctx_param(first_param)?;
247
248    // Parse remaining params
249    let mut params = Vec::new();
250    for arg in iter {
251        match arg {
252            FnArg::Typed(pat_type) => {
253                let param = parse_param(pat_type)?;
254                params.push(param);
255            }
256            FnArg::Receiver(_) => {
257                return Err(Error::new(arg.span(), "query functions cannot have `self`"));
258            }
259        }
260    }
261
262    // Parse return type - must be Result<T, QueryError>
263    let output_ty = parse_return_type(&input_fn.sig.output)?;
264
265    // Get function body
266    let body = &input_fn.block;
267    let body_tokens = quote! { #body };
268
269    Ok(ParsedFn {
270        vis,
271        name,
272        params,
273        output_ty,
274        body: body_tokens,
275    })
276}
277
278fn validate_ctx_param(arg: &FnArg) -> Result<(), Error> {
279    match arg {
280        FnArg::Typed(pat_type) => {
281            // Check parameter name is 'ctx'
282            if let Pat::Ident(pat_ident) = &*pat_type.pat {
283                if pat_ident.ident != "ctx" {
284                    return Err(Error::new(
285                        pat_ident.ident.span(),
286                        "first parameter must be named `ctx`",
287                    ));
288                }
289            }
290            // Type checking is complex, we'll trust the user
291            Ok(())
292        }
293        FnArg::Receiver(_) => Err(Error::new(
294            arg.span(),
295            "first parameter must be `ctx: &mut QueryContext`, not `self`",
296        )),
297    }
298}
299
300fn parse_param(pat_type: &PatType) -> Result<Param, Error> {
301    let name = match &*pat_type.pat {
302        Pat::Ident(pat_ident) => pat_ident.ident.clone(),
303        _ => {
304            return Err(Error::new(
305                pat_type.pat.span(),
306                "expected simple identifier pattern",
307            ))
308        }
309    };
310
311    let ty = (*pat_type.ty).clone();
312
313    Ok(Param { name, ty })
314}
315
316fn parse_return_type(ret: &ReturnType) -> Result<Type, Error> {
317    match ret {
318        ReturnType::Default => Err(Error::new(
319            ret.span(),
320            "query function must return `Result<T, QueryError>`",
321        )),
322        ReturnType::Type(_, ty) => {
323            // Extract T from Result<T, QueryError>
324            // We need to parse the type and extract the first generic arg
325            extract_result_ok_type(ty)
326        }
327    }
328}
329
330fn extract_result_ok_type(ty: &Type) -> Result<Type, Error> {
331    if let Type::Path(type_path) = ty {
332        if let Some(segment) = type_path.path.segments.last() {
333            if segment.ident == "Result" {
334                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
335                    if let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first() {
336                        return Ok(ok_ty.clone());
337                    }
338                }
339            }
340        }
341    }
342    Err(Error::new(
343        ty.span(),
344        "expected `Result<T, QueryError>` return type",
345    ))
346}
347
348fn generate_struct(parsed: &ParsedFn, struct_name: &Ident) -> TokenStream2 {
349    let vis = &parsed.vis;
350    let fields: Vec<_> = parsed
351        .params
352        .iter()
353        .map(|p| {
354            let name = &p.name;
355            let ty = &p.ty;
356            quote! { pub #name: #ty }
357        })
358        .collect();
359
360    let field_names: Vec<_> = parsed.params.iter().map(|p| &p.name).collect();
361    let field_types: Vec<_> = parsed.params.iter().map(|p| &p.ty).collect();
362
363    let new_impl = if parsed.params.is_empty() {
364        quote! {
365            impl #struct_name {
366                /// Create a new query instance.
367                #vis fn new() -> Self {
368                    Self {}
369                }
370            }
371
372            impl ::std::default::Default for #struct_name {
373                fn default() -> Self {
374                    Self::new()
375                }
376            }
377        }
378    } else {
379        quote! {
380            impl #struct_name {
381                /// Create a new query instance.
382                #vis fn new(#( #field_names: #field_types ),*) -> Self {
383                    Self { #( #field_names ),* }
384                }
385            }
386        }
387    };
388
389    quote! {
390        #[derive(Clone, Debug)]
391        #vis struct #struct_name {
392            #( #fields ),*
393        }
394
395        #new_impl
396    }
397}
398
399fn generate_query_impl(
400    parsed: &ParsedFn,
401    struct_name: &Ident,
402    key_params: &[&Param],
403    attr: &QueryAttr,
404) -> Result<TokenStream2, Error> {
405    let output_ty = &parsed.output_ty;
406
407    // Generate CacheKey type
408    let cache_key_ty = match key_params.len() {
409        0 => quote! { () },
410        1 => {
411            let ty = &key_params[0].ty;
412            quote! { #ty }
413        }
414        _ => {
415            let types: Vec<_> = key_params.iter().map(|p| &p.ty).collect();
416            quote! { ( #( #types ),* ) }
417        }
418    };
419
420    // Generate cache_key() body
421    let cache_key_body = match key_params.len() {
422        0 => quote! { () },
423        1 => {
424            let name = &key_params[0].name;
425            quote! { self.#name.clone() }
426        }
427        _ => {
428            let names: Vec<_> = key_params.iter().map(|p| &p.name).collect();
429            quote! { ( #( self.#names.clone() ),* ) }
430        }
431    };
432
433    // Generate query() body - bind fields to local variables
434    let field_bindings: Vec<_> = parsed
435        .params
436        .iter()
437        .map(|p| {
438            let name = &p.name;
439            quote! { let #name = &self.#name; }
440        })
441        .collect();
442
443    let fn_body = &parsed.body;
444
445    // Generate optional trait methods
446    let durability_impl = attr.durability.map(|d| {
447        quote! {
448            fn durability(&self) -> u8 {
449                #d
450            }
451        }
452    });
453
454    let output_eq_impl = match &attr.output_eq {
455        // Default: use PartialEq
456        OutputEq::None | OutputEq::PartialEq => quote! {
457            fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
458                old == new
459            }
460        },
461        // `output_eq = path`: use custom function
462        OutputEq::Custom(custom_fn) => quote! {
463            fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
464                #custom_fn(old, new)
465            }
466        },
467    };
468
469    Ok(quote! {
470        impl ::query_flow::Query for #struct_name {
471            type CacheKey = #cache_key_ty;
472            type Output = #output_ty;
473
474            fn cache_key(&self) -> Self::CacheKey {
475                #cache_key_body
476            }
477
478            fn query(&self, ctx: &mut ::query_flow::QueryContext) -> ::std::result::Result<Self::Output, ::query_flow::QueryError> {
479                #( #field_bindings )*
480                #fn_body
481            }
482
483            #durability_impl
484            #output_eq_impl
485        }
486    })
487}
488
489// ============================================================================
490// Asset Key Macro
491// ============================================================================
492
493/// Named durability levels for assets.
494#[derive(Debug, Clone, Copy, Default)]
495enum DurabilityAttr {
496    #[default]
497    Volatile,
498    Session,
499    Stable,
500    Constant,
501}
502
503impl FromMeta for DurabilityAttr {
504    fn from_string(value: &str) -> darling::Result<Self> {
505        Self::parse_str(value)
506    }
507
508    fn from_expr(expr: &syn::Expr) -> darling::Result<Self> {
509        // Handle identifier like `constant` without quotes
510        if let syn::Expr::Path(expr_path) = expr {
511            if let Some(ident) = expr_path.path.get_ident() {
512                return Self::parse_str(&ident.to_string());
513            }
514        }
515        Err(darling::Error::custom("expected durability level: volatile, session, stable, or constant"))
516    }
517}
518
519impl DurabilityAttr {
520    fn parse_str(value: &str) -> darling::Result<Self> {
521        match value.to_lowercase().as_str() {
522            "volatile" => Ok(DurabilityAttr::Volatile),
523            "session" => Ok(DurabilityAttr::Session),
524            "stable" => Ok(DurabilityAttr::Stable),
525            "constant" => Ok(DurabilityAttr::Constant),
526            _ => Err(darling::Error::unknown_value(value)),
527        }
528    }
529}
530
531/// Wrapper for parsing a type from attribute: `asset = String` or `asset = Vec<u8>`
532#[derive(Debug)]
533struct TypeWrapper(syn::Type);
534
535impl FromMeta for TypeWrapper {
536    fn from_expr(expr: &syn::Expr) -> darling::Result<Self> {
537        // Convert expression to token stream and parse as type
538        let tokens = quote! { #expr };
539        syn::parse2::<syn::Type>(tokens)
540            .map(TypeWrapper)
541            .map_err(|e| darling::Error::custom(format!("invalid type: {}", e)))
542    }
543}
544
545/// Options for the `#[asset_key]` attribute.
546#[derive(Debug, FromMeta)]
547struct AssetKeyAttr {
548    /// The asset type this key loads (required).
549    asset: TypeWrapper,
550
551    /// Durability level. Default: volatile.
552    #[darling(default)]
553    durability: DurabilityAttr,
554
555    /// Asset equality for early cutoff optimization.
556    /// Default: uses PartialEq (`old == new`).
557    /// `asset_eq = path`: uses custom function for types without PartialEq.
558    #[darling(default)]
559    asset_eq: OutputEq,
560}
561
562/// Define an asset key type.
563///
564/// # Attributes
565///
566/// - `asset = Type`: The asset type this key loads (required)
567/// - `durability = volatile|session|stable|constant`: Durability level (default: volatile)
568/// - `asset_eq`: Use PartialEq for asset comparison (default)
569/// - `asset_eq = path`: Use custom function for asset comparison
570///
571/// # Example
572///
573/// ```ignore
574/// use query_flow::asset_key;
575/// use std::path::PathBuf;
576///
577/// // Default: volatile durability
578/// #[asset_key(asset = String)]
579/// pub struct ConfigFile(pub PathBuf);
580///
581/// // Explicit constant durability for bundled assets
582/// #[asset_key(asset = String, durability = constant)]
583/// pub struct BundledFile(pub PathBuf);
584///
585/// // Custom equality
586/// #[asset_key(asset = ImageData, asset_eq = image_bytes_eq)]
587/// pub struct TexturePath(pub String);
588/// ```
589#[proc_macro_attribute]
590pub fn asset_key(attr: TokenStream, item: TokenStream) -> TokenStream {
591    let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
592        Ok(v) => v,
593        Err(e) => return TokenStream::from(e.to_compile_error()),
594    };
595
596    let attr = match AssetKeyAttr::from_list(&attr_args) {
597        Ok(v) => v,
598        Err(e) => return TokenStream::from(e.write_errors()),
599    };
600
601    let input_struct = parse_macro_input!(item as ItemStruct);
602
603    match generate_asset_key(attr, input_struct) {
604        Ok(tokens) => tokens.into(),
605        Err(e) => e.to_compile_error().into(),
606    }
607}
608
609fn generate_asset_key(attr: AssetKeyAttr, input_struct: ItemStruct) -> Result<TokenStream2, Error> {
610    let struct_name = &input_struct.ident;
611    let asset_ty = &attr.asset.0;
612
613    // Generate durability method
614    let durability_impl = match attr.durability {
615        DurabilityAttr::Volatile => quote! {
616            fn durability(&self) -> ::query_flow::DurabilityLevel {
617                ::query_flow::DurabilityLevel::Volatile
618            }
619        },
620        DurabilityAttr::Session => quote! {
621            fn durability(&self) -> ::query_flow::DurabilityLevel {
622                ::query_flow::DurabilityLevel::Session
623            }
624        },
625        DurabilityAttr::Stable => quote! {
626            fn durability(&self) -> ::query_flow::DurabilityLevel {
627                ::query_flow::DurabilityLevel::Stable
628            }
629        },
630        DurabilityAttr::Constant => quote! {
631            fn durability(&self) -> ::query_flow::DurabilityLevel {
632                ::query_flow::DurabilityLevel::Constant
633            }
634        },
635    };
636
637    // Generate asset_eq method
638    let asset_eq_impl = match &attr.asset_eq {
639        OutputEq::None | OutputEq::PartialEq => quote! {
640            fn asset_eq(old: &Self::Asset, new: &Self::Asset) -> bool {
641                old == new
642            }
643        },
644        OutputEq::Custom(custom_fn) => quote! {
645            fn asset_eq(old: &Self::Asset, new: &Self::Asset) -> bool {
646                #custom_fn(old, new)
647            }
648        },
649    };
650
651    Ok(quote! {
652        #[derive(Clone, Debug, PartialEq, Eq, Hash)]
653        #input_struct
654
655        impl ::query_flow::AssetKey for #struct_name {
656            type Asset = #asset_ty;
657
658            #asset_eq_impl
659            #durability_impl
660        }
661    })
662}