use darling::{Error, FromMeta};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::__private::Span;
use quote::quote;
use std::ops::Deref;
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{
parse_quote, parse_str, Attribute, Block, FnArg, GenericArgument, Pat, PatType, PathArguments,
ReturnType, Signature, Type,
};
#[derive(Debug, Default, Eq, PartialEq)]
pub(super) enum SyncWriteMode {
#[default]
Disabled,
Default,
ByKey,
}
impl FromMeta for SyncWriteMode {
fn from_word() -> darling::Result<Self> {
Ok(Self::Default)
}
fn from_bool(value: bool) -> darling::Result<Self> {
Ok(if value { Self::Default } else { Self::Disabled })
}
fn from_string(value: &str) -> darling::Result<Self> {
match value {
"default" | "true" => Ok(Self::Default),
"by_key" => Ok(Self::ByKey),
"false" => Ok(Self::Disabled),
_ => Err(Error::unknown_value(value)),
}
}
}
pub(super) fn validate_sync_writes_buckets(
buckets: usize,
span: proc_macro2::Span,
) -> std::result::Result<(), syn::Error> {
if buckets == 0 {
Err(syn::Error::new(
span,
"`sync_writes_buckets` must be greater than 0",
))
} else {
Ok(())
}
}
pub(super) fn by_key_lock_block(
key: TokenStream2,
locks: TokenStream2,
lock_method: TokenStream2,
await_if_async: TokenStream2,
) -> TokenStream2 {
quote! {
let lock = {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
#key.hash(&mut hasher);
#locks[(hasher.finish() as usize) % #locks.len()].clone()
};
let _key_lock = lock.#lock_method()#await_if_async;
}
}
pub(super) fn get_mut_signature(signature: Signature) -> Signature {
let mut signature_no_muts = signature;
let mut sig_inputs = Punctuated::new();
for inp in &signature_no_muts.inputs {
let item = match inp {
FnArg::Receiver(_) => inp.clone(),
FnArg::Typed(pat_type) => {
let mut pt = pat_type.clone();
let pat = match_pattern_type(&pat_type);
pt.pat = pat;
FnArg::Typed(pt)
}
};
sig_inputs.push(item);
}
signature_no_muts.inputs = sig_inputs;
signature_no_muts
}
pub(super) fn match_pattern_type(pat_type: &&PatType) -> Box<Pat> {
match &pat_type.pat.deref() {
Pat::Ident(pat_ident) => {
if pat_ident.mutability.is_some() {
let mut p = pat_ident.clone();
p.mutability = None;
Box::new(Pat::Ident(p))
} else {
Box::new(Pat::Ident(pat_ident.clone()))
}
}
_ => pat_type.pat.clone(),
}
}
pub(super) fn find_value_type(
result: bool,
option: bool,
output: &ReturnType,
output_ty: TokenStream2,
) -> Result<TokenStream2, syn::Error> {
use syn::spanned::Spanned;
match (result, option) {
(false, false) => Ok(output_ty),
(true, true) => Err(syn::Error::new(
output_ty.span(),
"the `result` and `option` attributes are mutually exclusive",
)),
_ => match output.clone() {
ReturnType::Default => Err(syn::Error::new(
output_ty.span(),
"function must return something when `result` or `option` is set",
)),
ReturnType::Type(_, ty) => {
let span = ty.span();
if let Type::Path(typepath) = *ty {
let segments = typepath.path.segments;
if let Some(last_seg) = segments.last() {
if let PathArguments::AngleBracketed(brackets) = &last_seg.arguments {
if let Some(inner_ty) = brackets.args.first() {
Ok(quote! {#inner_ty})
} else {
Err(syn::Error::new(
span,
"function return type has no inner type",
))
}
} else {
Err(syn::Error::new(
span,
"function return type has no inner type",
))
}
} else {
Err(syn::Error::new(span, "function return type is too complex"))
}
} else {
Err(syn::Error::new(span, "function return type is too complex"))
}
}
},
}
}
pub(super) fn first_type_arg<'a>(
ty: &'a Type,
span: Span,
not_path: &str,
no_arg: &str,
) -> Result<&'a GenericArgument, syn::Error> {
let Type::Path(typepath) = ty else {
return Err(syn::Error::new(span, not_path));
};
let Some(segment) = typepath.path.segments.last() else {
return Err(syn::Error::new(span, no_arg));
};
let PathArguments::AngleBracketed(brackets) = &segment.arguments else {
return Err(syn::Error::new(span, no_arg));
};
brackets
.args
.first()
.ok_or_else(|| syn::Error::new(span, no_arg))
}
pub(super) fn make_cache_key_type(
key: &Option<String>,
convert: &Option<String>,
ty: &Option<String>,
input_tys: Vec<Type>,
input_names: &Vec<Pat>,
) -> Result<(TokenStream2, TokenStream2), syn::Error> {
match (key, convert, ty) {
(Some(key_str), Some(convert_str), _) => {
let cache_key_ty = parse_str::<Type>(key_str)?;
let key_convert_block = parse_str::<Block>(convert_str)?;
Ok((quote! {#cache_key_ty}, quote! {#key_convert_block}))
}
(None, Some(convert_str), Some(_)) => {
let key_convert_block = parse_str::<Block>(convert_str)?;
Ok((quote! {}, quote! {#key_convert_block}))
}
(None, None, _) => Ok((
quote! {(#(#input_tys),*)},
quote! {(#(#input_names.clone()),*)},
)),
(Some(_), None, _) => Err(syn::Error::new(
Span::call_site(),
"`key` requires `convert` to be set",
)),
(None, Some(_), None) => Err(syn::Error::new(
Span::call_site(),
"`convert` requires `key` or `ty` to be set",
)),
}
}
pub(super) fn get_input_names(inputs: &Punctuated<FnArg, Comma>) -> Vec<Pat> {
inputs
.iter()
.map(|input| match input {
FnArg::Receiver(_) => panic!("methods (functions taking 'self') are not supported"),
FnArg::Typed(pat_type) => *match_pattern_type(&pat_type),
})
.collect()
}
pub(super) fn fill_in_attributes(attributes: &mut Vec<Attribute>, cache_fn_doc_extra: String) {
if attributes.iter().any(|attr| attr.path().is_ident("doc")) {
attributes.push(parse_quote! { #[doc = ""] });
attributes.push(parse_quote! { #[doc = "# Caching"] });
attributes.push(parse_quote! { #[doc = #cache_fn_doc_extra] });
} else {
attributes.push(parse_quote! { #[doc = #cache_fn_doc_extra] });
}
}
pub(super) fn get_input_types(inputs: &Punctuated<FnArg, Comma>) -> Vec<Type> {
inputs
.iter()
.map(|input| match input {
FnArg::Receiver(_) => panic!("methods (functions taking 'self') are not supported"),
FnArg::Typed(pat_type) => *pat_type.ty.clone(),
})
.collect()
}
pub(super) fn with_cache_flag_error(output_span: Span, output_type_display: String) -> TokenStream {
syn::Error::new(
output_span,
format!(
"\nWhen specifying `with_cached_flag = true`, \
the return type must be wrapped in `cached::Return<T>`. \n\
The following return types are supported: \n\
| `cached::Return<T>`\n\
| `std::result::Result<cached::Return<T>, E>`\n\
| `std::option::Option<cached::Return<T>>`\n\
Found type: {t}.",
t = output_type_display
),
)
.to_compile_error()
.into()
}
pub(super) fn gen_return_cache_block(
time: Option<u64>,
return_cache_block: TokenStream2,
) -> TokenStream2 {
if let Some(time) = &time {
quote! {
let (created_sec, result) = result;
if now.saturating_duration_since(*created_sec) < ::cached::time::Duration::from_secs(#time) {
#return_cache_block
}
}
} else {
quote! { #return_cache_block }
}
}
fn type_is_cached_return(ty: &Type) -> bool {
let Type::Path(type_path) = ty else {
return false;
};
let Some(last) = type_path.path.segments.last() else {
return false;
};
match last.ident.to_string().as_str() {
"Result" | "Option" => {
if let PathArguments::AngleBracketed(bracketed) = &last.arguments {
bracketed
.args
.iter()
.find_map(|arg| match arg {
GenericArgument::Type(inner) => Some(inner),
_ => None,
})
.is_some_and(type_is_cached_return)
} else {
false
}
}
"Return" => {
let segments: Vec<String> = type_path
.path
.segments
.iter()
.map(|seg| seg.ident.to_string())
.collect();
matches!(segments.as_slice(), [r] if r == "Return")
|| matches!(segments.as_slice(), [c, r] if c == "cached" && r == "Return")
}
_ => false,
}
}
pub(super) fn check_with_cache_flag(with_cached_flag: bool, output: &ReturnType) -> bool {
if !with_cached_flag {
return false;
}
match output {
ReturnType::Default => true,
ReturnType::Type(_, ty) => !type_is_cached_return(ty),
}
}