use std::{collections::HashMap, hash::Hash};
use proc_macro2::{TokenStream, TokenTree};
use quote::quote;
use crate::{context::Context, maybe_borrowed::MaybeBorrowed};
#[derive(Debug, PartialEq, Eq, Hash)]
pub enum TokenLevel<'a> {
Generatable,
Template,
Assert(&'a Option<MaybeBorrowed<'a, syn::Generics>>),
}
#[derive(Debug, Default)]
pub struct TokenStore {
generatable: TokenStream,
template: TokenStream,
assert: TokenStream,
generics: HashMap<syn::Generics, TokenStream>,
}
pub trait AddTokens<I> {
fn add_tokens<T>(&mut self, level: TokenLevel, tokens: T)
where
T: IntoIterator<Item = I>;
}
fn hash_map_get_or_default_clone_key<'a, K, V>(map: &'a mut HashMap<K, V>, key: &K) -> &'a mut V
where
K: Eq + Hash + Clone,
V: Default,
{
if !map.contains_key(key) {
map.insert(key.clone(), V::default());
}
map.get_mut(key).unwrap()
}
impl AddTokens<TokenStream> for TokenStore {
fn add_tokens<T>(&mut self, level: TokenLevel, tokens: T)
where
T: IntoIterator<Item = TokenStream>,
{
match level {
TokenLevel::Generatable => self.generatable.extend(tokens),
TokenLevel::Template => self.template.extend(tokens),
TokenLevel::Assert(&None) => self.assert.extend(tokens),
TokenLevel::Assert(Some(t)) => {
hash_map_get_or_default_clone_key(&mut self.generics, t).extend(tokens);
}
}
}
}
impl AddTokens<TokenTree> for TokenStore {
fn add_tokens<T>(&mut self, level: TokenLevel, tokens: T)
where
T: IntoIterator<Item = TokenTree>,
{
match level {
TokenLevel::Generatable => self.generatable.extend(tokens),
TokenLevel::Template => self.template.extend(tokens),
TokenLevel::Assert(&None) => self.assert.extend(tokens),
TokenLevel::Assert(Some(t)) => {
hash_map_get_or_default_clone_key(&mut self.generics, t).extend(tokens);
}
}
}
}
impl TokenStore {
pub fn new() -> Self {
Self::default()
}
pub fn into_tokens(
self,
Context {
ident_generator, ..
}: &mut Context,
) -> TokenStream {
let Self {
generatable,
template,
assert,
generics,
} = self;
let generics: TokenStream = generics
.into_iter()
.map(|(generics, tokens)| {
let generics_ident = ident_generator.prefixed("generics");
let (impl_generics, _type_generics, where_clause) = generics.split_for_impl();
quote! {
fn #generics_ident #impl_generics () #where_clause {
#tokens
}
}
})
.collect();
quote! {
#generatable
#template
#assert
#generics
}
}
}