computed_map 0.1.0

A Rust proc-macro crate for generating indexed maps with computed fields and indicies.
Documentation
extern crate proc_macro;

use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use syn::{parse::Parse, parse_macro_input, punctuated::Punctuated, ExprClosure, Ident, LitBool, ReturnType, Token, Type};

struct MappingInput {
    mapping_name: Ident,
    closure_expr: ExprClosure,
    return_type: Type,
    create_index: bool
}
impl Parse for MappingInput {
    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
        let mapping_name = Ident::new(&input.parse::<Ident>()?.to_string().to_case(Case::Snake), Span::call_site());
        let _comma: Token![,] = input.parse()?;
        let closure_expr: ExprClosure = input.parse()?;
        let _comma: Token![,] = input.parse()?;
        let create_index: LitBool = input.parse()?;
        if let ReturnType::Type(_, return_type) = closure_expr.output.clone() {
            Ok(Self {
                mapping_name,
                closure_expr,
                return_type: *return_type,
                create_index: create_index.value
            })
        } else {
            Err(syn::Error::new(Span::call_site(), "Return type required on closure"))
        }
    }
}

struct MacroInput {
    table_ident: Ident,
    key_type: Type,
    value_type_base: Type,
    mappings: Vec<MappingInput>
}

impl Parse for MacroInput {
    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
        let table_ident = input.parse()?;
        let _comma: Token![,] = input.parse()?;
        let key_type = input.parse()?;
        let _comma: Token![,] = input.parse()?;
        let value_type_base: Type = input.parse()?;
        let ending_comma: Result<Token![,], _> = input.parse();
        let mappings: Vec<MappingInput> = if ending_comma.is_ok() {
            Punctuated::<MappingInput, Token![,]>::parse_terminated(input)?.into_iter().collect()
        } else {
            vec![]
        };
        
        Ok(Self {
            table_ident,
            key_type,
            value_type_base,
            mappings
        })
    }
}

// struct JoinedTableInput {
//     table_ident: Ident,
//     key_type: Type,
//     value_type_base: Type,
//     mappings: Vec<MappingInput>
// }

#[proc_macro]
pub fn joined_table(_input: TokenStream) -> TokenStream {
    TokenStream::from(quote! {})
}

#[proc_macro]
pub fn generate_table(input: TokenStream) -> TokenStream {
    let MacroInput {
        table_ident,
        key_type,
        value_type_base,
        mappings,
    } = parse_macro_input!(input as MacroInput);

    let mapping_indexes_ident = Ident::new(&("MappingIndexes".to_owned() + &table_ident.to_string().to_case(Case::UpperCamel)), Span::call_site());
    let mapped_value_struct_ident = Ident::new(&("MappedValuesStruct".to_owned() + &table_ident.to_string().to_case(Case::UpperCamel)), Span::call_site());
    let utility_structs_mod_ident = Ident::new(&table_ident.to_string().to_case(Case::Snake), Span::call_site());

    let mapped_values_struct: Vec<_> = mappings.iter().map(|mapping| {
        let ident = mapping.mapping_name.clone();
        let return_type = mapping.return_type.clone();
        quote! {
            #ident: #return_type
        }
    }).collect();

    let value_structs = quote! {
        #[derive(Debug, PartialEq, Clone)]
        struct #mapped_value_struct_ident {
            #(#mapped_values_struct),*
        }
    };

    let mapping_index_struct: Vec<_> = mappings
        .iter()
        .filter(|mapping| {
            mapping.create_index
        })
        .map(|mapping| {
            let ident = mapping.mapping_name.clone();
            let return_type = mapping.return_type.clone();
            quote! {
                #ident: std::collections::BTreeMap<#return_type, std::collections::BTreeSet<#key_type>>
            }
        }).collect();

    let mapping_index_new: Vec<_> = mappings
        .iter()
        .filter(|mapping| {
            mapping.create_index
        })
        .map(|mapping| {
            let ident = mapping.mapping_name.clone();
            quote! {
                #ident: std::collections::BTreeMap::new()
            }
        }).collect();

    let value_tuple = quote! {
        (#value_type_base, #mapped_value_struct_ident)
    };

    let mapping_index_gets: Vec<_> = mappings
        .iter()
        .filter(|mapping| {
            mapping.create_index
        })
        .map(|mapping| {
            let ident = mapping.mapping_name.clone();
            let getter_ident = Ident::new(&("get_".to_owned() + &ident.to_string()), Span::call_site());
            let iter_ident = Ident::new(&("iter_".to_owned() + &ident.to_string()), Span::call_site());
            let return_type = mapping.return_type.clone();
            quote! {
                fn #getter_ident(&self, mapped_key: &#return_type) -> Option<&std::collections::BTreeSet<#key_type>> {
                    self.mapping_indexes.#ident.get(mapped_key)
                }

                fn #iter_ident(&self) -> std::collections::btree_map::Iter<#return_type, std::collections::BTreeSet<#key_type>> {
                    self.mapping_indexes.#ident.iter()
                }
            }
        }).collect();

    let insert_fn = insert(&key_type, &value_type_base, &value_tuple, &mappings, &mapped_value_struct_ident);

    let expanded = quote! {
        pub mod #utility_structs_mod_ident {
            #[derive(Debug)]
            pub enum ReadableTableState {}
            #[derive(Debug)]
            pub enum WritableTableState {}
            pub trait TableStatus {}
            impl TableStatus for ReadableTableState {}
            impl TableStatus for WritableTableState {}
        }

        #value_structs

        #[derive(Default, Debug)]
        struct #mapping_indexes_ident {
            #(#mapping_index_struct),*
        }
        impl #mapping_indexes_ident {
            fn new() -> Self {
                Self {
                    #(#mapping_index_new),*
                }
            }
        }

        #[derive(Default, Debug)]
        struct #table_ident<S: #utility_structs_mod_ident::TableStatus> {
            table: std::collections::BTreeMap<#key_type, #value_tuple>,
            mapping_indexes: #mapping_indexes_ident,
            state: std::marker::PhantomData<S>
        }
        impl #table_ident<#utility_structs_mod_ident::ReadableTableState> {
            fn new() -> Self {
                Self {
                    table: std::collections::BTreeMap::new(),
                    mapping_indexes: #mapping_indexes_ident::new(),
                    state: std::marker::PhantomData
                }
            }

            fn lock(mut self) -> #table_ident<#utility_structs_mod_ident::WritableTableState> {
                #table_ident {
                    table: self.table,
                    mapping_indexes: self.mapping_indexes,
                    state: std::marker::PhantomData::<#utility_structs_mod_ident::WritableTableState>
                }
            }

            fn get(&self, key: &#key_type) -> Option<&#value_tuple> {
                self.table.get(key)
            }

            fn iter(&self) -> std::collections::btree_map::Iter<#key_type, #value_tuple> {
                self.table.iter()
            }

            fn len(&self) -> usize {
                self.table.len()
            }

            #(#mapping_index_gets)*
        }
        impl #table_ident<#utility_structs_mod_ident::WritableTableState> {
            fn unlock(mut self) -> #table_ident<#utility_structs_mod_ident::ReadableTableState> {
                #table_ident {
                    table: self.table,
                    mapping_indexes: self.mapping_indexes,
                    state: std::marker::PhantomData::<#utility_structs_mod_ident::ReadableTableState>
                }
            }

            #insert_fn
        }
    };
    println!("{expanded}");
    TokenStream::from(expanded)
}

