query_flow_macros/
lib.rs

1//! Procedural macros for query-flow.
2//!
3//! This crate provides attribute macros for defining queries and asset keys
4//! with minimal boilerplate.
5//!
6//! # Query Example
7//!
8//! ```ignore
9//! use query_flow::{query, QueryContext, QueryError};
10//!
11//! #[query]
12//! pub fn add(ctx: &mut QueryContext, a: i32, b: i32) -> Result<i32, QueryError> {
13//!     Ok(a + b)
14//! }
15//!
16//! // Generates:
17//! // pub struct Add { pub a: i32, pub b: i32 }
18//! // impl Query for Add { ... }
19//! ```
20//!
21//! # Asset Key Example
22//!
23//! ```ignore
24//! use query_flow::asset_key;
25//!
26//! #[asset_key(asset = String)]
27//! pub struct ConfigFile(pub PathBuf);
28//!
29//! #[asset_key(asset = String, durability = constant)]
30//! pub struct BundledAsset(pub PathBuf);
31//!
32//! // Generates:
33//! // impl AssetKey for ConfigFile { type Asset = String; ... }
34//! ```
35
36use darling::{ast::NestedMeta, FromMeta};
37use heck::ToUpperCamelCase;
38use proc_macro::TokenStream;
39use proc_macro2::TokenStream as TokenStream2;
40use quote::{format_ident, quote};
41use syn::{
42    parse_macro_input, spanned::Spanned, Error, FnArg, Ident, ItemFn, ItemStruct, Pat, PatType,
43    ReturnType, Type, Visibility,
44};
45
46/// Wrapper for parsing a list of identifiers: `keys(a, b, c)`
47#[derive(Debug, Default)]
48struct Keys(Vec<Ident>);
49
50impl FromMeta for Keys {
51    fn from_list(items: &[NestedMeta]) -> darling::Result<Self> {
52        let mut idents = Vec::new();
53        for item in items {
54            match item {
55                NestedMeta::Meta(syn::Meta::Path(path)) => {
56                    if let Some(ident) = path.get_ident() {
57                        idents.push(ident.clone());
58                    } else {
59                        return Err(darling::Error::custom("expected identifier").with_span(path));
60                    }
61                }
62                _ => {
63                    return Err(darling::Error::custom("expected identifier"));
64                }
65            }
66        }
67        Ok(Keys(idents))
68    }
69}
70
71/// Output equality option: `output_eq` or `output_eq = path`
72#[derive(Debug, Default)]
73enum OutputEq {
74    #[default]
75    None,
76    /// `output_eq` - use PartialEq
77    PartialEq,
78    /// `output_eq = path` - use custom function
79    Custom(syn::Path),
80}
81
82impl FromMeta for OutputEq {
83    fn from_word() -> darling::Result<Self> {
84        Ok(OutputEq::PartialEq)
85    }
86
87    fn from_value(value: &syn::Lit) -> darling::Result<Self> {
88        Err(darling::Error::unexpected_lit_type(value))
89    }
90
91    fn from_meta(item: &syn::Meta) -> darling::Result<Self> {
92        match item {
93            syn::Meta::Path(_) => Ok(OutputEq::PartialEq),
94            syn::Meta::NameValue(nv) => {
95                if let syn::Expr::Path(expr_path) = &nv.value {
96                    Ok(OutputEq::Custom(expr_path.path.clone()))
97                } else {
98                    Err(darling::Error::custom("expected path").with_span(&nv.value))
99                }
100            }
101            syn::Meta::List(_) => Err(darling::Error::unsupported_format("list")),
102        }
103    }
104}
105
106/// Options for the `#[query]` attribute.
107#[derive(Debug, Default, FromMeta)]
108struct QueryAttr {
109    /// Durability level (0-255). Default: 0 (volatile).
110    #[darling(default)]
111    durability: Option<u8>,
112
113    /// Output equality for early cutoff optimization.
114    /// Default: uses PartialEq (`old == new`).
115    /// `output_eq = path`: uses custom function for types without PartialEq.
116    #[darling(default)]
117    output_eq: OutputEq,
118
119    /// Params that form the cache key. Default: all params except ctx.
120    #[darling(default)]
121    keys: Keys,
122
123    /// Override the generated struct name. Default: PascalCase of function name.
124    #[darling(default)]
125    name: Option<String>,
126}
127
128/// A parsed function parameter.
129struct Param {
130    name: Ident,
131    ty: Type,
132}
133
134/// Parsed function information.
135struct ParsedFn {
136    vis: Visibility,
137    name: Ident,
138    params: Vec<Param>,
139    output_ty: Type,
140    body: TokenStream2,
141    /// Attributes from the original function (e.g., #[tracing::instrument])
142    attrs: Vec<syn::Attribute>,
143}
144
145/// Define a query from a function.
146///
147/// # Attributes
148///
149/// - `durability = N`: Set durability level (0-255, default 0)
150/// - `output_eq = path`: Custom equality function (default: PartialEq)
151/// - `keys(a, b, ...)`: Specify which params form the cache key
152/// - `name = "Name"`: Override generated struct name
153///
154/// # Example
155///
156/// ```ignore
157/// use query_flow::{query, QueryContext, QueryError};
158///
159/// // Basic query - all params are keys
160/// #[query]
161/// fn add(ctx: &mut QueryContext, a: i32, b: i32) -> Result<i32, QueryError> {
162///     Ok(a + b)
163/// }
164///
165/// // With options
166/// #[query(durability = 2, keys(id))]
167/// pub fn fetch_user(ctx: &mut QueryContext, id: u64, include_deleted: bool) -> Result<User, QueryError> {
168///     // include_deleted is NOT part of the cache key
169///     Ok(load_user(id, include_deleted))
170/// }
171/// ```
172#[proc_macro_attribute]
173pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream {
174    let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
175        Ok(v) => v,
176        Err(e) => return TokenStream::from(e.to_compile_error()),
177    };
178
179    let attr = match QueryAttr::from_list(&attr_args) {
180        Ok(v) => v,
181        Err(e) => return TokenStream::from(e.write_errors()),
182    };
183
184    let input_fn = parse_macro_input!(item as ItemFn);
185
186    match generate_query(attr, input_fn) {
187        Ok(tokens) => tokens.into(),
188        Err(e) => e.to_compile_error().into(),
189    }
190}
191
192fn generate_query(attr: QueryAttr, input_fn: ItemFn) -> Result<TokenStream2, Error> {
193    // Parse the function
194    let parsed = parse_function(&input_fn)?;
195
196    // Determine struct name
197    let struct_name = match &attr.name {
198        Some(name) => format_ident!("{}", name),
199        None => format_ident!("{}", parsed.name.to_string().to_upper_camel_case()),
200    };
201
202    // Determine which params are keys
203    let key_params: Vec<&Param> = if attr.keys.0.is_empty() {
204        // Default: all params are keys
205        parsed.params.iter().collect()
206    } else {
207        // Validate that specified keys exist
208        for key in &attr.keys.0 {
209            if !parsed.params.iter().any(|p| p.name == *key) {
210                return Err(Error::new(
211                    key.span(),
212                    format!("unknown parameter `{}` in keys", key),
213                ));
214            }
215        }
216        parsed
217            .params
218            .iter()
219            .filter(|p| attr.keys.0.contains(&p.name))
220            .collect()
221    };
222
223    // Generate struct definition
224    let struct_def = generate_struct(&parsed, &struct_name);
225
226    // Generate Query impl
227    let query_impl = generate_query_impl(&parsed, &struct_name, &key_params, &attr)?;
228
229    Ok(quote! {
230        #struct_def
231        #query_impl
232    })
233}
234
235fn parse_function(input_fn: &ItemFn) -> Result<ParsedFn, Error> {
236    let vis = input_fn.vis.clone();
237    let name = input_fn.sig.ident.clone();
238    // Extract function attributes (e.g., #[tracing::instrument], #[inline])
239    let attrs = input_fn.attrs.clone();
240
241    // Check that first param is ctx: &mut QueryContext
242    let mut iter = input_fn.sig.inputs.iter();
243    let first_param = iter.next().ok_or_else(|| {
244        Error::new(
245            input_fn.sig.span(),
246            "query function must have `ctx: &mut QueryContext` as first parameter",
247        )
248    })?;
249
250    validate_ctx_param(first_param)?;
251
252    // Parse remaining params
253    let mut params = Vec::new();
254    for arg in iter {
255        match arg {
256            FnArg::Typed(pat_type) => {
257                let param = parse_param(pat_type)?;
258                params.push(param);
259            }
260            FnArg::Receiver(_) => {
261                return Err(Error::new(arg.span(), "query functions cannot have `self`"));
262            }
263        }
264    }
265
266    // Parse return type - must be Result<T, QueryError>
267    let output_ty = parse_return_type(&input_fn.sig.output)?;
268
269    // Get function body
270    let body = &input_fn.block;
271    let body_tokens = quote! { #body };
272
273    Ok(ParsedFn {
274        vis,
275        name,
276        params,
277        output_ty,
278        body: body_tokens,
279        attrs,
280    })
281}
282
283fn validate_ctx_param(arg: &FnArg) -> Result<(), Error> {
284    match arg {
285        FnArg::Typed(pat_type) => {
286            // Check parameter name is 'ctx'
287            if let Pat::Ident(pat_ident) = &*pat_type.pat {
288                if pat_ident.ident != "ctx" {
289                    return Err(Error::new(
290                        pat_ident.ident.span(),
291                        "first parameter must be named `ctx`",
292                    ));
293                }
294            }
295            // Type checking is complex, we'll trust the user
296            Ok(())
297        }
298        FnArg::Receiver(_) => Err(Error::new(
299            arg.span(),
300            "first parameter must be `ctx: &mut QueryContext`, not `self`",
301        )),
302    }
303}
304
305fn parse_param(pat_type: &PatType) -> Result<Param, Error> {
306    let name = match &*pat_type.pat {
307        Pat::Ident(pat_ident) => pat_ident.ident.clone(),
308        _ => {
309            return Err(Error::new(
310                pat_type.pat.span(),
311                "expected simple identifier pattern",
312            ))
313        }
314    };
315
316    let ty = (*pat_type.ty).clone();
317
318    Ok(Param { name, ty })
319}
320
321fn parse_return_type(ret: &ReturnType) -> Result<Type, Error> {
322    match ret {
323        ReturnType::Default => Err(Error::new(
324            ret.span(),
325            "query function must return `Result<T, QueryError>`",
326        )),
327        ReturnType::Type(_, ty) => {
328            // Extract T from Result<T, QueryError>
329            // We need to parse the type and extract the first generic arg
330            extract_result_ok_type(ty)
331        }
332    }
333}
334
335fn extract_result_ok_type(ty: &Type) -> Result<Type, Error> {
336    if let Type::Path(type_path) = ty {
337        if let Some(segment) = type_path.path.segments.last() {
338            if segment.ident == "Result" {
339                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
340                    if let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first() {
341                        return Ok(ok_ty.clone());
342                    }
343                }
344            }
345        }
346    }
347    Err(Error::new(
348        ty.span(),
349        "expected `Result<T, QueryError>` return type",
350    ))
351}
352
353fn generate_struct(parsed: &ParsedFn, struct_name: &Ident) -> TokenStream2 {
354    let vis = &parsed.vis;
355    let fields: Vec<_> = parsed
356        .params
357        .iter()
358        .map(|p| {
359            let name = &p.name;
360            let ty = &p.ty;
361            quote! { pub #name: #ty }
362        })
363        .collect();
364
365    let field_names: Vec<_> = parsed.params.iter().map(|p| &p.name).collect();
366    let field_types: Vec<_> = parsed.params.iter().map(|p| &p.ty).collect();
367
368    let new_impl = if parsed.params.is_empty() {
369        quote! {
370            impl #struct_name {
371                /// Create a new query instance.
372                #vis fn new() -> Self {
373                    Self {}
374                }
375            }
376
377            impl ::std::default::Default for #struct_name {
378                fn default() -> Self {
379                    Self::new()
380                }
381            }
382        }
383    } else {
384        quote! {
385            impl #struct_name {
386                /// Create a new query instance.
387                #vis fn new(#( #field_names: #field_types ),*) -> Self {
388                    Self { #( #field_names ),* }
389                }
390            }
391        }
392    };
393
394    quote! {
395        #[derive(Clone, Debug)]
396        #vis struct #struct_name {
397            #( #fields ),*
398        }
399
400        #new_impl
401    }
402}
403
404fn generate_query_impl(
405    parsed: &ParsedFn,
406    struct_name: &Ident,
407    key_params: &[&Param],
408    attr: &QueryAttr,
409) -> Result<TokenStream2, Error> {
410    let output_ty = &parsed.output_ty;
411
412    // Generate CacheKey type
413    let cache_key_ty = match key_params.len() {
414        0 => quote! { () },
415        1 => {
416            let ty = &key_params[0].ty;
417            quote! { #ty }
418        }
419        _ => {
420            let types: Vec<_> = key_params.iter().map(|p| &p.ty).collect();
421            quote! { ( #( #types ),* ) }
422        }
423    };
424
425    // Generate cache_key() body
426    let cache_key_body = match key_params.len() {
427        0 => quote! { () },
428        1 => {
429            let name = &key_params[0].name;
430            quote! { self.#name.clone() }
431        }
432        _ => {
433            let names: Vec<_> = key_params.iter().map(|p| &p.name).collect();
434            quote! { ( #( self.#names.clone() ),* ) }
435        }
436    };
437
438    // Generate query() body - bind fields to local variables
439    let field_bindings: Vec<_> = parsed
440        .params
441        .iter()
442        .map(|p| {
443            let name = &p.name;
444            quote! { let #name = &self.#name; }
445        })
446        .collect();
447
448    let fn_body = &parsed.body;
449
450    // Generate optional trait methods
451    let durability_impl = attr.durability.map(|d| {
452        quote! {
453            fn durability(&self) -> u8 {
454                #d
455            }
456        }
457    });
458
459    let output_eq_impl = match &attr.output_eq {
460        // Default: use PartialEq
461        OutputEq::None | OutputEq::PartialEq => quote! {
462            fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
463                old == new
464            }
465        },
466        // `output_eq = path`: use custom function
467        OutputEq::Custom(custom_fn) => quote! {
468            fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
469                #custom_fn(old, new)
470            }
471        },
472    };
473
474    // Attributes from the original function to apply to the query method
475    let fn_attrs = &parsed.attrs;
476
477    Ok(quote! {
478        impl ::query_flow::Query for #struct_name {
479            type CacheKey = #cache_key_ty;
480            type Output = #output_ty;
481
482            fn cache_key(&self) -> Self::CacheKey {
483                #cache_key_body
484            }
485
486            #( #fn_attrs )*
487            fn query(&self, ctx: &mut ::query_flow::QueryContext) -> ::std::result::Result<Self::Output, ::query_flow::QueryError> {
488                #( #field_bindings )*
489                #fn_body
490            }
491
492            #durability_impl
493            #output_eq_impl
494        }
495    })
496}
497
498// ============================================================================
499// Asset Key Macro
500// ============================================================================
501
502/// Named durability levels for assets.
503#[derive(Debug, Clone, Copy, Default)]
504enum DurabilityAttr {
505    #[default]
506    Volatile,
507    Session,
508    Stable,
509    Constant,
510}
511
512impl FromMeta for DurabilityAttr {
513    fn from_string(value: &str) -> darling::Result<Self> {
514        Self::parse_str(value)
515    }
516
517    fn from_expr(expr: &syn::Expr) -> darling::Result<Self> {
518        // Handle identifier like `constant` without quotes
519        if let syn::Expr::Path(expr_path) = expr {
520            if let Some(ident) = expr_path.path.get_ident() {
521                return Self::parse_str(&ident.to_string());
522            }
523        }
524        Err(darling::Error::custom(
525            "expected durability level: volatile, session, stable, or constant",
526        ))
527    }
528}
529
530impl DurabilityAttr {
531    fn parse_str(value: &str) -> darling::Result<Self> {
532        match value.to_lowercase().as_str() {
533            "volatile" => Ok(DurabilityAttr::Volatile),
534            "session" => Ok(DurabilityAttr::Session),
535            "stable" => Ok(DurabilityAttr::Stable),
536            "constant" => Ok(DurabilityAttr::Constant),
537            _ => Err(darling::Error::unknown_value(value)),
538        }
539    }
540}
541
542/// Wrapper for parsing a type from attribute: `asset = String` or `asset = Vec<u8>`
543#[derive(Debug)]
544struct TypeWrapper(syn::Type);
545
546impl FromMeta for TypeWrapper {
547    fn from_expr(expr: &syn::Expr) -> darling::Result<Self> {
548        // Convert expression to token stream and parse as type
549        let tokens = quote! { #expr };
550        syn::parse2::<syn::Type>(tokens)
551            .map(TypeWrapper)
552            .map_err(|e| darling::Error::custom(format!("invalid type: {}", e)))
553    }
554}
555
556/// Options for the `#[asset_key]` attribute.
557#[derive(Debug, FromMeta)]
558struct AssetKeyAttr {
559    /// The asset type this key loads (required).
560    asset: TypeWrapper,
561
562    /// Durability level. Default: volatile.
563    #[darling(default)]
564    durability: DurabilityAttr,
565
566    /// Asset equality for early cutoff optimization.
567    /// Default: uses PartialEq (`old == new`).
568    /// `asset_eq = path`: uses custom function for types without PartialEq.
569    #[darling(default)]
570    asset_eq: OutputEq,
571}
572
573/// Define an asset key type.
574///
575/// # Attributes
576///
577/// - `asset = Type`: The asset type this key loads (required)
578/// - `durability = volatile|session|stable|constant`: Durability level (default: volatile)
579/// - `asset_eq`: Use PartialEq for asset comparison (default)
580/// - `asset_eq = path`: Use custom function for asset comparison
581///
582/// # Example
583///
584/// ```ignore
585/// use query_flow::asset_key;
586/// use std::path::PathBuf;
587///
588/// // Default: volatile durability
589/// #[asset_key(asset = String)]
590/// pub struct ConfigFile(pub PathBuf);
591///
592/// // Explicit constant durability for bundled assets
593/// #[asset_key(asset = String, durability = constant)]
594/// pub struct BundledFile(pub PathBuf);
595///
596/// // Custom equality
597/// #[asset_key(asset = ImageData, asset_eq = image_bytes_eq)]
598/// pub struct TexturePath(pub String);
599/// ```
600#[proc_macro_attribute]
601pub fn asset_key(attr: TokenStream, item: TokenStream) -> TokenStream {
602    let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
603        Ok(v) => v,
604        Err(e) => return TokenStream::from(e.to_compile_error()),
605    };
606
607    let attr = match AssetKeyAttr::from_list(&attr_args) {
608        Ok(v) => v,
609        Err(e) => return TokenStream::from(e.write_errors()),
610    };
611
612    let input_struct = parse_macro_input!(item as ItemStruct);
613
614    match generate_asset_key(attr, input_struct) {
615        Ok(tokens) => tokens.into(),
616        Err(e) => e.to_compile_error().into(),
617    }
618}
619
620fn generate_asset_key(attr: AssetKeyAttr, input_struct: ItemStruct) -> Result<TokenStream2, Error> {
621    let struct_name = &input_struct.ident;
622    let asset_ty = &attr.asset.0;
623
624    // Generate durability method
625    let durability_impl = match attr.durability {
626        DurabilityAttr::Volatile => quote! {
627            fn durability(&self) -> ::query_flow::DurabilityLevel {
628                ::query_flow::DurabilityLevel::Volatile
629            }
630        },
631        DurabilityAttr::Session => quote! {
632            fn durability(&self) -> ::query_flow::DurabilityLevel {
633                ::query_flow::DurabilityLevel::Session
634            }
635        },
636        DurabilityAttr::Stable => quote! {
637            fn durability(&self) -> ::query_flow::DurabilityLevel {
638                ::query_flow::DurabilityLevel::Stable
639            }
640        },
641        DurabilityAttr::Constant => quote! {
642            fn durability(&self) -> ::query_flow::DurabilityLevel {
643                ::query_flow::DurabilityLevel::Constant
644            }
645        },
646    };
647
648    // Generate asset_eq method
649    let asset_eq_impl = match &attr.asset_eq {
650        OutputEq::None | OutputEq::PartialEq => quote! {
651            fn asset_eq(old: &Self::Asset, new: &Self::Asset) -> bool {
652                old == new
653            }
654        },
655        OutputEq::Custom(custom_fn) => quote! {
656            fn asset_eq(old: &Self::Asset, new: &Self::Asset) -> bool {
657                #custom_fn(old, new)
658            }
659        },
660    };
661
662    Ok(quote! {
663        #[derive(Clone, Debug, PartialEq, Eq, Hash)]
664        #input_struct
665
666        impl ::query_flow::AssetKey for #struct_name {
667            type Asset = #asset_ty;
668
669            #asset_eq_impl
670            #durability_impl
671        }
672    })
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use quote::quote;
679
680    fn normalize_tokens(tokens: TokenStream2) -> String {
681        tokens
682            .to_string()
683            .split_whitespace()
684            .collect::<Vec<_>>()
685            .join(" ")
686    }
687
688    #[test]
689    fn test_query_macro_preserves_attributes() {
690        let input_fn: ItemFn = syn::parse_quote! {
691            #[allow(unused_variables)]
692            #[inline]
693            fn my_query(ctx: &mut QueryContext, x: i32) -> Result<i32, QueryError> {
694                let unused = 42;
695                Ok(x * 2)
696            }
697        };
698
699        let attr = QueryAttr::default();
700        let output = generate_query(attr, input_fn).unwrap();
701
702        let expected = quote! {
703            #[derive(Clone, Debug)]
704            struct MyQuery {
705                pub x: i32
706            }
707
708            impl MyQuery {
709                #[doc = r" Create a new query instance."]
710                fn new(x: i32) -> Self {
711                    Self { x }
712                }
713            }
714
715            impl ::query_flow::Query for MyQuery {
716                type CacheKey = i32;
717                type Output = i32;
718
719                fn cache_key(&self) -> Self::CacheKey {
720                    self.x.clone()
721                }
722
723                #[allow(unused_variables)]
724                #[inline]
725                fn query(&self, ctx: &mut ::query_flow::QueryContext) -> ::std::result::Result<Self::Output, ::query_flow::QueryError> {
726                    let x = &self.x;
727                    {
728                        let unused = 42;
729                        Ok(x * 2)
730                    }
731                }
732
733                fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
734                    old == new
735                }
736            }
737        };
738
739        assert_eq!(normalize_tokens(output), normalize_tokens(expected));
740    }
741
742    #[test]
743    fn test_query_macro_without_attributes() {
744        let input_fn: ItemFn = syn::parse_quote! {
745            fn simple(ctx: &mut QueryContext, a: i32, b: i32) -> Result<i32, QueryError> {
746                Ok(a + b)
747            }
748        };
749
750        let attr = QueryAttr::default();
751        let output = generate_query(attr, input_fn).unwrap();
752
753        let expected = quote! {
754            #[derive(Clone, Debug)]
755            struct Simple {
756                pub a: i32,
757                pub b: i32
758            }
759
760            impl Simple {
761                #[doc = r" Create a new query instance."]
762                fn new(a: i32, b: i32) -> Self {
763                    Self { a, b }
764                }
765            }
766
767            impl ::query_flow::Query for Simple {
768                type CacheKey = (i32, i32);
769                type Output = i32;
770
771                fn cache_key(&self) -> Self::CacheKey {
772                    (self.a.clone(), self.b.clone())
773                }
774
775                fn query(&self, ctx: &mut ::query_flow::QueryContext) -> ::std::result::Result<Self::Output, ::query_flow::QueryError> {
776                    let a = &self.a;
777                    let b = &self.b;
778                    {
779                        Ok(a + b)
780                    }
781                }
782
783                fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
784                    old == new
785                }
786            }
787        };
788
789        assert_eq!(normalize_tokens(output), normalize_tokens(expected));
790    }
791}