use convert_case::Casing;
use proc_macro_error::{abort_call_site, proc_macro_error};
use quote::{format_ident, quote};
use syn::{parse_macro_input, DeriveInput};
#[proc_macro_derive(MultiIndexMap, attributes(multi_index))]
#[proc_macro_error]
pub fn multi_index_map(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let fields = match input.data {
syn::Data::Struct(d) => d.fields,
_ => abort_call_site!("MultiIndexMap only supports structs as elements"),
};
let named_fields = match fields {
syn::Fields::Named(f) => f,
_ => abort_call_site!(
"Struct fields must be named, unnamed tuple structs and unit structs are not supported"
),
};
let fields_to_index = || {
named_fields.named.iter().filter(|f| {
f.attrs.first().is_some() && f.attrs.first().unwrap().path.is_ident("multi_index")
})
};
let lookup_table_fields = fields_to_index().map(|f| {
let index_name = format_ident!("_{}_index", f.ident.as_ref().unwrap());
let ty = &f.ty;
let (ordering, uniqueness) = get_index_kind(f).unwrap_or_else(|| {
abort_call_site!("Attributes must be in the style #[multi_index(hashed_unique)]")
});
match uniqueness {
Uniqueness::Unique => match ordering {
Ordering::Hashed => quote! {
#index_name: rustc_hash::FxHashMap<#ty, usize>,
},
Ordering::Ordered => quote! {
#index_name: std::collections::BTreeMap<#ty, usize>,
}
}
Uniqueness::NonUnique => match ordering {
Ordering::Hashed => quote! {
#index_name: rustc_hash::FxHashMap<#ty, Vec<usize>>,
},
Ordering::Ordered => quote! {
#index_name: std::collections::BTreeMap<#ty, Vec<usize>>,
}
}
}
});
let inserts: Vec<proc_macro2::TokenStream> = fields_to_index()
.map(|f| {
let field_name = f.ident.as_ref().unwrap();
let field_name_string = field_name.to_string();
let index_name = format_ident!("_{}_index", field_name);
let (_ordering, uniqueness) = get_index_kind(f).unwrap_or_else(|| {
abort_call_site!("Attributes must be in the style #[multi_index(hashed_unique)]")
});
match uniqueness {
Uniqueness::Unique => quote! {
let orig_elem_idx = self.#index_name.insert(elem.#field_name.clone(), idx);
if orig_elem_idx.is_some() {
panic!("Unable to insert element, uniqueness constraint violated on field '{}'", #field_name_string);
}
},
Uniqueness::NonUnique => quote! {
self.#index_name.entry(elem.#field_name.clone()).or_insert(Vec::with_capacity(1)).push(idx);
},
}
})
.collect();
let removes: Vec<proc_macro2::TokenStream> = fields_to_index()
.map(|f| {
let field_name = f.ident.as_ref().unwrap();
let index_name = format_ident!("_{}_index", field_name);
let field_name_string = field_name.to_string();
let error_msg = format!("Internal invariants broken, unable to find element in index '{field_name_string}' despite being present in another");
let (_ordering, uniqueness) = get_index_kind(f).unwrap_or_else(|| {
abort_call_site!("Attributes must be in the style #[multi_index(hashed_unique)]")
});
match uniqueness {
Uniqueness::Unique => quote! {
let removed_elem = self.#index_name.remove(&elem_orig.#field_name);
},
Uniqueness::NonUnique => quote! {
if let Some(mut elems) = self.#index_name.remove(&elem_orig.#field_name) {
if elems.len() > 1 {
let pos = elems.iter().position(|e| *e == idx).expect(#error_msg);
elems.remove(pos);
self.#index_name.insert(elem_orig.#field_name.clone(), elems);
}
}
}
}
})
.collect();
let modifies: Vec<proc_macro2::TokenStream> = fields_to_index().map(|f| {
let field_name = f.ident.as_ref().unwrap();
let field_name_string = field_name.to_string();
let index_name = format_ident!("_{}_index", field_name);
let error_msg = format!("Internal invariants broken, unable to find element in index '{field_name_string}' despite being present in another");
let (_ordering, uniqueness) = get_index_kind(f).unwrap_or_else(|| {
abort_call_site!("Attributes must be in the style #[multi_index(hashed_unique)]")
});
match uniqueness {
Uniqueness::Unique => quote! {
let idx = self.#index_name.remove(&elem_orig.#field_name).expect(#error_msg);
let orig_elem_idx = self.#index_name.insert(elem.#field_name.clone(), idx);
if orig_elem_idx.is_some() {
panic!("Unable to insert element, uniqueness constraint violated on field '{}'", #field_name_string);
}
},
Uniqueness::NonUnique => quote! {
let idxs = self.#index_name.get_mut(&elem_orig.#field_name).expect(#error_msg);
let pos = idxs.iter().position(|x| *x == idx).expect(#error_msg);
idxs.remove(pos);
self.#index_name.entry(elem.#field_name.clone()).or_insert(Vec::with_capacity(1)).push(idx);
},
}
}).collect();
let clears: Vec<proc_macro2::TokenStream> = fields_to_index()
.map(|f| {
let field_name = f.ident.as_ref().unwrap();
let index_name = format_ident!("_{}_index", field_name);
quote!{
self.#index_name.clear();
}
})
.collect();
let element_name = input.ident;
let map_name = format_ident!("MultiIndex{}Map", element_name);
let accessors = fields_to_index().map(|f| {
let field_name = f.ident.as_ref().unwrap();
let index_name = format_ident!("_{}_index", field_name);
let getter_name = format_ident!("get_by_{}", field_name);
let mut_getter_name = format_ident!("get_mut_by_{}", field_name);
let remover_name = format_ident!("remove_by_{}", field_name);
let modifier_name = format_ident!("modify_by_{}", field_name);
let iter_name = format_ident!("{}{}Iter", map_name, field_name.to_string().to_case(convert_case::Case::UpperCamel));
let iter_getter_name = format_ident!("iter_by_{}", field_name);
let ty = &f.ty;
let (_ordering, uniqueness) = get_index_kind(f).unwrap_or_else(|| {
abort_call_site!("Attributes must be in the style #[multi_index(hashed_unique)]")
});
let getter = match uniqueness {
Uniqueness::Unique => quote! {
pub(super) fn #getter_name(&self, key: &#ty) -> Option<&#element_name> {
Some(&self._store[*self.#index_name.get(key)?])
}
},
Uniqueness::NonUnique => quote! {
pub(super) fn #getter_name(&self, key: &#ty) -> Vec<&#element_name> {
if let Some(idxs) = self.#index_name.get(key) {
let mut elem_refs = Vec::with_capacity(idxs.len());
for idx in idxs {
elem_refs.push(&self._store[*idx])
}
elem_refs
} else {
Vec::new()
}
}
},
};
let mut_getter = match uniqueness {
Uniqueness::Unique => quote! {
pub(super) unsafe fn #mut_getter_name(&mut self, key: &#ty) -> Option<&mut #element_name> {
Some(&mut self._store[*self.#index_name.get(key)?])
}
},
Uniqueness::NonUnique => quote! {},
};
let remover = match uniqueness {
Uniqueness::Unique => quote! {
pub(super) fn #remover_name(&mut self, key: &#ty) -> Option<#element_name> {
let idx = self.#index_name.remove(key)?;
let elem_orig = self._store.remove(idx);
#(#removes)*
Some(elem_orig)
}
},
Uniqueness::NonUnique => quote! {
pub(super) fn #remover_name(&mut self, key: &#ty) -> Vec<#element_name> {
if let Some(idxs) = self.#index_name.remove(key) {
let mut elems = Vec::with_capacity(idxs.len());
for idx in idxs {
let elem_orig = self._store.remove(idx);
#(#removes)*
elems.push(elem_orig)
}
elems
} else {
Vec::new()
}
}
},
};
let modifier = match uniqueness {
Uniqueness::Unique => quote! {
pub(super) fn #modifier_name(&mut self, key: &#ty, f: impl FnOnce(&mut #element_name)) -> Option<&#element_name> {
let idx = *self.#index_name.get(key)?;
let elem = &mut self._store[idx];
let elem_orig = elem.clone();
f(elem);
#(#modifies)*
Some(elem)
}
},
Uniqueness::NonUnique => quote! {},
};
quote! {
#getter
#mut_getter
#remover
#modifier
pub(super) fn #iter_getter_name(&mut self) -> #iter_name {
#iter_name {
_store_ref: &self._store,
_iter: self.#index_name.iter(),
_inner_iter: None,
}
}
}
});
let iterators = fields_to_index().map(|f| {
let field_name = f.ident.as_ref().unwrap();
let field_name_string = field_name.to_string();
let error_msg = format!("Internal invariants broken, found empty slice in non_unique index '{field_name_string}'");
let iter_name = format_ident!(
"{}{}Iter",
map_name,
field_name
.to_string()
.to_case(convert_case::Case::UpperCamel)
);
let ty = &f.ty;
let (ordering, uniqueness) = get_index_kind(f).unwrap_or_else(|| {
abort_call_site!("Attributes must be in the style #[multi_index(hashed_unique)]")
});
let iter_type = match uniqueness {
Uniqueness::Unique => match ordering {
Ordering::Hashed => quote! {std::collections::hash_map::Iter<'a, #ty, usize>},
Ordering::Ordered => quote! {std::collections::btree_map::Iter<'a, #ty, usize>},
}
Uniqueness::NonUnique => match ordering {
Ordering::Hashed => quote! {std::collections::hash_map::Iter<'a, #ty, Vec<usize>>},
Ordering::Ordered => quote! {std::collections::btree_map::Iter<'a, #ty, Vec<usize>>},
}
};
let iter_action = match uniqueness {
Uniqueness::Unique => quote! { Some(&self._store_ref[*self._iter.next()?.1]) },
Uniqueness::NonUnique => quote! {
let inner_next = if let Some(inner_iter) = &mut self._inner_iter {
inner_iter.next()
} else {
None
};
if let Some(next_index) = inner_next {
Some(&self._store_ref[*next_index])
} else {
let hashmap_next = self._iter.next()?;
self._inner_iter = Some(hashmap_next.1.iter());
Some(&self._store_ref[*self._inner_iter.as_mut().unwrap().next().expect(#error_msg)])
}
},
};
quote! {
pub(super) struct #iter_name<'a> {
_store_ref: &'a slab::Slab<#element_name>,
_iter: #iter_type,
_inner_iter: Option<core::slice::Iter<'a, usize>>,
}
impl<'a> Iterator for #iter_name<'a> {
type Item = &'a #element_name;
fn next(&mut self) -> Option<Self::Item> {
#iter_action
}
}
}
});
let mod_name = format_ident!(
"multi_index_{}",
element_name
.to_string()
.to_case(convert_case::Case::Snake)
);
let expanded = quote! {
mod #mod_name {
use super::*;
#[derive(Default, Clone)]
pub(super) struct #map_name {
_store: slab::Slab<#element_name>,
#(#lookup_table_fields)*
}
impl #map_name {
pub(super) fn len(&self) -> usize {
self._store.len()
}
pub(super) fn is_empty(&self) -> bool {
self._store.is_empty()
}
pub(super) fn insert(&mut self, elem: #element_name) {
let idx = self._store.insert(elem);
let elem = &self._store[idx];
#(#inserts)*
}
pub(super) fn clear(&mut self) {
self._store.clear();
#(#clears)*
}
pub(super) fn iter(&self) -> slab::Iter<#element_name> {
self._store.iter()
}
pub(super) unsafe fn iter_mut(&mut self) -> slab::IterMut<#element_name> {
self._store.iter_mut()
}
#(#accessors)*
}
#(#iterators)*
}
};
proc_macro::TokenStream::from(expanded)
}
enum Ordering {
Hashed,
Ordered,
}
#[allow(clippy::enum_variant_names)]
enum Uniqueness {
Unique,
NonUnique,
}
fn get_index_kind(f: &syn::Field) -> Option<(Ordering, Uniqueness)> {
let meta_list = match f.attrs.first()?.parse_meta() {
Ok(syn::Meta::List(l)) => l,
_ => return None,
};
let nested = meta_list.nested.first()?;
let nested_path = match nested {
syn::NestedMeta::Meta(syn::Meta::Path(p)) => p,
_ => return None,
};
if nested_path.is_ident("hashed_unique") {
Some((Ordering::Hashed, Uniqueness::Unique))
} else if nested_path.is_ident("ordered_unique") {
Some((Ordering::Ordered, Uniqueness::Unique))
} else if nested_path.is_ident("hashed_non_unique") {
Some((Ordering::Hashed, Uniqueness::NonUnique))
} else if nested_path.is_ident("ordered_non_unique") {
Some((Ordering::Ordered, Uniqueness::NonUnique))
} else {
None
}
}