use heck::CamelCase;
use proc_macro::TokenStream;
use quote::{format_ident, quote, ToTokens};
use std::{
cell::RefCell,
collections::{BTreeSet, HashSet},
};
use syn::visit::Visit;
thread_local! {
static QUERIES: RefCell<Vec<String>> = Default::default();
}
pub(crate) fn mark_query(_args: TokenStream, input: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(input as syn::ItemFn);
QUERIES.with(|c| c.borrow_mut().push(input.to_token_stream().to_string()));
let note = format!("call query via database: cx.{}(...)", input.sig.ident);
let output = quote! {
#[deprecated(note = #note)]
#input
};
output.into()
}
pub(crate) fn derive_query_db(input: TokenStream) -> TokenStream {
let input = proc_macro2::TokenStream::from(input);
let mut output = proc_macro2::TokenStream::new();
let queries = QUERIES.with(|c| std::mem::replace(&mut *c.borrow_mut(), Default::default()));
let reserved: HashSet<_> = ["query_key", "query_storage", "query_tag", "query_result"]
.iter()
.copied()
.collect();
let mut ltset = BTreeSet::new();
let mut funcs = vec![];
let mut caches = vec![];
let mut tags = vec![];
let mut tag_debugs = vec![];
let mut keys = vec![];
for raw_query in &queries {
let query: syn::ItemFn = syn::parse_str(raw_query).unwrap();
let name = query.sig.ident.clone();
let mut generics = query.sig.generics.clone();
let args = query.sig.inputs.iter().skip(1);
let result = match query.sig.output {
syn::ReturnType::Type(_, ty) => ty.as_ref().clone(),
_ => panic!("query {} has no return type", name),
};
let doc_attrs = query.attrs.iter().filter(|a| a.path.is_ident("doc"));
let arg_pats = args.clone().map(|arg| match arg {
syn::FnArg::Typed(pat) => pat,
_ => unreachable!("there should be no self args left at this point"),
});
let arg_names: Vec<_> = arg_pats
.clone()
.enumerate()
.map(|(i, pat)| match pat.pat.as_ref() {
syn::Pat::Ident(id) if !reserved.contains(id.ident.to_string().as_str()) => {
id.ident.clone()
}
_ => format_ident!("arg{}", i),
})
.collect();
let arg_types: Vec<_> = arg_pats
.clone()
.map(|pat| pat.ty.as_ref().clone())
.collect();
let original_generics = generics.clone();
for param in std::mem::replace(&mut generics.params, Default::default()) {
match param {
syn::GenericParam::Lifetime(ltdef) => {
ltset.insert(ltdef.lifetime);
}
_ => generics.params.push(param),
}
}
struct KeyGenerics(HashSet<String>);
impl Visit<'_> for KeyGenerics {
fn visit_lifetime(&mut self, node: &syn::Lifetime) {
self.0.insert(node.to_string());
syn::visit::visit_lifetime(self, node);
}
fn visit_type(&mut self, node: &syn::Type) {
match node {
syn::Type::Path(x) => {
if let Some(x) = x.path.get_ident() {
self.0.insert(x.to_string());
}
}
_ => (),
}
syn::visit::visit_type(self, node);
}
}
let mut key_generics_set = KeyGenerics(Default::default());
for arg_type in &arg_types {
key_generics_set.visit_type(&arg_type);
}
let mut key_generics = syn::Generics {
lt_token: original_generics.lt_token,
params: Default::default(),
gt_token: original_generics.gt_token,
where_clause: None,
};
for param in original_generics.params.iter() {
let keep = match param {
syn::GenericParam::Type(x) => key_generics_set.0.contains(&x.ident.to_string()),
syn::GenericParam::Lifetime(x) => {
key_generics_set.0.contains(&x.lifetime.to_string())
}
syn::GenericParam::Const(x) => key_generics_set.0.contains(&x.ident.to_string()),
};
if keep {
key_generics.params.push(param.clone());
}
}
let key_name = format_ident!("{}QueryKey", name.to_string().to_camel_case());
let key_indices = (0..arg_types.len()).map(syn::Index::from);
let doc = format!("The arguments passed to the `{}` query.", name);
keys.push(quote! {
#[doc = #doc]
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct #key_name #key_generics (#(pub #arg_types),*);
impl #key_generics std::fmt::Debug for #key_name #key_generics {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_tuple("")#(.field(&self.#key_indices))*.finish()
}
}
});
let cache_name = format_ident!("cached_{}", name);
let tag_name = format_ident!("{}", name.to_string().to_camel_case());
let doc = format!("The `{}` query.", name);
tags.push(quote! {
#[doc = #doc]
#tag_name (#key_name #key_generics),
});
tag_debugs.push(quote! {
QueryTag::#tag_name (x) => write!(f, "{}{:?}", stringify!(#name), x),
});
funcs.push(quote! {
#(#doc_attrs)*
fn #name #generics (&self, #(#arg_names: #arg_types),*) -> #result {
let query_storage = self.storage();
let query_key = #key_name(#(Clone::clone(&#arg_names)),*);
let query_tag = QueryTag::#tag_name(query_key.clone());
self.before_query(&query_tag);
if let Some(result) = query_storage.#cache_name.borrow().get(&query_key) {
trace!("Serving {}{:?} from cache", stringify!(#name), query_key);
return Clone::clone(&result);
}
trace!("Executing {}{:?}", stringify!(#name), query_key);
query_storage.stack.borrow_mut().push(query_tag.clone());
if !query_storage.inflight.borrow_mut().insert(query_tag.clone()) {
self.handle_cycle();
}
#[allow(deprecated)]
let query_result = #name(self.context(), #(#arg_names),*);
query_storage.#cache_name.borrow_mut().insert(query_key, Clone::clone(&query_result));
query_storage.inflight.borrow_mut().remove(&query_tag);
query_storage.stack.borrow_mut().pop();
self.after_query(&query_tag);
query_result
}
});
let doc = format!("Cached results of the `{}` query.", name);
caches.push(quote! {
#[doc = #doc]
pub #cache_name: RefCell<HashMap<#key_name #key_generics, #result>>,
});
}
let lts = if ltset.is_empty() {
quote! {}
} else {
let lts = ltset.iter();
quote! { <#(#lts),*> }
};
output.extend(quote! {
#input
pub trait QueryDatabase #lts {
type Context: crate::Context #lts;
fn context(&self) -> &Self::Context;
fn storage(&self) -> &QueryStorage #lts;
fn handle_cycle(&self) -> ! {
error!("Query cycle detected:");
for x in self.storage().stack.borrow().iter() {
error!(" - {:?}", x);
}
panic!("query cycle detected");
}
fn before_query(&self, tag: &QueryTag #lts) {}
fn after_query(&self, tag: &QueryTag #lts) {}
#(#funcs)*
}
});
output.extend(quote! {
#[derive(Default)]
pub struct QueryStorage #lts {
pub stack: RefCell<Vec<QueryTag #lts>>,
pub inflight: RefCell<HashSet<QueryTag #lts>>,
#(#caches)*
}
});
output.extend(quote! {
#[derive(Clone, Hash, Eq, PartialEq)]
pub enum QueryTag #lts {
#(#tags)*
}
impl #lts std::fmt::Debug for QueryTag #lts {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
#(#tag_debugs)*
}
}
}
});
output.extend(quote! {
#(#keys)*
});
output.into()
}