shardize-core 0.1.0

Core libraries for shardize
Documentation
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{
    parse::{Parse, ParseStream},
    punctuated::Punctuated,
    token::Comma,
    FnArg, Ident, ItemTrait, TraitItem, TraitItemMethod,
};

// pub trait Shardable: Default {
//     // fn shard_key();
// }

pub fn shardize_transform(
    macro_config: MacroConfig,
    trait_definition: TraitDefinition,
) -> Result<TokenStream, &'static str> {
    let new_struct_name = macro_config.new_struct_name;
    let trait_name = trait_definition.name();
    let original_trait = trait_definition.to_token_stream();
    let impl_methods = trait_definition.impl_methods();

    Ok(quote! {
        #original_trait

        struct #new_struct_name<Impl, const NUM_SHARDS: usize>
        where
            Impl: #trait_name
        {
            sharder: &'static (dyn Fn(usize) -> usize + Send + Sync),
            shards: [Impl; NUM_SHARDS],
        }

        impl <Impl, const NUM_SHARDS: usize> #new_struct_name<Impl, NUM_SHARDS>
        where
            Impl: #trait_name,
            [Impl; NUM_SHARDS]: Default
        {
            fn new(
                    sharder: &'static (dyn Fn(usize) -> usize + Send + Sync),
            ) -> Self {
                let shards: [Impl; NUM_SHARDS] = Default::default();

                Self {
                    sharder,
                    shards,
                }
            }
        }

        impl <Impl, const NUM_SHARDS: usize> #new_struct_name<Impl, NUM_SHARDS>
        where
            Impl: #trait_name,
        {
            fn shard_key(&self, key: usize) -> usize {
                (self.sharder)(key) % NUM_SHARDS
            }
        }

        impl<Impl, const NUM_SHARDS: usize> Default for #new_struct_name<Impl, NUM_SHARDS>
        where
            Impl: #trait_name
        {
            fn default() -> Self {
                panic!();
            }
        }


        impl <Impl, const NUM_SHARDS: usize> #trait_name for #new_struct_name<Impl, NUM_SHARDS>
        where
            Impl: #trait_name
        {
            #impl_methods
        }
    })
}

pub struct MacroConfig {
    pub new_struct_name: Ident,
}

impl Parse for MacroConfig {
    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
        let new_struct_name = Ident::parse(input)?;
        Ok(Self { new_struct_name })
    }
}

pub struct TraitDefinition(ItemTrait);

impl Parse for TraitDefinition {
    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
        Ok(Self(ItemTrait::parse(input)?))
    }
}

impl TraitDefinition {
    pub fn name(&self) -> &Ident {
        &self.0.ident
    }

    pub fn to_token_stream(&self) -> TokenStream {
        self.0.to_token_stream()
    }

    pub fn impl_methods(&self) -> TokenStream {
        self.0
            .items
            .iter()
            .filter_map(|item| match item {
                TraitItem::Method(method) => {
                    Some(Self::to_sharded_method(method))
                }
                _ => None,
            })
            .collect()
    }

    fn to_sharded_method(method: &TraitItemMethod) -> TokenStream {
        let signature = method.sig.to_token_stream();
        let method_name = &method.sig.ident;
        // TODO ensure it takes self
        let args: Punctuated<_, Comma> = method
            .sig
            .inputs
            .iter()
            .filter_map(|arg| match arg {
                FnArg::Typed(pat_type) => Some(&pat_type.pat),
                _ => None,
            })
            .collect();
        // TODO ignore associated functions

        quote!(
            #signature {
                let k = self.shard_key(key);
                self.shards[k].#method_name(#args)
            }
        )
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn macro_config_parse_test() {
        let macro_config: MacroConfig =
            syn::parse2(quote!(ShardedHashMap)).unwrap();

        assert_eq!(macro_config.new_struct_name, "ShardedHashMap");
    }

    #[test]
    fn trait_definition_parse_test() {
        let trait_definition: TraitDefinition = syn::parse2(quote!(
            trait MyTrait {}
        ))
        .unwrap();

        assert_eq!(trait_definition.name(), "MyTrait");
    }

    #[test]
    fn impl_methods_test() {
        let trait_definition: TraitDefinition = syn::parse2(quote!(
            trait MyTrait {
                fn get(&self, key: String);
                fn set(&self, key: String, value: String);
            }
        ))
        .unwrap();

        assert_eq!(
            trait_definition.impl_methods().to_string(),
            quote!(
                fn get(&self, key: String) {
                    let k = self.shard_key(key);
                    self.shards[k].get(key)
                }

                fn set(&self, key: String, value: String) {
                    let k = self.shard_key(key);
                    self.shards[k].set(key, value)
                }
            )
            .to_string()
        );
    }
}