1use darling::{ast::NestedMeta, FromMeta};
37use heck::ToUpperCamelCase;
38use proc_macro::TokenStream;
39use proc_macro2::TokenStream as TokenStream2;
40use quote::{format_ident, quote};
41use syn::{
42 parse_macro_input, spanned::Spanned, Error, FnArg, Ident, ItemFn, ItemStruct, Pat, PatType,
43 ReturnType, Type, Visibility,
44};
45
46#[derive(Debug, Default)]
48struct Keys(Vec<Ident>);
49
50impl FromMeta for Keys {
51 fn from_list(items: &[NestedMeta]) -> darling::Result<Self> {
52 let mut idents = Vec::new();
53 for item in items {
54 match item {
55 NestedMeta::Meta(syn::Meta::Path(path)) => {
56 if let Some(ident) = path.get_ident() {
57 idents.push(ident.clone());
58 } else {
59 return Err(darling::Error::custom("expected identifier").with_span(path));
60 }
61 }
62 _ => {
63 return Err(darling::Error::custom("expected identifier"));
64 }
65 }
66 }
67 Ok(Keys(idents))
68 }
69}
70
71#[derive(Debug, Default)]
73enum OutputEq {
74 #[default]
75 None,
76 PartialEq,
78 Custom(syn::Path),
80}
81
82impl FromMeta for OutputEq {
83 fn from_word() -> darling::Result<Self> {
84 Ok(OutputEq::PartialEq)
85 }
86
87 fn from_value(value: &syn::Lit) -> darling::Result<Self> {
88 Err(darling::Error::unexpected_lit_type(value))
89 }
90
91 fn from_meta(item: &syn::Meta) -> darling::Result<Self> {
92 match item {
93 syn::Meta::Path(_) => Ok(OutputEq::PartialEq),
94 syn::Meta::NameValue(nv) => {
95 if let syn::Expr::Path(expr_path) = &nv.value {
96 Ok(OutputEq::Custom(expr_path.path.clone()))
97 } else {
98 Err(darling::Error::custom("expected path").with_span(&nv.value))
99 }
100 }
101 syn::Meta::List(_) => Err(darling::Error::unsupported_format("list")),
102 }
103 }
104}
105
106#[derive(Debug, Default, FromMeta)]
108struct QueryAttr {
109 #[darling(default)]
111 durability: Option<u8>,
112
113 #[darling(default)]
117 output_eq: OutputEq,
118
119 #[darling(default)]
121 keys: Keys,
122
123 #[darling(default)]
125 name: Option<String>,
126}
127
128struct Param {
130 name: Ident,
131 ty: Type,
132}
133
134struct ParsedFn {
136 vis: Visibility,
137 name: Ident,
138 params: Vec<Param>,
139 output_ty: Type,
140 body: TokenStream2,
141}
142
143#[proc_macro_attribute]
171pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream {
172 let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
173 Ok(v) => v,
174 Err(e) => return TokenStream::from(e.to_compile_error()),
175 };
176
177 let attr = match QueryAttr::from_list(&attr_args) {
178 Ok(v) => v,
179 Err(e) => return TokenStream::from(e.write_errors()),
180 };
181
182 let input_fn = parse_macro_input!(item as ItemFn);
183
184 match generate_query(attr, input_fn) {
185 Ok(tokens) => tokens.into(),
186 Err(e) => e.to_compile_error().into(),
187 }
188}
189
190fn generate_query(attr: QueryAttr, input_fn: ItemFn) -> Result<TokenStream2, Error> {
191 let parsed = parse_function(&input_fn)?;
193
194 let struct_name = match &attr.name {
196 Some(name) => format_ident!("{}", name),
197 None => format_ident!("{}", parsed.name.to_string().to_upper_camel_case()),
198 };
199
200 let key_params: Vec<&Param> = if attr.keys.0.is_empty() {
202 parsed.params.iter().collect()
204 } else {
205 for key in &attr.keys.0 {
207 if !parsed.params.iter().any(|p| p.name == *key) {
208 return Err(Error::new(
209 key.span(),
210 format!("unknown parameter `{}` in keys", key),
211 ));
212 }
213 }
214 parsed
215 .params
216 .iter()
217 .filter(|p| attr.keys.0.contains(&p.name))
218 .collect()
219 };
220
221 let struct_def = generate_struct(&parsed, &struct_name);
223
224 let query_impl = generate_query_impl(&parsed, &struct_name, &key_params, &attr)?;
226
227 Ok(quote! {
228 #struct_def
229 #query_impl
230 })
231}
232
233fn parse_function(input_fn: &ItemFn) -> Result<ParsedFn, Error> {
234 let vis = input_fn.vis.clone();
235 let name = input_fn.sig.ident.clone();
236
237 let mut iter = input_fn.sig.inputs.iter();
239 let first_param = iter.next().ok_or_else(|| {
240 Error::new(
241 input_fn.sig.span(),
242 "query function must have `ctx: &mut QueryContext` as first parameter",
243 )
244 })?;
245
246 validate_ctx_param(first_param)?;
247
248 let mut params = Vec::new();
250 for arg in iter {
251 match arg {
252 FnArg::Typed(pat_type) => {
253 let param = parse_param(pat_type)?;
254 params.push(param);
255 }
256 FnArg::Receiver(_) => {
257 return Err(Error::new(arg.span(), "query functions cannot have `self`"));
258 }
259 }
260 }
261
262 let output_ty = parse_return_type(&input_fn.sig.output)?;
264
265 let body = &input_fn.block;
267 let body_tokens = quote! { #body };
268
269 Ok(ParsedFn {
270 vis,
271 name,
272 params,
273 output_ty,
274 body: body_tokens,
275 })
276}
277
278fn validate_ctx_param(arg: &FnArg) -> Result<(), Error> {
279 match arg {
280 FnArg::Typed(pat_type) => {
281 if let Pat::Ident(pat_ident) = &*pat_type.pat {
283 if pat_ident.ident != "ctx" {
284 return Err(Error::new(
285 pat_ident.ident.span(),
286 "first parameter must be named `ctx`",
287 ));
288 }
289 }
290 Ok(())
292 }
293 FnArg::Receiver(_) => Err(Error::new(
294 arg.span(),
295 "first parameter must be `ctx: &mut QueryContext`, not `self`",
296 )),
297 }
298}
299
300fn parse_param(pat_type: &PatType) -> Result<Param, Error> {
301 let name = match &*pat_type.pat {
302 Pat::Ident(pat_ident) => pat_ident.ident.clone(),
303 _ => {
304 return Err(Error::new(
305 pat_type.pat.span(),
306 "expected simple identifier pattern",
307 ))
308 }
309 };
310
311 let ty = (*pat_type.ty).clone();
312
313 Ok(Param { name, ty })
314}
315
316fn parse_return_type(ret: &ReturnType) -> Result<Type, Error> {
317 match ret {
318 ReturnType::Default => Err(Error::new(
319 ret.span(),
320 "query function must return `Result<T, QueryError>`",
321 )),
322 ReturnType::Type(_, ty) => {
323 extract_result_ok_type(ty)
326 }
327 }
328}
329
330fn extract_result_ok_type(ty: &Type) -> Result<Type, Error> {
331 if let Type::Path(type_path) = ty {
332 if let Some(segment) = type_path.path.segments.last() {
333 if segment.ident == "Result" {
334 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
335 if let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first() {
336 return Ok(ok_ty.clone());
337 }
338 }
339 }
340 }
341 }
342 Err(Error::new(
343 ty.span(),
344 "expected `Result<T, QueryError>` return type",
345 ))
346}
347
348fn generate_struct(parsed: &ParsedFn, struct_name: &Ident) -> TokenStream2 {
349 let vis = &parsed.vis;
350 let fields: Vec<_> = parsed
351 .params
352 .iter()
353 .map(|p| {
354 let name = &p.name;
355 let ty = &p.ty;
356 quote! { pub #name: #ty }
357 })
358 .collect();
359
360 let field_names: Vec<_> = parsed.params.iter().map(|p| &p.name).collect();
361 let field_types: Vec<_> = parsed.params.iter().map(|p| &p.ty).collect();
362
363 let new_impl = if parsed.params.is_empty() {
364 quote! {
365 impl #struct_name {
366 #vis fn new() -> Self {
368 Self {}
369 }
370 }
371
372 impl ::std::default::Default for #struct_name {
373 fn default() -> Self {
374 Self::new()
375 }
376 }
377 }
378 } else {
379 quote! {
380 impl #struct_name {
381 #vis fn new(#( #field_names: #field_types ),*) -> Self {
383 Self { #( #field_names ),* }
384 }
385 }
386 }
387 };
388
389 quote! {
390 #[derive(Clone, Debug)]
391 #vis struct #struct_name {
392 #( #fields ),*
393 }
394
395 #new_impl
396 }
397}
398
399fn generate_query_impl(
400 parsed: &ParsedFn,
401 struct_name: &Ident,
402 key_params: &[&Param],
403 attr: &QueryAttr,
404) -> Result<TokenStream2, Error> {
405 let output_ty = &parsed.output_ty;
406
407 let cache_key_ty = match key_params.len() {
409 0 => quote! { () },
410 1 => {
411 let ty = &key_params[0].ty;
412 quote! { #ty }
413 }
414 _ => {
415 let types: Vec<_> = key_params.iter().map(|p| &p.ty).collect();
416 quote! { ( #( #types ),* ) }
417 }
418 };
419
420 let cache_key_body = match key_params.len() {
422 0 => quote! { () },
423 1 => {
424 let name = &key_params[0].name;
425 quote! { self.#name.clone() }
426 }
427 _ => {
428 let names: Vec<_> = key_params.iter().map(|p| &p.name).collect();
429 quote! { ( #( self.#names.clone() ),* ) }
430 }
431 };
432
433 let field_bindings: Vec<_> = parsed
435 .params
436 .iter()
437 .map(|p| {
438 let name = &p.name;
439 quote! { let #name = &self.#name; }
440 })
441 .collect();
442
443 let fn_body = &parsed.body;
444
445 let durability_impl = attr.durability.map(|d| {
447 quote! {
448 fn durability(&self) -> u8 {
449 #d
450 }
451 }
452 });
453
454 let output_eq_impl = match &attr.output_eq {
455 OutputEq::None | OutputEq::PartialEq => quote! {
457 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
458 old == new
459 }
460 },
461 OutputEq::Custom(custom_fn) => quote! {
463 fn output_eq(old: &Self::Output, new: &Self::Output) -> bool {
464 #custom_fn(old, new)
465 }
466 },
467 };
468
469 Ok(quote! {
470 impl ::query_flow::Query for #struct_name {
471 type CacheKey = #cache_key_ty;
472 type Output = #output_ty;
473
474 fn cache_key(&self) -> Self::CacheKey {
475 #cache_key_body
476 }
477
478 fn query(&self, ctx: &mut ::query_flow::QueryContext) -> ::std::result::Result<Self::Output, ::query_flow::QueryError> {
479 #( #field_bindings )*
480 #fn_body
481 }
482
483 #durability_impl
484 #output_eq_impl
485 }
486 })
487}
488
489#[derive(Debug, Clone, Copy, Default)]
495enum DurabilityAttr {
496 #[default]
497 Volatile,
498 Session,
499 Stable,
500 Constant,
501}
502
503impl FromMeta for DurabilityAttr {
504 fn from_string(value: &str) -> darling::Result<Self> {
505 Self::parse_str(value)
506 }
507
508 fn from_expr(expr: &syn::Expr) -> darling::Result<Self> {
509 if let syn::Expr::Path(expr_path) = expr {
511 if let Some(ident) = expr_path.path.get_ident() {
512 return Self::parse_str(&ident.to_string());
513 }
514 }
515 Err(darling::Error::custom("expected durability level: volatile, session, stable, or constant"))
516 }
517}
518
519impl DurabilityAttr {
520 fn parse_str(value: &str) -> darling::Result<Self> {
521 match value.to_lowercase().as_str() {
522 "volatile" => Ok(DurabilityAttr::Volatile),
523 "session" => Ok(DurabilityAttr::Session),
524 "stable" => Ok(DurabilityAttr::Stable),
525 "constant" => Ok(DurabilityAttr::Constant),
526 _ => Err(darling::Error::unknown_value(value)),
527 }
528 }
529}
530
531#[derive(Debug)]
533struct TypeWrapper(syn::Type);
534
535impl FromMeta for TypeWrapper {
536 fn from_expr(expr: &syn::Expr) -> darling::Result<Self> {
537 let tokens = quote! { #expr };
539 syn::parse2::<syn::Type>(tokens)
540 .map(TypeWrapper)
541 .map_err(|e| darling::Error::custom(format!("invalid type: {}", e)))
542 }
543}
544
545#[derive(Debug, FromMeta)]
547struct AssetKeyAttr {
548 asset: TypeWrapper,
550
551 #[darling(default)]
553 durability: DurabilityAttr,
554
555 #[darling(default)]
559 asset_eq: OutputEq,
560}
561
562#[proc_macro_attribute]
590pub fn asset_key(attr: TokenStream, item: TokenStream) -> TokenStream {
591 let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
592 Ok(v) => v,
593 Err(e) => return TokenStream::from(e.to_compile_error()),
594 };
595
596 let attr = match AssetKeyAttr::from_list(&attr_args) {
597 Ok(v) => v,
598 Err(e) => return TokenStream::from(e.write_errors()),
599 };
600
601 let input_struct = parse_macro_input!(item as ItemStruct);
602
603 match generate_asset_key(attr, input_struct) {
604 Ok(tokens) => tokens.into(),
605 Err(e) => e.to_compile_error().into(),
606 }
607}
608
609fn generate_asset_key(attr: AssetKeyAttr, input_struct: ItemStruct) -> Result<TokenStream2, Error> {
610 let struct_name = &input_struct.ident;
611 let asset_ty = &attr.asset.0;
612
613 let durability_impl = match attr.durability {
615 DurabilityAttr::Volatile => quote! {
616 fn durability(&self) -> ::query_flow::DurabilityLevel {
617 ::query_flow::DurabilityLevel::Volatile
618 }
619 },
620 DurabilityAttr::Session => quote! {
621 fn durability(&self) -> ::query_flow::DurabilityLevel {
622 ::query_flow::DurabilityLevel::Session
623 }
624 },
625 DurabilityAttr::Stable => quote! {
626 fn durability(&self) -> ::query_flow::DurabilityLevel {
627 ::query_flow::DurabilityLevel::Stable
628 }
629 },
630 DurabilityAttr::Constant => quote! {
631 fn durability(&self) -> ::query_flow::DurabilityLevel {
632 ::query_flow::DurabilityLevel::Constant
633 }
634 },
635 };
636
637 let asset_eq_impl = match &attr.asset_eq {
639 OutputEq::None | OutputEq::PartialEq => quote! {
640 fn asset_eq(old: &Self::Asset, new: &Self::Asset) -> bool {
641 old == new
642 }
643 },
644 OutputEq::Custom(custom_fn) => quote! {
645 fn asset_eq(old: &Self::Asset, new: &Self::Asset) -> bool {
646 #custom_fn(old, new)
647 }
648 },
649 };
650
651 Ok(quote! {
652 #[derive(Clone, Debug, PartialEq, Eq, Hash)]
653 #input_struct
654
655 impl ::query_flow::AssetKey for #struct_name {
656 type Asset = #asset_ty;
657
658 #asset_eq_impl
659 #durability_impl
660 }
661 })
662}