fn insert(
    key_type: &Type,
    value_type_base: &Type,
    value_tuple: &proc_macro2::TokenStream,
    mappings: &[MappingInput],
    mapped_value_struct_ident: &Ident
) -> proc_macro2::TokenStream {
    let insert_table: Vec<_> = mappings.iter().map(|mapping| {
        let ident = mapping.mapping_name.clone();
        let closure = mapping.closure_expr.clone();
        quote! {
            #ident: (#closure)(value.clone())
        }
    }).collect();

    let insert_indexes: Vec<_> = mappings
        .iter()
        .filter(|mapping| {
            mapping.create_index
        })
        .map(|mapping| {
            let ident = mapping.mapping_name.clone();
            quote! {
                self.mapping_indexes.#ident.entry(new_mapped_values.#ident.clone())
                    .and_modify(|value| {
                        value.insert(key);
                    })
                    .or_insert_with(|| {
                        let mut new_set = std::collections::BTreeSet::new();
                        new_set.insert(key);
                        new_set
                    });
            }
        }).collect();

    let remove_indexes: Vec<_> = mappings
        .iter()
        .filter(|mapping| {
            mapping.create_index
        })
        .map(|mapping| {
            let ident = mapping.mapping_name.clone();
            quote! {
                self.mapping_indexes.#ident.entry(old_mapped_values.#ident.clone())
                    .and_modify(|value| {
                        value.remove(&key);
                    });
            }
        }).collect();

    quote! {
        fn insert(&mut self, key: #key_type, value: #value_type_base) -> Option<#value_tuple> {
            let new_mapped_values = #mapped_value_struct_ident {
                #(#insert_table),*
            };
            let maybe_old_value = self.table.insert(key, (
                value,
                new_mapped_values.clone()
            ));
            if let Some((_, old_mapped_values)) = maybe_old_value.clone() {
                #(#remove_indexes);*
            }
            #(#insert_indexes);*
            maybe_old_value
        }
        
        fn remove(&mut self, key: &#key_type) -> Option<#value_tuple> {
            let maybe_old_value = self.table.remove(key);
            if let Some((_, old_mapped_values)) = maybe_old_value.clone() {
                #(#remove_indexes);*
            }
            maybe_old_value
        }
    }
}