use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse::Parser, parse_macro_input, punctuated::Punctuated, Expr, ItemFn, Lit, Meta, Token,
};
#[allow(dead_code)]
const MAX_CACHE_KEY_LENGTH: usize = 1024;
#[allow(dead_code)]
const FORBIDDEN_KEY_CHARS: &[char] = &['\0', '\n', '\r'];
#[allow(dead_code)]
fn validate_cache_key(key: &str) -> Result<(), String> {
if key.is_empty() {
return Err("Cache key cannot be empty".to_string());
}
if key.len() > MAX_CACHE_KEY_LENGTH {
return Err(format!(
"Cache key exceeds maximum length of {} bytes (got {} bytes)",
MAX_CACHE_KEY_LENGTH,
key.len()
));
}
for c in key.chars() {
if FORBIDDEN_KEY_CHARS.contains(&c) {
return Err(format!(
"Cache key contains forbidden character '\\x{:02x}'",
c as u8
));
}
}
Ok(())
}
#[proc_macro_attribute]
pub fn cached(args: TokenStream, item: TokenStream) -> TokenStream {
let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
let args = parser.parse(args).expect("Failed to parse arguments");
let input = parse_macro_input!(item as ItemFn);
let mut service_name = "default".to_string();
let mut ttl = quote! { None };
let mut key_pattern = None;
let mut key_prefix = None;
let mut key_generator_type = "default".to_string();
let mut cache_type = quote! { "two-level" };
for arg in args {
if let Meta::NameValue(nv) = arg {
if nv.path.is_ident("service") {
if let Expr::Lit(expr_lit) = nv.value {
if let Lit::Str(lit) = expr_lit.lit {
service_name = lit.value();
}
}
} else if nv.path.is_ident("ttl") {
if let Expr::Lit(expr_lit) = nv.value {
if let Lit::Int(lit) = expr_lit.lit {
let val = lit.base10_parse::<u64>().unwrap();
ttl = quote! { Some(#val) };
}
}
} else if nv.path.is_ident("key") {
if let Expr::Lit(expr_lit) = nv.value {
if let Lit::Str(lit) = expr_lit.lit {
key_pattern = Some(lit.value());
}
}
} else if nv.path.is_ident("key_prefix") {
if let Expr::Lit(expr_lit) = nv.value {
if let Lit::Str(lit) = expr_lit.lit {
key_prefix = Some(lit.value());
}
}
} else if nv.path.is_ident("key_generator") {
if let Expr::Lit(expr_lit) = nv.value {
if let Lit::Str(lit) = expr_lit.lit {
key_generator_type = lit.value();
}
}
} else if nv.path.is_ident("cache_type") {
if let Expr::Lit(expr_lit) = nv.value {
if let Lit::Str(lit) = expr_lit.lit {
let val = lit.value();
cache_type = quote! { #val };
}
}
}
}
}
let fn_name = &input.sig.ident;
let fn_args = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_block = &input.block;
let vis = &input.vis;
let arg_names: Vec<_> = fn_args
.iter()
.filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
return Some(&pat_ident.ident);
}
}
None
})
.collect();
let key_gen = if let Some(pattern) = key_pattern {
quote! {
format!(#pattern)
}
} else if key_generator_type != "default" {
match key_generator_type.as_str() {
"simple" => {
if arg_names.is_empty() {
quote! {
oxcache::KeyGenerator::simple(#service_name, stringify!(#fn_name))
}
} else {
quote! {
oxcache::KeyGenerator::simple_with_args(
#service_name,
stringify!(#fn_name),
&(#(#arg_names),*)
)
}
}
}
"md5" => {
if arg_names.is_empty() {
quote! {
oxcache::KeyGenerator::md5(#service_name, stringify!(#fn_name), "")
}
} else {
quote! {
oxcache::KeyGenerator::md5_with_args(
#service_name,
stringify!(#fn_name),
&(#(#arg_names),*)
)
}
}
}
"murmur3" => {
if arg_names.is_empty() {
quote! {
oxcache::KeyGenerator::murmur3(#service_name, stringify!(#fn_name), "")
}
} else {
quote! {
oxcache::KeyGenerator::murmur3_with_args(
#service_name,
stringify!(#fn_name),
&(#(#arg_names),*)
)
}
}
}
"namespace" => {
if arg_names.is_empty() {
quote! {
oxcache::KeyGenerator::namespace(
#service_name,
#key_prefix.unwrap_or("default"),
stringify!(#fn_name),
""
)
}
} else {
quote! {
oxcache::KeyGenerator::namespace_with_args(
#service_name,
#key_prefix.unwrap_or("default"),
stringify!(#fn_name),
&(#(#arg_names),*)
)
}
}
}
_ => {
if arg_names.is_empty() {
quote! { format!("{}:{}", #service_name, stringify!(#fn_name)) }
} else {
quote! {
format!("{}:{}:{:?}", #service_name, stringify!(#fn_name), (#(#arg_names),*))
}
}
}
}
} else if let Some(prefix) = key_prefix {
if arg_names.is_empty() {
quote! { format!("{}:{}:{}", #service_name, #prefix, stringify!(#fn_name)) }
} else {
quote! {
format!("{}:{}:{}:{:?}", #service_name, #prefix, stringify!(#fn_name), (#(#arg_names),*))
}
}
} else {
if arg_names.is_empty() {
quote! { format!("{}:{}", #service_name, stringify!(#fn_name)) }
} else {
quote! {
format!("{}:{}:{:?}", #service_name, stringify!(#fn_name), (#(#arg_names),*))
}
}
};
let output = quote! {
#vis async fn #fn_name(#fn_args) #fn_output {
use oxcache::{get_client, CacheOps};
let cache_key = #key_gen;
let key_len = cache_key.len();
if key_len > 1024 {
tracing::warn!(
"Cache key too long ({} bytes), falling back to uncached execution",
key_len
);
return async { #fn_block }.await;
}
if cache_key.bytes().any(|b| b == 0 || b == 10 || b == 13) {
tracing::warn!(
"Cache key contains invalid characters, falling back to uncached execution"
);
return async { #fn_block }.await;
}
let client = match get_client(#service_name) {
Ok(c) => c,
Err(_) => return async { #fn_block }.await,
};
if let Ok(Some(bytes)) = client.get_bytes(&cache_key).await {
use oxcache::serialization::Serializer;
if let Ok(val) = client.serializer().deserialize(&bytes) {
return Ok(val);
}
}
let result = async { #fn_block }.await;
if let Ok(ref val) = result {
use oxcache::serialization::Serializer;
if let Ok(bytes) = client.serializer().serialize(val) {
let _ = match #cache_type {
"l1-only" => client.set_l1_bytes(&cache_key, bytes, #ttl).await,
"l2-only" => client.set_l2_bytes(&cache_key, bytes, #ttl).await,
_ => client.set_bytes(&cache_key, bytes, #ttl).await,
};
}
}
result
}
};
output.into()
}