1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{parse_macro_input, ItemFn};
5
6use cachelito_macro_utils::{generate_key_expr, parse_async_attributes, AsyncCacheAttributes};
8
9fn parse_attributes(attr: TokenStream) -> AsyncCacheAttributes {
11 let attr_stream: TokenStream2 = attr.into();
12 match parse_async_attributes(attr_stream) {
13 Ok(attrs) => attrs,
14 Err(err) => {
15 panic!("Failed to parse attributes: {}", err);
18 }
19 }
20}
21fn generate_insert_call(max_memory_expr: &TokenStream2) -> TokenStream2 {
23 let max_memory_str = max_memory_expr.to_string();
24 let has_max_memory = !max_memory_str.contains("None");
25
26 if has_max_memory {
27 quote! { __cache.insert_with_memory(&__key, __result.clone()); }
28 } else {
29 quote! { __cache.insert(&__key, __result.clone()); }
30 }
31}
32
33fn generate_cache_logic_block(
35 key_expr: &TokenStream2,
36 cache_ident: &syn::Ident,
37 order_ident: &syn::Ident,
38 stats_ident: &syn::Ident,
39 limit_expr: &TokenStream2,
40 max_memory_expr: &TokenStream2,
41 policy_expr: &TokenStream2,
42 ttl_expr: &TokenStream2,
43 frequency_weight_expr: &TokenStream2,
44 window_ratio_expr: &TokenStream2,
45 invalidation_check: &TokenStream2,
46 block: &syn::Block,
47 cache_insert: &TokenStream2,
48) -> TokenStream2 {
49 quote! {
50 let __key = #key_expr;
52
53 let __cache = cachelito_core::AsyncGlobalCache::new(
55 &*#cache_ident,
56 &*#order_ident,
57 #limit_expr,
58 #max_memory_expr,
59 #policy_expr,
60 #ttl_expr,
61 #frequency_weight_expr,
62 #window_ratio_expr,
63 &*#stats_ident,
64 );
65
66 if let Some(__cached) = __cache.get(&__key) {
68 #invalidation_check
69 }
70
71 let __result = (async #block).await;
73
74 #cache_insert
76
77 __result
78 }
79}
80
81#[proc_macro_attribute]
230pub fn cache_async(attr: TokenStream, item: TokenStream) -> TokenStream {
231 let input = parse_macro_input!(item as ItemFn);
232 let attrs = parse_attributes(attr);
233
234 let vis = &input.vis;
236 let sig = &input.sig;
237 let block = &input.block;
238 let fn_name = &sig.ident;
239 let fn_name_string = fn_name.to_string();
240 let fn_name_str = attrs.custom_name.as_ref().unwrap_or(&fn_name_string);
241
242 let ret_type = match &sig.output {
244 syn::ReturnType::Default => quote! { () },
245 syn::ReturnType::Type(_, ty) => quote! { #ty },
246 };
247
248 let mut has_self = false;
250 let mut arg_pats = Vec::new();
251
252 for arg in &sig.inputs {
253 match arg {
254 syn::FnArg::Receiver(_) => {
255 has_self = true;
256 }
257 syn::FnArg::Typed(pat_type) => {
258 let pat = &pat_type.pat;
259 arg_pats.push(quote! { #pat });
260 }
261 }
262 }
263
264 let cache_ident = syn::Ident::new(
266 &format!("__CACHE_{}", fn_name.to_string().to_uppercase()),
267 fn_name.span(),
268 );
269 let order_ident = syn::Ident::new(
270 &format!("__ORDER_{}", fn_name.to_string().to_uppercase()),
271 fn_name.span(),
272 );
273 let stats_ident = syn::Ident::new(
274 &format!("__STATS_{}", fn_name.to_string().to_uppercase()),
275 fn_name.span(),
276 );
277
278 let key_expr = generate_key_expr(has_self, &arg_pats);
280
281 let (is_result, _cache_value_type) = {
283 let s = quote!(#ret_type).to_string().replace(' ', "");
284 if s.starts_with("Result<") || s.starts_with("std::result::Result<") {
285 (true, ret_type.clone())
289 } else {
290 (false, ret_type.clone())
291 }
292 };
293
294 let limit_expr = &attrs.limit;
295 let policy_str = &attrs.policy;
296 let ttl_expr = &attrs.ttl;
297 let max_memory_expr = &attrs.max_memory;
298 let frequency_weight_expr = &attrs.frequency_weight;
299 let window_ratio_expr = &attrs.window_ratio;
300
301 let policy_expr = quote! {
303 cachelito_core::EvictionPolicy::from(#policy_str)
304 };
305
306 let invalidation_check = if let Some(pred_fn) = &attrs.invalidate_on {
308 quote! {
309 if !#pred_fn(&__key, &__cached) {
312 return __cached;
314 }
315 }
317 } else {
318 quote! {
319 return __cached;
320 }
321 };
322
323 let cache_logic = {
325 let insert_call = generate_insert_call(max_memory_expr);
326
327 let cache_insert = if let Some(pred_fn) = &attrs.cache_if {
329 quote! {
330 if #pred_fn(&__key, &__result) {
332 #insert_call
333 }
334 }
335 } else if is_result {
336 quote! {
338 if __result.is_ok() {
339 #insert_call
340 }
341 }
342 } else {
343 insert_call
345 };
346
347 generate_cache_logic_block(
349 &key_expr,
350 &cache_ident,
351 &order_ident,
352 &stats_ident,
353 limit_expr,
354 max_memory_expr,
355 &policy_expr,
356 ttl_expr,
357 frequency_weight_expr,
358 window_ratio_expr,
359 &invalidation_check,
360 block,
361 &cache_insert,
362 )
363 };
364
365 let invalidation_registration = if !attrs.tags.is_empty()
367 || !attrs.events.is_empty()
368 || !attrs.dependencies.is_empty()
369 {
370 let tags = &attrs.tags;
371 let events = &attrs.events;
372 let deps = &attrs.dependencies;
373
374 quote! {
375 static INVALIDATION_REGISTERED: once_cell::sync::OnceCell<()> = once_cell::sync::OnceCell::new();
377 INVALIDATION_REGISTERED.get_or_init(|| {
378 let metadata = cachelito_core::InvalidationMetadata::new(
379 vec![#(#tags.to_string()),*],
380 vec![#(#events.to_string()),*],
381 vec![#(#deps.to_string()),*],
382 );
383 cachelito_core::InvalidationRegistry::global().register(#fn_name_str, metadata);
384
385 cachelito_core::InvalidationRegistry::global().register_callback(
387 #fn_name_str,
388 move || {
389 #cache_ident.clear();
390 #order_ident.lock().clear();
391 }
392 );
393 });
394 }
395 } else {
396 quote! {}
397 };
398
399 let invalidation_callback_registration = quote! {
401 static INVALIDATION_CHECK_REGISTERED: once_cell::sync::OnceCell<()> = once_cell::sync::OnceCell::new();
403 INVALIDATION_CHECK_REGISTERED.get_or_init(|| {
404 cachelito_core::InvalidationRegistry::global().register_invalidation_callback(
405 #fn_name_str,
406 move |invalidation_check: &dyn Fn(&str) -> bool| {
407 let keys_to_remove: Vec<String> = #cache_ident
409 .iter()
410 .filter(|entry| invalidation_check(entry.key().as_str()))
411 .map(|entry| entry.key().clone())
412 .collect();
413
414 let mut order_write = #order_ident.lock();
416 for key in &keys_to_remove {
417 #cache_ident.remove(key);
418 if let Some(pos) = order_write.iter().position(|k| k == key) {
419 order_write.remove(pos);
420 }
421 }
422 }
423 );
424 });
425 };
426
427 let expanded = quote! {
429 #vis #sig {
430 use std::collections::VecDeque;
431
432 static #cache_ident: once_cell::sync::Lazy<dashmap::DashMap<String, (#ret_type, u64, u64)>> =
434 once_cell::sync::Lazy::new(|| dashmap::DashMap::new());
435 static #order_ident: once_cell::sync::Lazy<parking_lot::Mutex<VecDeque<String>>> =
436 once_cell::sync::Lazy::new(|| parking_lot::Mutex::new(VecDeque::new()));
437 static #stats_ident: once_cell::sync::Lazy<cachelito_core::CacheStats> =
438 once_cell::sync::Lazy::new(|| cachelito_core::CacheStats::new());
439
440 static STATS_REGISTERED: once_cell::sync::OnceCell<()> = once_cell::sync::OnceCell::new();
442 STATS_REGISTERED.get_or_init(|| {
443 cachelito_core::stats_registry::register(#fn_name_str, &#stats_ident);
444 });
445
446 #invalidation_registration
447 #invalidation_callback_registration
448
449 #cache_logic
450 }
451 };
452
453 TokenStream::from(expanded)
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use quote::quote;
460
461 #[test]
462 fn test_generate_insert_call_without_max_memory() {
463 let max_memory_expr = quote! { None };
464 let result = generate_insert_call(&max_memory_expr);
465 let result_str = result.to_string();
466
467 assert!(result_str.contains("__cache") && result_str.contains("insert"));
468 assert!(!result_str.contains("insert_with_memory"));
469 }
470
471 #[test]
472 fn test_generate_insert_call_with_max_memory() {
473 let max_memory_expr = quote! { Some(1024 * 1024) };
474 let result = generate_insert_call(&max_memory_expr);
475 let result_str = result.to_string();
476
477 assert!(result_str.contains("insert_with_memory"));
478 }
479
480 #[test]
481 fn test_generate_cache_logic_block_structure() {
482 let key_expr = quote! { format!("{:?}", arg1) };
483 let cache_ident = syn::Ident::new("__CACHE_TEST", proc_macro2::Span::call_site());
484 let order_ident = syn::Ident::new("__ORDER_TEST", proc_macro2::Span::call_site());
485 let stats_ident = syn::Ident::new("__STATS_TEST", proc_macro2::Span::call_site());
486 let limit_expr = quote! { Some(100) };
487 let max_memory_expr = quote! { None };
488 let policy_expr = quote! { cachelito_core::EvictionPolicy::LRU };
489 let ttl_expr = quote! { None };
490 let frequency_weight_expr = quote! { Option::<f64>::None };
491 let window_ratio_expr = quote! { Option::<f64>::None };
492 let invalidation_check = quote! { return __cached; };
493 let block: syn::Block = syn::parse2(quote! { { 42 } }).unwrap();
494 let cache_insert = quote! { __cache.insert(&__key, __result.clone()); };
495
496 let result = generate_cache_logic_block(
497 &key_expr,
498 &cache_ident,
499 &order_ident,
500 &stats_ident,
501 &limit_expr,
502 &max_memory_expr,
503 &policy_expr,
504 &ttl_expr,
505 &frequency_weight_expr,
506 &window_ratio_expr,
507 &invalidation_check,
508 &block,
509 &cache_insert,
510 );
511
512 let result_str = result.to_string();
513
514 assert!(result_str.contains("let __key"));
516 assert!(result_str.contains("AsyncGlobalCache") && result_str.contains("new"));
517 assert!(result_str.contains("__cache") && result_str.contains("get"));
518 assert!(result_str.contains("let __result"));
519 assert!(result_str.contains("__CACHE_TEST"));
520 assert!(result_str.contains("__ORDER_TEST"));
521 assert!(result_str.contains("__STATS_TEST"));
522 }
523
524 #[test]
525 fn test_generate_cache_logic_includes_all_parameters() {
526 let key_expr = quote! { format!("{:?}", x) };
527 let cache_ident = syn::Ident::new("__CACHE_FN", proc_macro2::Span::call_site());
528 let order_ident = syn::Ident::new("__ORDER_FN", proc_macro2::Span::call_site());
529 let stats_ident = syn::Ident::new("__STATS_FN", proc_macro2::Span::call_site());
530 let limit_expr = quote! { Some(50) };
531 let max_memory_expr = quote! { Some(1024) };
532 let policy_expr = quote! { cachelito_core::EvictionPolicy::FIFO };
533 let ttl_expr = quote! { Some(60) };
534 let frequency_weight_expr = quote! { Option::<f64>::None };
535 let window_ratio_expr = quote! { Option::<f64>::None };
536 let invalidation_check = quote! { if !check_fn(&__key, &__cached) { return __cached; } };
537 let block: syn::Block = syn::parse2(quote! { { expensive_computation() } }).unwrap();
538 let cache_insert = quote! {
539 if should_cache(&__key, &__result) {
540 __cache.insert(&__key, __result.clone());
541 }
542 };
543
544 let result = generate_cache_logic_block(
545 &key_expr,
546 &cache_ident,
547 &order_ident,
548 &stats_ident,
549 &limit_expr,
550 &max_memory_expr,
551 &policy_expr,
552 &ttl_expr,
553 &frequency_weight_expr,
554 &window_ratio_expr,
555 &invalidation_check,
556 &block,
557 &cache_insert,
558 );
559
560 let result_str = result.to_string();
561
562 assert!(result_str.contains("Some (50)"));
564 assert!(result_str.contains("Some (1024)"));
565 assert!(result_str.contains("Some (60)"));
566 assert!(result_str.contains("check_fn"));
567 assert!(result_str.contains("should_cache"));
568 assert!(result_str.contains("expensive_computation"));
569 }
570
571 #[test]
572 fn test_insert_call_format() {
573 let max_memory_none = quote! { None };
574 let result_none = generate_insert_call(&max_memory_none);
575
576 assert_eq!(
578 result_none.to_string(),
579 "__cache . insert (& __key , __result . clone ()) ;"
580 );
581
582 let max_memory_some = quote! { Some(2048) };
583 let result_some = generate_insert_call(&max_memory_some);
584
585 assert_eq!(
587 result_some.to_string(),
588 "__cache . insert_with_memory (& __key , __result . clone ()) ;"
589 );
590 }
591
592 #[test]
593 fn test_cache_logic_block_contains_invalidation_check() {
594 let key_expr = quote! { key };
595 let cache_ident = syn::Ident::new("CACHE", proc_macro2::Span::call_site());
596 let order_ident = syn::Ident::new("ORDER", proc_macro2::Span::call_site());
597 let stats_ident = syn::Ident::new("STATS", proc_macro2::Span::call_site());
598 let limit_expr = quote! { None };
599 let max_memory_expr = quote! { None };
600 let policy_expr = quote! { cachelito_core::EvictionPolicy::LRU };
601 let ttl_expr = quote! { None };
602 let frequency_weight_expr = quote! { Option::<f64>::None };
603 let window_ratio_expr = quote! { Option::<f64>::None };
604
605 let custom_invalidation = quote! {
607 if my_custom_check(&__key, &__cached) {
608 return __cached;
609 }
610 };
611 let block: syn::Block = syn::parse2(quote! { { compute() } }).unwrap();
612 let cache_insert = quote! { __cache.insert(&__key, __result.clone()); };
613
614 let result = generate_cache_logic_block(
615 &key_expr,
616 &cache_ident,
617 &order_ident,
618 &stats_ident,
619 &limit_expr,
620 &max_memory_expr,
621 &policy_expr,
622 &ttl_expr,
623 &frequency_weight_expr,
624 &window_ratio_expr,
625 &custom_invalidation,
626 &block,
627 &cache_insert,
628 );
629
630 let result_str = result.to_string();
631 assert!(result_str.contains("my_custom_check"));
632 }
633
634 #[test]
635 fn test_cache_logic_block_contains_cache_insert() {
636 let key_expr = quote! { key };
637 let cache_ident = syn::Ident::new("CACHE", proc_macro2::Span::call_site());
638 let order_ident = syn::Ident::new("ORDER", proc_macro2::Span::call_site());
639 let stats_ident = syn::Ident::new("STATS", proc_macro2::Span::call_site());
640 let limit_expr = quote! { None };
641 let max_memory_expr = quote! { None };
642 let policy_expr = quote! { cachelito_core::EvictionPolicy::LRU };
643 let ttl_expr = quote! { None };
644 let frequency_weight_expr = quote! { Option::<f64>::None };
645 let window_ratio_expr = quote! { Option::<f64>::None };
646 let invalidation_check = quote! { return __cached; };
647 let block: syn::Block = syn::parse2(quote! { { compute() } }).unwrap();
648
649 let conditional_insert = quote! {
651 if predicate(&__key, &__result) {
652 __cache.insert(&__key, __result.clone());
653 }
654 };
655
656 let result = generate_cache_logic_block(
657 &key_expr,
658 &cache_ident,
659 &order_ident,
660 &stats_ident,
661 &limit_expr,
662 &max_memory_expr,
663 &policy_expr,
664 &ttl_expr,
665 &frequency_weight_expr,
666 &window_ratio_expr,
667 &invalidation_check,
668 &block,
669 &conditional_insert,
670 );
671
672 let result_str = result.to_string();
673 assert!(result_str.contains("predicate"));
674 assert!(result_str.contains("if"));
675 }
676
677 #[test]
678 fn test_generate_insert_call_detects_none_correctly() {
679 let none_variants = vec![quote! { None }, quote! { ::std::option::Option::None }];
681
682 for none_expr in none_variants {
683 let result = generate_insert_call(&none_expr);
684 let result_str = result.to_string();
685 assert!(
686 result_str.contains("insert") && !result_str.contains("insert_with_memory"),
687 "Failed for None variant: {}",
688 none_expr
689 );
690 }
691 }
692
693 #[test]
694 fn test_generate_insert_call_detects_some_correctly() {
695 let some_variants = vec![
697 quote! { Some(100) },
698 quote! { Some(1024 * 1024) },
699 quote! { Some(MAX_SIZE) },
700 ];
701
702 for some_expr in some_variants {
703 let result = generate_insert_call(&some_expr);
704 let result_str = result.to_string();
705 assert!(
706 result_str.contains("insert_with_memory"),
707 "Failed for Some variant: {}",
708 some_expr
709 );
710 }
711 }
712
713 #[test]
714 fn test_cache_logic_block_async_execution() {
715 let key_expr = quote! { format!("{:?}", id) };
716 let cache_ident = syn::Ident::new("CACHE", proc_macro2::Span::call_site());
717 let order_ident = syn::Ident::new("ORDER", proc_macro2::Span::call_site());
718 let stats_ident = syn::Ident::new("STATS", proc_macro2::Span::call_site());
719 let limit_expr = quote! { None };
720 let max_memory_expr = quote! { None };
721 let policy_expr = quote! { cachelito_core::EvictionPolicy::LRU };
722 let ttl_expr = quote! { None };
723 let frequency_weight_expr = quote! { Option::<f64>::None };
724 let window_ratio_expr = quote! { Option::<f64>::None };
725 let invalidation_check = quote! { return __cached; };
726 let block: syn::Block = syn::parse2(quote! {
727 {
728 tokio::time::sleep(Duration::from_millis(10)).await;
729 42
730 }
731 })
732 .unwrap();
733 let cache_insert = quote! { __cache.insert(&__key, __result.clone()); };
734
735 let result = generate_cache_logic_block(
736 &key_expr,
737 &cache_ident,
738 &order_ident,
739 &stats_ident,
740 &limit_expr,
741 &max_memory_expr,
742 &policy_expr,
743 &ttl_expr,
744 &frequency_weight_expr,
745 &window_ratio_expr,
746 &invalidation_check,
747 &block,
748 &cache_insert,
749 );
750
751 let result_str = result.to_string();
752
753 assert!(result_str.contains("(async"));
755 assert!(result_str.contains(") . await"));
756 assert!(result_str.contains("tokio :: time :: sleep"));
757 }
758
759 #[test]
760 fn test_cache_logic_block_contains_async_global_cache_initialization() {
761 let key_expr = quote! { key };
762 let cache_ident = syn::Ident::new("TEST_CACHE", proc_macro2::Span::call_site());
763 let order_ident = syn::Ident::new("TEST_ORDER", proc_macro2::Span::call_site());
764 let stats_ident = syn::Ident::new("TEST_STATS", proc_macro2::Span::call_site());
765 let limit_expr = quote! { Some(200) };
766 let max_memory_expr = quote! { Some(4096) };
767 let policy_expr = quote! { cachelito_core::EvictionPolicy::ARC };
768 let ttl_expr = quote! { Some(120) };
769 let frequency_weight_expr = quote! { Option::<f64>::None };
770 let window_ratio_expr = quote! { Some(0.3) };
771 let invalidation_check = quote! { return __cached; };
772 let block: syn::Block = syn::parse2(quote! { { value } }).unwrap();
773 let cache_insert = quote! { __cache.insert_with_memory(&__key, __result.clone()); };
774
775 let result = generate_cache_logic_block(
776 &key_expr,
777 &cache_ident,
778 &order_ident,
779 &stats_ident,
780 &limit_expr,
781 &max_memory_expr,
782 &policy_expr,
783 &ttl_expr,
784 &frequency_weight_expr,
785 &window_ratio_expr,
786 &invalidation_check,
787 &block,
788 &cache_insert,
789 );
790
791 let result_str = result.to_string();
792
793 assert!(result_str.contains("AsyncGlobalCache :: new"));
795 assert!(result_str.contains("& * TEST_CACHE"));
796 assert!(result_str.contains("& * TEST_ORDER"));
797 assert!(result_str.contains("Some (200)"));
798 assert!(result_str.contains("Some (4096)"));
799 assert!(result_str.contains("EvictionPolicy :: ARC"));
800 assert!(result_str.contains("Some (120)"));
801 assert!(result_str.contains("& * TEST_STATS"));
802 assert!(result_str.contains("Some (0.3)"));
803 }
804}