use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
token::Comma,
FnArg, Ident, ItemTrait, TraitItem, TraitItemMethod,
};
pub fn shardize_transform(
macro_config: MacroConfig,
trait_definition: TraitDefinition,
) -> Result<TokenStream, &'static str> {
let new_struct_name = macro_config.new_struct_name;
let trait_name = trait_definition.name();
let original_trait = trait_definition.to_token_stream();
let impl_methods = trait_definition.impl_methods();
Ok(quote! {
#original_trait
struct #new_struct_name<Impl, const NUM_SHARDS: usize>
where
Impl: #trait_name
{
sharder: &'static (dyn Fn(usize) -> usize + Send + Sync),
shards: [Impl; NUM_SHARDS],
}
impl <Impl, const NUM_SHARDS: usize> #new_struct_name<Impl, NUM_SHARDS>
where
Impl: #trait_name,
[Impl; NUM_SHARDS]: Default
{
fn new(
sharder: &'static (dyn Fn(usize) -> usize + Send + Sync),
) -> Self {
let shards: [Impl; NUM_SHARDS] = Default::default();
Self {
sharder,
shards,
}
}
}
impl <Impl, const NUM_SHARDS: usize> #new_struct_name<Impl, NUM_SHARDS>
where
Impl: #trait_name,
{
fn shard_key(&self, key: usize) -> usize {
(self.sharder)(key) % NUM_SHARDS
}
}
impl<Impl, const NUM_SHARDS: usize> Default for #new_struct_name<Impl, NUM_SHARDS>
where
Impl: #trait_name
{
fn default() -> Self {
panic!();
}
}
impl <Impl, const NUM_SHARDS: usize> #trait_name for #new_struct_name<Impl, NUM_SHARDS>
where
Impl: #trait_name
{
#impl_methods
}
})
}
pub struct MacroConfig {
pub new_struct_name: Ident,
}
impl Parse for MacroConfig {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let new_struct_name = Ident::parse(input)?;
Ok(Self { new_struct_name })
}
}
pub struct TraitDefinition(ItemTrait);
impl Parse for TraitDefinition {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
Ok(Self(ItemTrait::parse(input)?))
}
}
impl TraitDefinition {
pub fn name(&self) -> &Ident {
&self.0.ident
}
pub fn to_token_stream(&self) -> TokenStream {
self.0.to_token_stream()
}
pub fn impl_methods(&self) -> TokenStream {
self.0
.items
.iter()
.filter_map(|item| match item {
TraitItem::Method(method) => {
Some(Self::to_sharded_method(method))
}
_ => None,
})
.collect()
}
fn to_sharded_method(method: &TraitItemMethod) -> TokenStream {
let signature = method.sig.to_token_stream();
let method_name = &method.sig.ident;
let args: Punctuated<_, Comma> = method
.sig
.inputs
.iter()
.filter_map(|arg| match arg {
FnArg::Typed(pat_type) => Some(&pat_type.pat),
_ => None,
})
.collect();
quote!(
#signature {
let k = self.shard_key(key);
self.shards[k].#method_name(#args)
}
)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn macro_config_parse_test() {
let macro_config: MacroConfig =
syn::parse2(quote!(ShardedHashMap)).unwrap();
assert_eq!(macro_config.new_struct_name, "ShardedHashMap");
}
#[test]
fn trait_definition_parse_test() {
let trait_definition: TraitDefinition = syn::parse2(quote!(
trait MyTrait {}
))
.unwrap();
assert_eq!(trait_definition.name(), "MyTrait");
}
#[test]
fn impl_methods_test() {
let trait_definition: TraitDefinition = syn::parse2(quote!(
trait MyTrait {
fn get(&self, key: String);
fn set(&self, key: String, value: String);
}
))
.unwrap();
assert_eq!(
trait_definition.impl_methods().to_string(),
quote!(
fn get(&self, key: String) {
let k = self.shard_key(key);
self.shards[k].get(key)
}
fn set(&self, key: String, value: String) {
let k = self.shard_key(key);
self.shards[k].set(key, value)
}
)
.to_string()
);
}
}