smart_cache_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use sha2::{Digest, Sha256};
5use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, ReturnType, Type};
6
7/// A procedural macro that automatically caches function results based on its input parameters.
8///
9/// This macro implements function-level caching by serializing the function's input parameters
10/// and using them as a cache key. The function's result is then stored and can be retrieved
11/// when the function is called again with the same parameters.
12///
13/// # Requirements
14///
15/// - All function parameters must implement `Archive`, `Serialize`, and `Deserialize` from `rkyv`
16/// - The return type must implement `Archive`, `Serialize`, and `Deserialize` from `rkyv`
17/// - The function must be pure (no mutable references allowed)
18///
19/// # Examples
20///
21/// ```rust
22/// use smart_cache_macro::cached;
23///
24/// #[cached]
25/// fn fibonacci(n: u64) -> u64 {
26///     if n <= 1 {
27///         return n;
28///     }
29///     fibonacci(n - 1) + fibonacci(n - 2)
30/// }
31///
32/// // First call will compute and cache the result
33/// let result1 = fibonacci(10);
34///
35/// // Second call will retrieve from cache
36/// let result2 = fibonacci(10);
37///
38/// assert_eq!(result1, result2);
39/// ```
40///
41/// Works with multiple parameters and reference types:
42///
43/// ```rust
44/// use smart_cache_macro::cached;
45///
46/// #[cached]
47/// fn process_data(data: &[u8], threshold: u32) -> Vec<u8> {
48///     // Expensive computation here...
49///     data.iter()
50///         .filter(|&&x| x as u32 > threshold)
51///         .copied()
52///         .collect()
53/// }
54/// ```
55///
56/// # How it works
57///
58/// The macro:
59/// 1. Creates a unique cache key from the function's parameters and a hash of the function body
60/// 2. Checks if a result exists in the cache for this key
61/// 3. If found, deserializes and returns the cached result
62/// 4. If not found, executes the function, caches the result, and returns it
63///
64fn hash_token_stream(tokens: &proc_macro2::TokenStream) -> [u8; 32] {
65    // Convert TokenStream to a string representation
66    let token_string = tokens.to_string();
67
68    // Create a new SHA-256 hasher
69    let mut hasher = Sha256::new();
70
71    // Update hasher with token string bytes
72    hasher.update(token_string.as_bytes());
73
74    // Finalize and return the hash as bytes
75    hasher.finalize().into()
76}
77
78fn check_for_mutable_refs(
79    fn_inputs: &syn::punctuated::Punctuated<FnArg, syn::token::Comma>,
80) -> Result<(), syn::Error> {
81    for arg in fn_inputs {
82        let FnArg::Typed(pat_type) = arg else {
83            continue;
84        };
85
86        let Type::Reference(type_ref) = &*pat_type.ty else {
87            continue;
88        };
89
90        let Some(mutability) = &type_ref.mutability else {
91            continue;
92        };
93
94        return Err(syn::Error::new_spanned(
95            mutability,
96            "cached functions must be pure - mutable references are not allowed",
97        ));
98    }
99    Ok(())
100}
101
102fn get_param_type(ty: &Type) -> &Type {
103    if let Type::Reference(type_ref) = ty {
104        &type_ref.elem
105    } else {
106        ty
107    }
108}
109
110#[proc_macro_attribute]
111pub fn cached(_attr: TokenStream, item: TokenStream) -> TokenStream {
112    let input_fn = parse_macro_input!(item as ItemFn);
113
114    // Check for mutable references and return the original function with error if found
115    if let Err(err) = check_for_mutable_refs(&input_fn.sig.inputs) {
116        let compiler_err = err.to_compile_error();
117
118        return quote! {
119            #input_fn
120
121            #compiler_err
122        }
123        .into();
124    }
125
126    let mut input_fn = input_fn;
127
128    let mut fn_with_name_inner = input_fn.clone();
129    fn_with_name_inner.sig.ident = Ident::new("inner", Span::call_site());
130
131    let fn_with_name_inner_tokens = quote! {
132        #fn_with_name_inner
133    };
134
135    let inner_fn_hash = hash_token_stream(&fn_with_name_inner_tokens);
136
137    // Convert the [u8; 32] to a literal array expression
138    let inner_fn_hash_literal = quote! {
139        [
140            #(#inner_fn_hash,)*
141        ]
142    };
143
144    let fn_inputs = &input_fn.sig.inputs;
145    let fn_output = match &input_fn.sig.output {
146        ReturnType::Default => quote!(()),
147        ReturnType::Type(_, ty) => quote!(#ty),
148    };
149
150    let param_names: Vec<_> = fn_inputs
151        .iter()
152        .filter_map(|arg| match arg {
153            FnArg::Typed(pat_type) => {
154                if let Pat::Ident(pat_ident) = &*pat_type.pat {
155                    Some(&pat_ident.ident)
156                } else {
157                    None
158                }
159            }
160            _ => None,
161        })
162        .collect();
163
164    let param_types: Vec<_> = fn_inputs
165        .iter()
166        .filter_map(|arg| match arg {
167            FnArg::Typed(pat_type) => Some(get_param_type(&pat_type.ty)),
168            _ => None,
169        })
170        .collect();
171
172    let new_block = quote! {{
173        #fn_with_name_inner
174
175        use rkyv::{with::InlineAsBox, Archive, Deserialize, Serialize};
176
177        #[derive(Archive, Serialize, Deserialize, Debug)]
178        struct CacheKey<'a> {
179            #(
180                #[rkyv(with = InlineAsBox)]
181                #param_names: &'a #param_types,
182            )*
183            _function_hash: [u8; 32],
184        }
185
186        let key = CacheKey {
187            #(#param_names: &#param_names,)*
188            _function_hash: #inner_fn_hash_literal,
189        };
190        println!("{key:?}");
191        let key_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&key).unwrap();
192
193        if let Some(cached_result) = smart_cache::get_cached(&*key_bytes) {
194            let cached_result = &*cached_result;
195            let cached_result: &rkyv::Archived<#fn_output> = rkyv::access::<_, rkyv::rancor::Error>(cached_result).unwrap();
196            let cached_result: #fn_output = rkyv::deserialize::<#fn_output, rkyv::rancor::Error>(cached_result).unwrap();
197            return cached_result;
198        }
199
200        let result = inner(#(#param_names,)*);
201
202        let value_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&result).unwrap();
203        let _ = smart_cache::set_cached(&key_bytes, &value_bytes);
204
205        result
206    }};
207
208    input_fn.block = syn::parse2(new_block).unwrap();
209
210    TokenStream::from(quote! {
211        #input_fn
212    })
213}