1use darling::{ast::NestedMeta, FromMeta};
37use heck::ToUpperCamelCase;
38use proc_macro::TokenStream;
39use proc_macro2::TokenStream as TokenStream2;
40use quote::{format_ident, quote};
41use syn::{
42 parse_macro_input, spanned::Spanned, Error, FnArg, Ident, ItemFn, ItemStruct, Pat, PatType,
43 ReturnType, Type, Visibility,
44};
45
46#[derive(Debug, Default)]
48struct Keys(Vec<Ident>);
49
50impl FromMeta for Keys {
51 fn from_list(items: &[NestedMeta]) -> darling::Result<Self> {
52 let mut idents = Vec::new();
53 for item in items {
54 match item {
55 NestedMeta::Meta(syn::Meta::Path(path)) => {
56 if let Some(ident) = path.get_ident() {
57 idents.push(ident.clone());
58 } else {
59 return Err(darling::Error::custom("expected identifier").with_span(path));
60 }
61 }
62 _ => {
63 return Err(darling::Error::custom("expected identifier"));
64 }
65 }
66 }
67 Ok(Keys(idents))
68 }
69}
70
71#[derive(Debug, Default)]
73enum OutputEq {
74 #[default]
75 None,
76 PartialEq,
78 Custom(syn::Path),
80}
81
82impl FromMeta for OutputEq {
83 fn from_word() -> darling::Result<Self> {
84 Ok(OutputEq::PartialEq)
85 }
86
87 fn from_value(value: &syn::Lit) -> darling::Result<Self> {
88 Err(darling::Error::unexpected_lit_type(value))
89 }
90
91 fn from_meta(item: &syn::Meta) -> darling::Result<Self> {
92 match item {
93 syn::Meta::Path(_) => Ok(OutputEq::PartialEq),
94 syn::Meta::NameValue(nv) => {
95 if let syn::Expr::Path(expr_path) = &nv.value {
96 Ok(OutputEq::Custom(expr_path.path.clone()))
97 } else {
98 Err(darling::Error::custom("expected path").with_span(&nv.value))
99 }
100 }
101 syn::Meta::List(_) => Err(darling::Error::unsupported_format("list")),
102 }
103 }
104}
105
106#[derive(Debug, Default, FromMeta)]
108struct QueryAttr {
109 #[darling(default)]
111 durability: Option<u8>,
112
113 #[darling(default)]
117 output_eq: OutputEq,
118
119 #[darling(default)]
121 keys: Keys,
122
123 #[darling(default)]
125 name: Option<String>,
126}
127
128struct Param {
130 name: Ident,
131 ty: Type,
132}
133
134struct ParsedFn {
136 vis: Visibility,
137 name: Ident,
138 params: Vec<Param>,
139 output_ty: Type,
140 body: TokenStream2,
141 attrs: Vec<syn::Attribute>,
143}
144
145#[proc_macro_attribute]
173pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream {
174 let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
175 Ok(v) => v,
176 Err(e) => return TokenStream::from(e.to_compile_error()),
177 };
178
179 let attr = match QueryAttr::from_list(&attr_args) {
180 Ok(v) => v,
181 Err(e) => return TokenStream::from(e.write_errors()),
182 };
183
184 let input_fn = parse_macro_input!(item as ItemFn);
185
186 match generate_query(attr, input_fn) {
187 Ok(tokens) => tokens.into(),
188 Err(e) => e.to_compile_error().into(),
189 }
190}
191
192fn generate_query(attr: QueryAttr, input_fn: ItemFn) -> Result<TokenStream2, Error> {
193 let parsed = parse_function(&input_fn)?;
195
196 let struct_name = match &attr.name {
198 Some(name) => format_ident!("{}", name),
199 None => format_ident!("{}", parsed.name.to_string().to_upper_camel_case()),
200 };
201
202 let key_params: Vec<&Param> = if attr.keys.0.is_empty() {
204 parsed.params.iter().collect()
206 } else {
207 for key in &attr.keys.0 {
209 if !parsed.params.iter().any(|p| p.name == *key) {
210 return Err(Error::new(
211 key.span(),
212 format!("unknown parameter `{}` in keys", key),
213 ));
214 }
215 }
216 parsed
217 .params
218 .iter()
219 .filter(|p| attr.keys.0.contains(&p.name))
220 .collect()
221 };
222
223 let struct_def = generate_struct(&parsed, &struct_name);
225
226 let query_impl = generate_query_impl(&parsed, &struct_name, &key_params, &attr)?;
228
229 Ok(quote! {
230 #struct_def
231 #query_impl
232 })
233}
234
235fn parse_function(input_fn: &ItemFn) -> Result<ParsedFn, Error> {
236 let vis = input_fn.vis.clone();
237 let name = input_fn.sig.ident.clone();
238 let attrs = input_fn.attrs.clone();
240
241 let mut iter = input_fn.sig.inputs.iter();
243 let first_param = iter.next().ok_or_else(|| {
244 Error::new(
245 input_fn.sig.span(),
246 "query function must have `ctx: &mut QueryContext` as first parameter",
247 )
248 })?;
249
250 validate_ctx_param(first_param)?;
251
252 let mut params = Vec::new();
254 for arg in iter {
255 match arg {
256 FnArg::Typed(pat_type) => {
257 let param = parse_param(pat_type)?;
258 params.push(param);
259 }
260 FnArg::Receiver(_) => {
261 return Err(Error::new(arg.span(), "query functions cannot have `self`"));
262 }
263 }
264 }
265
266 let output_ty = parse_return_type(&input_fn.sig.output)?;
268
269 let body = &input_fn.block;
271 let body_tokens = quote! { #body };
272
273 Ok(ParsedFn {
274 vis,
275 name,
276 params,
277 output_ty,
278 body: body_tokens,
279 attrs,
280 })
281}
282
283fn validate_ctx_param(arg: &FnArg) -> Result<(), Error> {
284 match arg {
285 FnArg::Typed(pat_type) => {
286 if let Pat::Ident(pat_ident) = &*pat_type.pat {
288 if pat_ident.ident != "ctx" {
289 return Err(Error::new(
290 pat_ident.ident.span(),
291 "first parameter must be named `ctx`",
292 ));
293 }
294 }
295 Ok(())
297 }
298 FnArg::Receiver(_) => Err(Error::new(
299 arg.span(),
300 "first parameter must be `ctx: &mut QueryContext`, not `self`",
301 )),
302 }
303}
304
305fn parse_param(pat_type: &PatType) -> Result<Param, Error> {
306 let name = match &*pat_type.pat {
307 Pat::Ident(pat_ident) => pat_ident.ident.clone(),
308 _ => {
309 return Err(Error::new(
310 pat_type.pat.span(),
311 "expected simple identifier pattern",
312 ))
313 }
314 };
315
316 let ty = (*pat_type.ty).clone();
317
318 Ok(Param { name, ty })
319}
320
321fn parse_return_type(ret: &ReturnType) -> Result<Type, Error> {
322 match ret {
323 ReturnType::Default => Err(Error::new(
324 ret.span(),
325 "query function must return `Result<T, QueryError>`",
326 )),
327 ReturnType::Type(_, ty) => {
328 extract_result_ok_type(ty)
331 }
332 }
333}
334
335fn extract_result_ok_type(ty: &Type) -> Result<Type, Error> {
336 if let Type::Path(type_path) = ty {
337 if let Some(segment) = type_path.path.segments.last() {
338 if segment.ident == "Result" {
339 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
340 if let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first() {
341 return Ok(ok_ty.clone());
342 }
343 }
344 }
345 }
346 }
347 Err(Error::new(
348 ty.span(),
349 "expected `Result<T, QueryError>` return type",
350 ))
351}
352
353fn generate_struct(parsed: &ParsedFn, struct_name: &Ident) -> TokenStream2 {
354 let vis = &parsed.vis;
355 let fields: Vec<_> = parsed
356 .params
357 .iter()
358 .map(|p| {
359 let name = &p.name;
360 let ty = &p.ty;
361 quote! { pub #name: #ty }
362 })
363 .collect();
364
365 let field_names: Vec<_> = parsed.params.iter().map(|p| &p.name).collect();
366 let field_types: Vec<_> = parsed.params.iter().map(|p| &p.ty).collect();
367
368 let new_impl = if parsed.params.is_empty() {
369 quote! {
370 impl #struct_name {
371 #vis fn new() -> Self {
373 Self {}
374 }
375 }
376
377 impl ::std::default::Default for #struct_name {
378 fn default() -> Self {
379 Self::new()
380 }
381 }
382 }
383 } else {
384 quote! {
385 impl #struct_name {
386 #vis fn new(#( #field_names: #field_types ),*) -> Self {
388 Self { #( #field_names ),* }
389 }
390 }
391 }
392 };
393
394 quote! {
395 #[derive(Clone, Debug)]
396 #vis struct #struct_name {
397 #( #fields ),*
398 }
399
400 #new_impl
401 }
402}
403
404fn generate_query_impl(
405 parsed: &ParsedFn,
406 struct_name: &Ident,
407 key_params: &[&Param],
408 attr: &QueryAttr,
409) -> Result<TokenStream2, Error> {
410 let output_ty = &parsed.output_ty;
411
412 let cache_key_ty = match key_params.len() {
414 0 => quote! { () },
415 1 => {
416 let ty = &key_params[0].ty;
417 quote! { #ty }
418 }
419 _ => {
420 let types: Vec<_> = key_params.iter().map(|p| &p.ty).collect();
421 quote! { ( #( #types ),* ) }
422 }
423 };
424
425 let cache_key_body = match key_params.len() {
427 0 => quote! { () },
428 1 => {
429 let name = &key_params[0].name;
430 quote! { self.#name.clone() }
431 }
432 _ => {
433 let names: Vec<_> = key_params.iter().map(|p| &p.name).collect();
434 quote! { ( #( self.#names.clone() ),* ) }
435 }
436 };
437
438 let field_bindings: Vec<_> = parsed
440 .params
441 .iter()
442 .map(|p| {
443 let name = &p.name;
444 quote! { let #name = &self.#name; }
445 })
446 .collect();
447
448 let fn_body = &parsed.body;
449
450 let durability_impl = attr.durability.map(|d| {
452 quote! {
453 fn durability(&self) -> u8 {
454 #d
455 }
456 }
457 });
458
459 let output_eq_impl = match &attr.output_eq {
460 OutputEq::None | OutputEq::PartialEq => quote! {
462 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
463 old == new
464 }
465 },
466 OutputEq::Custom(custom_fn) => quote! {
468 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
469 #custom_fn(old, new)
470 }
471 },
472 };
473
474 let fn_attrs = &parsed.attrs;
476
477 Ok(quote! {
478 impl ::query_flow::Query for #struct_name {
479 type CacheKey = #cache_key_ty;
480 type Output = #output_ty;
481
482 fn cache_key(&self) -> Self::CacheKey {
483 #cache_key_body
484 }
485
486 #( #fn_attrs )*
487 fn query(&self, ctx: &mut ::query_flow::QueryContext) -> ::std::result::Result<Self::Output, ::query_flow::QueryError> {
488 #( #field_bindings )*
489 #fn_body
490 }
491
492 #durability_impl
493 #output_eq_impl
494 }
495 })
496}
497
498#[derive(Debug, Clone, Copy, Default)]
504enum DurabilityAttr {
505 #[default]
506 Volatile,
507 Session,
508 Stable,
509 Constant,
510}
511
512impl FromMeta for DurabilityAttr {
513 fn from_string(value: &str) -> darling::Result<Self> {
514 Self::parse_str(value)
515 }
516
517 fn from_expr(expr: &syn::Expr) -> darling::Result<Self> {
518 if let syn::Expr::Path(expr_path) = expr {
520 if let Some(ident) = expr_path.path.get_ident() {
521 return Self::parse_str(&ident.to_string());
522 }
523 }
524 Err(darling::Error::custom(
525 "expected durability level: volatile, session, stable, or constant",
526 ))
527 }
528}
529
530impl DurabilityAttr {
531 fn parse_str(value: &str) -> darling::Result<Self> {
532 match value.to_lowercase().as_str() {
533 "volatile" => Ok(DurabilityAttr::Volatile),
534 "session" => Ok(DurabilityAttr::Session),
535 "stable" => Ok(DurabilityAttr::Stable),
536 "constant" => Ok(DurabilityAttr::Constant),
537 _ => Err(darling::Error::unknown_value(value)),
538 }
539 }
540}
541
542#[derive(Debug)]
544struct TypeWrapper(syn::Type);
545
546impl FromMeta for TypeWrapper {
547 fn from_expr(expr: &syn::Expr) -> darling::Result<Self> {
548 let tokens = quote! { #expr };
550 syn::parse2::<syn::Type>(tokens)
551 .map(TypeWrapper)
552 .map_err(|e| darling::Error::custom(format!("invalid type: {}", e)))
553 }
554}
555
556#[derive(Debug, FromMeta)]
558struct AssetKeyAttr {
559 asset: TypeWrapper,
561
562 #[darling(default)]
564 durability: DurabilityAttr,
565
566 #[darling(default)]
570 asset_eq: OutputEq,
571}
572
573#[proc_macro_attribute]
601pub fn asset_key(attr: TokenStream, item: TokenStream) -> TokenStream {
602 let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
603 Ok(v) => v,
604 Err(e) => return TokenStream::from(e.to_compile_error()),
605 };
606
607 let attr = match AssetKeyAttr::from_list(&attr_args) {
608 Ok(v) => v,
609 Err(e) => return TokenStream::from(e.write_errors()),
610 };
611
612 let input_struct = parse_macro_input!(item as ItemStruct);
613
614 match generate_asset_key(attr, input_struct) {
615 Ok(tokens) => tokens.into(),
616 Err(e) => e.to_compile_error().into(),
617 }
618}
619
620fn generate_asset_key(attr: AssetKeyAttr, input_struct: ItemStruct) -> Result<TokenStream2, Error> {
621 let struct_name = &input_struct.ident;
622 let asset_ty = &attr.asset.0;
623
624 let durability_impl = match attr.durability {
626 DurabilityAttr::Volatile => quote! {
627 fn durability(&self) -> ::query_flow::DurabilityLevel {
628 ::query_flow::DurabilityLevel::Volatile
629 }
630 },
631 DurabilityAttr::Session => quote! {
632 fn durability(&self) -> ::query_flow::DurabilityLevel {
633 ::query_flow::DurabilityLevel::Session
634 }
635 },
636 DurabilityAttr::Stable => quote! {
637 fn durability(&self) -> ::query_flow::DurabilityLevel {
638 ::query_flow::DurabilityLevel::Stable
639 }
640 },
641 DurabilityAttr::Constant => quote! {
642 fn durability(&self) -> ::query_flow::DurabilityLevel {
643 ::query_flow::DurabilityLevel::Constant
644 }
645 },
646 };
647
648 let asset_eq_impl = match &attr.asset_eq {
650 OutputEq::None | OutputEq::PartialEq => quote! {
651 fn asset_eq(old: &Self::Asset, new: &Self::Asset) -> bool {
652 old == new
653 }
654 },
655 OutputEq::Custom(custom_fn) => quote! {
656 fn asset_eq(old: &Self::Asset, new: &Self::Asset) -> bool {
657 #custom_fn(old, new)
658 }
659 },
660 };
661
662 Ok(quote! {
663 #[derive(Clone, Debug, PartialEq, Eq, Hash)]
664 #input_struct
665
666 impl ::query_flow::AssetKey for #struct_name {
667 type Asset = #asset_ty;
668
669 #asset_eq_impl
670 #durability_impl
671 }
672 })
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use quote::quote;
679
680 fn normalize_tokens(tokens: TokenStream2) -> String {
681 tokens
682 .to_string()
683 .split_whitespace()
684 .collect::<Vec<_>>()
685 .join(" ")
686 }
687
688 #[test]
689 fn test_query_macro_preserves_attributes() {
690 let input_fn: ItemFn = syn::parse_quote! {
691 #[allow(unused_variables)]
692 #[inline]
693 fn my_query(ctx: &mut QueryContext, x: i32) -> Result<i32, QueryError> {
694 let unused = 42;
695 Ok(x * 2)
696 }
697 };
698
699 let attr = QueryAttr::default();
700 let output = generate_query(attr, input_fn).unwrap();
701
702 let expected = quote! {
703 #[derive(Clone, Debug)]
704 struct MyQuery {
705 pub x: i32
706 }
707
708 impl MyQuery {
709 #[doc = r" Create a new query instance."]
710 fn new(x: i32) -> Self {
711 Self { x }
712 }
713 }
714
715 impl ::query_flow::Query for MyQuery {
716 type CacheKey = i32;
717 type Output = i32;
718
719 fn cache_key(&self) -> Self::CacheKey {
720 self.x.clone()
721 }
722
723 #[allow(unused_variables)]
724 #[inline]
725 fn query(&self, ctx: &mut ::query_flow::QueryContext) -> ::std::result::Result<Self::Output, ::query_flow::QueryError> {
726 let x = &self.x;
727 {
728 let unused = 42;
729 Ok(x * 2)
730 }
731 }
732
733 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
734 old == new
735 }
736 }
737 };
738
739 assert_eq!(normalize_tokens(output), normalize_tokens(expected));
740 }
741
742 #[test]
743 fn test_query_macro_without_attributes() {
744 let input_fn: ItemFn = syn::parse_quote! {
745 fn simple(ctx: &mut QueryContext, a: i32, b: i32) -> Result<i32, QueryError> {
746 Ok(a + b)
747 }
748 };
749
750 let attr = QueryAttr::default();
751 let output = generate_query(attr, input_fn).unwrap();
752
753 let expected = quote! {
754 #[derive(Clone, Debug)]
755 struct Simple {
756 pub a: i32,
757 pub b: i32
758 }
759
760 impl Simple {
761 #[doc = r" Create a new query instance."]
762 fn new(a: i32, b: i32) -> Self {
763 Self { a, b }
764 }
765 }
766
767 impl ::query_flow::Query for Simple {
768 type CacheKey = (i32, i32);
769 type Output = i32;
770
771 fn cache_key(&self) -> Self::CacheKey {
772 (self.a.clone(), self.b.clone())
773 }
774
775 fn query(&self, ctx: &mut ::query_flow::QueryContext) -> ::std::result::Result<Self::Output, ::query_flow::QueryError> {
776 let a = &self.a;
777 let b = &self.b;
778 {
779 Ok(a + b)
780 }
781 }
782
783 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
784 old == new
785 }
786 }
787 };
788
789 assert_eq!(normalize_tokens(output), normalize_tokens(expected));
790 }
791}