use hashbrown::HashTable;
use std::borrow::Borrow;
use std::fmt::Debug;
use std::hash::Hasher;
use std::iter::FusedIterator;
#[derive(Clone, Debug)]
pub struct IndexMap<K, V, S = ahash::RandomState> {
pub(crate) values: Vec<V>,
pub(crate) index_set: IndexSet<K, S>,
}
impl<K, V, S> IndexMap<K, V, S>
where
K: Eq + std::hash::Hash + Clone,
S: std::hash::BuildHasher,
{
pub fn with_hasher(hash_builder: S) -> Self {
IndexMap {
values: Vec::new(),
index_set: IndexSet::with_hasher(hash_builder),
}
}
pub fn with_capacity(capacity: usize) -> Self
where
S: Default,
{
IndexMap {
values: Vec::with_capacity(capacity),
index_set: IndexSet::with_capacity(capacity),
}
}
pub fn new() -> Self
where
S: Default,
{
IndexMap {
values: Vec::new(),
index_set: IndexSet::new(),
}
}
pub fn with_hasher_capacity(hash_builder: S, capacity: usize) -> Self {
IndexMap {
values: Vec::with_capacity(capacity),
index_set: IndexSet::with_hasher_capacity(hash_builder, capacity),
}
}
pub fn shrink_to_fit(&mut self) {
self.values.shrink_to_fit();
self.index_set.shrink_to_fit();
}
#[inline]
pub fn len(&self) -> usize {
self.values.len()
}
#[inline]
pub fn iter_values(&self) -> std::slice::Iter<'_, V> {
self.values.iter()
}
#[inline]
pub fn iter_keys(&self) -> std::slice::Iter<'_, K> {
self.index_set.keys.iter()
}
#[inline]
pub fn iter(&self) -> IndexMapIter<'_, K, V, S> {
IndexMapIter {
map: self,
index: 0,
end: self.len(),
}
}
#[inline]
pub fn into_iter(self) -> impl Iterator<Item = (K, V)> {
let keys = self.index_set.keys;
let values = self.values;
keys.into_iter().zip(values.into_iter())
}
#[inline]
pub fn values(&self) -> &Vec<V> {
&self.values
}
#[inline]
pub fn keys(&self) -> &Vec<K> {
&self.index_set.keys()
}
#[inline]
pub fn as_index_set(&self) -> &IndexSet<K, S> {
&self.index_set
}
#[inline]
pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V>
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
if let Some(idx) = self.index_set.table_get(key) {
unsafe {
Some(self.values.get_unchecked(idx))
}
} else {
None
}
}
#[inline]
pub fn get_mut<Q: ?Sized>(&mut self, key: &Q) -> Option<&mut V>
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
if let Some(idx) = self.index_set.table_get(key) {
unsafe {
Some(self.values.get_unchecked_mut(idx))
}
} else {
None
}
}
#[inline]
pub fn get_with_index(&self, index: usize) -> Option<&V> {
self.values.get(index)
}
#[inline]
pub fn get_with_index_mut(&mut self, index: usize) -> Option<&mut V> {
self.values.get_mut(index)
}
#[inline]
pub fn get_key_with_index(&self, index: usize) -> Option<&K> {
self.index_set.get_with_index(index)
}
#[inline]
pub fn get_key_value_with_index(&self, index: usize) -> Option<(&K, &V)> {
if index < self.len() {
unsafe {
Some((
self.index_set.keys.get_unchecked(index),
self.values.get_unchecked(index),
))
}
} else {
None
}
}
#[inline]
pub fn get_index<Q: ?Sized>(&self, key: &Q) -> Option<usize>
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
self.index_set.table_get(key)
}
pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
self.index_set.contains_key(key)
}
#[inline]
pub fn insert(&mut self, key: K, value: V) -> InsertResult<K, V> {
if let Some(idx) = self.index_set.table_get(&key) {
unsafe {
self.index_set.table_override(&key, &idx);
}
let old_value = Some(std::mem::replace(&mut self.values[idx], value));
let old_key = Some(std::mem::replace(&mut self.index_set.keys[idx], key.clone()));
InsertResult::Override {
old_value: old_value.unwrap(),
old_key: old_key.unwrap(),
index: idx,
}
} else {
let idx = self.values.len();
unsafe {
self.index_set.table_append(&key, &idx);
}
self.index_set.keys.push(key.clone());
self.values.push(value);
InsertResult::New {
index: idx,
}
}
}
#[inline]
pub fn entry_mut<'a>(&'a mut self, key: K) -> EntryMut<'a, K, V, S> {
if let Some(idx) = self.index_set.table_get(&key) {
unsafe {
EntryMut::Occupied {
key: key,
value: self.values.get_unchecked_mut(idx),
index: idx,
}
}
} else {
EntryMut::Vacant { key , map: self }
}
}
#[inline]
pub fn swap_remove<Q: ?Sized>(&mut self, key: &Q) -> RemoveResult<K, V>
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
if let Some(idx) = self.index_set.table_get(key) {
let last_idx = self.values.len() - 1;
if idx == last_idx {
unsafe {
self.index_set.table_swap_remove(key);
}
return RemoveResult::Removed {
old_value: self.values.pop().unwrap(),
old_key: self.index_set.keys.pop().unwrap(),
index: idx,
};
} else {
let last_idx_key = self.index_set.keys[last_idx].clone();
unsafe {
self.index_set.table_swap_remove(key);
self.index_set.table_override(last_idx_key.borrow(), &idx);
}
let value = self.values.swap_remove(idx);
self.index_set.keys.swap_remove(idx);
RemoveResult::Removed {
old_value: value,
old_key: last_idx_key,
index: idx,
}
}
} else {
RemoveResult::None
}
}
#[inline]
pub fn from_kv_vec(k_vec: Vec<K>, v_vec: Vec<V>) -> Self
where
S: std::hash::BuildHasher + Default,
{
let hash_builder = S::default();
let mut map = IndexMap::with_hasher(hash_builder);
for (k, v) in k_vec.into_iter().zip(v_vec.into_iter()) {
let idx = map.values.len();
unsafe {
map.index_set.table_append(&k, &idx);
}
map.index_set.keys.push(k);
map.values.push(v);
}
map
}
}
pub struct IndexMapIter<'a, K, V, S> {
pub map: &'a IndexMap<K, V, S>,
pub index: usize,
pub end: usize, }
impl<'a, K, V, S> Iterator for IndexMapIter<'a, K, V, S>
where
K: Eq + std::hash::Hash,
S: std::hash::BuildHasher,
{
type Item = (&'a K, &'a V);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.end {
unsafe {
let i = self.index;
self.index += 1;
let k = self.map.index_set.keys.get_unchecked(i);
let v = self.map.values.get_unchecked(i);
Some((k, v))
}
} else {
None
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.end - self.index;
(remaining, Some(remaining))
}
#[inline]
fn nth(&mut self, n: usize) -> Option<Self::Item> {
let new = self.index.saturating_add(n);
self.index = new.min(self.end);
self.next()
}
}
impl<'a, K, V, S> ExactSizeIterator for IndexMapIter<'a, K, V, S>
where
K: Eq + std::hash::Hash,
S: std::hash::BuildHasher,
{
#[inline]
fn len(&self) -> usize {
self.end - self.index
}
}
impl<'a, K, V, S> DoubleEndedIterator for IndexMapIter<'a, K, V, S>
where
K: Eq + std::hash::Hash,
S: std::hash::BuildHasher,
{
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
if self.index < self.end {
unsafe {
self.end -= 1;
let i = self.end;
let k = self.map.index_set.keys.get_unchecked(i);
let v = self.map.values.get_unchecked(i);
Some((k, v))
}
} else {
None
}
}
}
impl<'a, K, V, S> FusedIterator for IndexMapIter<'a, K, V, S>
where
K: Eq + std::hash::Hash,
S: std::hash::BuildHasher,
{}
pub enum EntryMut<'a, K, V, S> {
Occupied { key: K, value: &'a mut V, index: usize },
Vacant { key: K , map: &'a mut IndexMap<K, V, S> },
}
impl<'a, K, V, S> EntryMut<'a, K, V, S>
where
K: Eq + std::hash::Hash + Clone,
S: std::hash::BuildHasher,
{
#[inline]
pub fn is_occupied(&self) -> bool {
matches!(self, EntryMut::Occupied { .. })
}
#[inline]
pub fn or_insert_with<F>(self, value: F) -> &'a mut V
where
F: FnOnce() -> V,
K: Clone,
{
match self {
EntryMut::Occupied { value: v, .. } => v,
EntryMut::Vacant { key, map } => {
map.insert(key.clone(), value());
map.get_mut(&key).unwrap()
}
}
}
}
#[derive(Debug, PartialEq)]
pub enum InsertResult<K, V> {
Override {
old_value: V,
old_key: K,
index: usize,
},
New {
index: usize,
}
}
#[derive(Debug, PartialEq)]
pub enum RemoveResult<K, V> {
Removed {
old_value: V,
old_key: K,
index: usize,
},
None,
}
#[derive(Clone, Debug)]
pub struct IndexSet<K, S = ahash::RandomState> {
pub(crate) keys: Vec<K>,
pub(crate) hashes: Vec<u64>,
pub(crate) table: HashTable<usize>,
pub(crate) hash_builder: S,
}
impl<K, S> IndexSet<K, S>
where
K: Eq + std::hash::Hash,
S: std::hash::BuildHasher,
{
pub fn with_hasher(hash_builder: S) -> Self {
IndexSet {
keys: Vec::new(),
hashes: Vec::new(),
table: HashTable::new(),
hash_builder,
}
}
pub fn new() -> Self
where
S: Default,
{
IndexSet {
keys: Vec::new(),
hashes: Vec::new(),
table: HashTable::new(),
hash_builder: S::default(),
}
}
pub fn with_capacity(capacity: usize) -> Self
where
S: Default,
{
IndexSet {
keys: Vec::with_capacity(capacity),
hashes: Vec::with_capacity(capacity),
table: HashTable::with_capacity(capacity),
hash_builder: S::default(),
}
}
pub fn with_hasher_capacity(hash_builder: S, capacity: usize) -> Self {
IndexSet {
keys: Vec::with_capacity(capacity),
hashes: Vec::with_capacity(capacity),
table: HashTable::with_capacity(capacity),
hash_builder,
}
}
pub fn capacity(&self) -> usize {
self.keys.capacity()
}
pub fn shrink_to_fit(&mut self) {
self.keys.shrink_to_fit();
self.hashes.shrink_to_fit();
}
#[inline]
fn hash_key<Q: ?Sized>(&self, key: &Q) -> u64
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
let mut hasher = self.hash_builder.build_hasher();
key.hash(&mut hasher);
hasher.finish()
}
#[inline]
unsafe fn table_override<Q: ?Sized>(&mut self, key: &Q, idx: &usize) -> Option<usize>
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
let hash = self.hash_key(key);
match self.table.find_entry(hash, |&i| self.keys[i].borrow() == key) {
Ok(mut occ) => {
*occ.get_mut() = *idx;
Some(*idx)
}
Err(_) => {
None
}
}
}
#[inline]
unsafe fn table_append<Q: ?Sized>(&mut self, key: &Q, idx: &usize)
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
let hash = self.hash_key(key);
self.hashes.push(hash);
self.table.insert_unique(
hash,
*idx,
|&i| self.hashes[i]
);
}
#[inline]
fn table_get<Q: ?Sized>(&self, key: &Q) -> Option<usize>
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
let hash = self.hash_key(key);
self.table.find(
hash,
|&i| self.keys[i].borrow() == key
).copied()
}
#[inline]
unsafe fn table_swap_remove<Q: ?Sized>(&mut self, key: &Q) -> Option<usize>
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
let hash = self.hash_key(key);
if let Ok(entry) = self.table.find_entry(
hash,
|&i| self.keys[i].borrow() == key
) {
let (odl_idx, _) = entry.remove();
self.hashes.swap_remove(odl_idx);
Some(odl_idx)
} else {
None
}
}
#[inline]
pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
self.table_get(key).is_some()
}
#[inline]
pub fn len(&self) -> usize {
self.keys.len()
}
#[inline]
pub fn get_index<Q: ?Sized>(&self, key: &Q) -> Option<usize>
where
K: Borrow<Q>,
Q: std::hash::Hash + Eq,
{
self.table_get(key)
}
#[inline]
pub fn iter(&self) -> std::slice::Iter<'_, K> {
self.keys.iter()
}
#[inline]
pub fn keys(&self) -> &Vec<K> {
&self.keys
}
#[inline]
pub fn get_with_index(&self, index: usize) -> Option<&K> {
self.keys.get(index)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
type M = IndexMap<u64, i64>;
fn assert_internal_invariants(map: &M) {
assert_eq!(map.values.len(), map.index_set.keys.len(), "values/keys len mismatch");
assert_eq!(map.values.len(), map.index_set.hashes.len(), "values/hashes len mismatch");
for (i, k) in map.index_set.keys.iter().enumerate() {
let idx = map.index_set.table_get(k).expect("table_get must find existing key");
assert_eq!(idx, i, "table idx mismatch for key");
}
for i in 0..map.len() {
let k = &map.index_set.keys[i];
assert!(map.contains_key(k), "contains_key false for existing key");
let v = map.get(k).expect("get must return for existing key");
assert_eq!(*v, map.values[i], "get value mismatch");
}
for i in 0..map.index_set.keys.len() {
for j in (i + 1)..map.index_set.keys.len() {
assert!(map.index_set.keys[i] != map.index_set.keys[j], "duplicate keys detected");
}
}
}
fn assert_equals_oracle(map: &M, oracle: &HashMap<u64, i64>) {
assert_eq!(map.len(), oracle.len(), "len mismatch");
for (k, v) in oracle.iter() {
let got = map.get(k).copied();
assert_eq!(got, Some(*v), "value mismatch for key={k}");
}
for (k, v) in map.iter() {
assert_eq!(oracle.get(k).copied(), Some(*v), "extra/mismatch entry key={k}");
}
}
#[test]
fn swap_remove_last_and_middle() {
let mut m = M::new();
for i in 0..10 {
m.insert(i, (i as i64) * 10);
}
let v = m.swap_remove(&9);
assert_eq!(v, RemoveResult::Removed {
old_value: 90,
old_key: 9,
index: 9,
});
assert!(m.get(&9).is_none());
let v = m.swap_remove(&3);
assert_eq!(v, RemoveResult::Removed {
old_value: 30,
old_key: 3,
index: 3,
});
assert!(m.get(&3).is_none());
assert_internal_invariants(&m);
}
#[test]
fn entry_or_insert_with_works() {
let mut m = M::new();
let v = m.entry_mut(7).or_insert_with(|| 123);
assert_eq!(*v, 123);
let v2 = m.entry_mut(7).or_insert_with(|| 999);
assert_eq!(*v2, 123);
assert_internal_invariants(&m);
}
#[test]
fn compare_with_std_hashmap_small_scripted() {
let mut m = M::new();
let mut o = HashMap::<u64, i64>::new();
for i in 0..50u64 {
m.insert(i, i as i64);
o.insert(i, i as i64);
}
for i in 0..50u64 {
if i % 3 == 0 {
let a = m.swap_remove(&i);
let b = o.remove(&i);
assert_eq!(a, RemoveResult::Removed {
old_value: b.unwrap(),
old_key: i,
index: 0, });
}
}
for i in 0..50u64 {
if i % 5 == 0 {
m.insert(i, (i as i64) * 100);
o.insert(i, (i as i64) * 100);
}
}
assert_internal_invariants(&m);
assert_equals_oracle(&m, &o);
}
#[test]
fn randomized_ops_compare_with_oracle() {
use rand::{rngs::StdRng, Rng, SeedableRng};
let mut rng = StdRng::seed_from_u64(0xC0FFEE);
let mut m = M::new();
let mut o = HashMap::<u64, i64>::new();
const STEPS: usize = 30_000;
const KEY_SPACE: u64 = 2_000;
for _ in 0..STEPS {
let op = rng.gen_range(0..100);
let k = rng.gen_range(0..KEY_SPACE);
match op {
0..=59 => {
let v = rng.gen_range(-1_000_000..=1_000_000);
let a = m.insert(k, v);
let b = o.insert(k, v);
match (a, b) {
(InsertResult::New { .. }, None) => {}
(InsertResult::Override { old_key, old_value, .. }, Some(old)) => {
assert_eq!(old_key, k);
assert_eq!(old_value, old);
}
_ => panic!("insert mismatch"),
}
}
60..=79 => {
let a = m.swap_remove(&k);
let b = o.remove(&k);
assert_eq!(a, RemoveResult::Removed {
old_value: b.unwrap(),
old_key: k,
index: 0, });
}
80..=94 => {
let a = m.get(&k).copied();
let b = o.get(&k).copied();
assert_eq!(a, b);
}
_ => {
let a = m.contains_key(&k);
let b = o.contains_key(&k);
assert_eq!(a, b);
}
}
if rng.gen_ratio(1, 200) {
assert_internal_invariants(&m);
assert_equals_oracle(&m, &o);
}
}
assert_internal_invariants(&m);
assert_equals_oracle(&m, &o);
}
#[test]
fn empty_map_basics() {
let m = M::new();
assert_eq!(m.len(), 0);
assert!(m.get(&123).is_none());
assert!(!m.contains_key(&123));
assert_eq!(m.values.len(), 0);
assert_eq!(m.index_set.keys.len(), 0);
assert_eq!(m.index_set.hashes.len(), 0);
}
#[test]
fn swap_remove_single_element_roundtrip() {
let mut m = M::new();
m.insert(42, -7);
assert_internal_invariants(&m);
let v = m.swap_remove(&42);
assert_eq!(v, RemoveResult::Removed {
old_value: -7,
old_key: 42,
index: 0,
});
assert_eq!(m.len(), 0);
assert!(m.get(&42).is_none());
assert!(!m.contains_key(&42));
assert_internal_invariants(&m);
}
#[test]
fn remove_then_reinsert_same_key() {
let mut m = M::new();
m.insert(1, 10);
m.insert(2, 20);
m.insert(3, 30);
assert_internal_invariants(&m);
assert_eq!(m.swap_remove(&2), RemoveResult::Removed {
old_value: 20,
old_key: 2,
index: 1,
});
assert!(m.get(&2).is_none());
assert_internal_invariants(&m);
assert_eq!(m.insert(2, 200), InsertResult::New { index: 1 });
assert_eq!(m.get(&2).copied(), Some(200));
assert_internal_invariants(&m);
}
#[test]
fn from_kv_vec_builds_valid_map() {
let keys = vec![1u64, 2u64, 3u64, 10u64];
let values = vec![10i64, 20i64, 30i64, 100i64];
let m = M::from_kv_vec(keys.clone(), values.clone());
assert_eq!(m.len(), 4);
assert_eq!(m.index_set.keys, keys);
assert_eq!(m.values, values);
assert_internal_invariants(&m);
}
#[test]
fn iter_order_matches_internal_storage_even_after_removes() {
let mut m = M::new();
for i in 0..8u64 {
m.insert(i, (i as i64) + 100);
}
assert_internal_invariants(&m);
assert_eq!(m.swap_remove(&0), RemoveResult::Removed {
old_value: 100,
old_key: 0,
index: 0,
});
assert_eq!(m.swap_remove(&5), RemoveResult::Removed {
old_value: 105,
old_key: 5,
index: 5,
});
assert_internal_invariants(&m);
let collected: Vec<(u64, i64)> = m.iter().map(|(k, v)| (*k, *v)).collect();
let expected: Vec<(u64, i64)> = m.index_set.keys.iter().copied().zip(m.values.iter().copied()).collect();
assert_eq!(collected, expected);
}
#[test]
fn num_23257_issue() {
const ITER_NUM: u64 = 223259;
let mut map: IndexMap<u64, Vec<u64>> = IndexMap::new();
for i in 0..ITER_NUM {
map.entry_mut(0).or_insert_with(Vec::new).push(i);
}
assert_eq!(map.len(), 1);
assert_eq!(map.get(&0).unwrap().len() as u64, ITER_NUM);
}
}