use crate::error::{Result, ZiporaError};
use crate::hash_map::ZiporaHashMap;
use std::fmt;
use std::hash::Hash;
use std::mem::MaybeUninit;
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
use std::arch::x86_64::*;
pub const SMALL_MAP_THRESHOLD: usize = 8;
#[repr(align(64))] pub struct SmallMap<K, V>
where
K: Clone + std::hash::Hash + Eq,
V: Clone,
{
storage: SmallMapStorage<K, V>,
}
enum SmallMapStorage<K, V>
where
K: Clone + std::hash::Hash + Eq,
V: Clone,
{
Small {
keys: [MaybeUninit<K>; SMALL_MAP_THRESHOLD],
values: [MaybeUninit<V>; SMALL_MAP_THRESHOLD],
len: usize,
},
Large(ZiporaHashMap<K, V>),
}
impl<K, V> SmallMap<K, V>
where
K: Clone + std::hash::Hash + Eq,
V: Clone,
{
pub fn new() -> Self {
Self {
storage: SmallMapStorage::Small {
keys: [const { MaybeUninit::uninit() }; SMALL_MAP_THRESHOLD],
values: [const { MaybeUninit::uninit() }; SMALL_MAP_THRESHOLD],
len: 0,
},
}
}
#[inline]
pub fn len(&self) -> usize {
match &self.storage {
SmallMapStorage::Small { len, .. } => *len,
SmallMapStorage::Large(map) => map.len(),
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn capacity(&self) -> usize {
match &self.storage {
SmallMapStorage::Small { .. } => SMALL_MAP_THRESHOLD,
SmallMapStorage::Large(map) => map.capacity(),
}
}
}
impl<K: PartialEq + Hash + Eq + 'static + Clone, V: Clone> SmallMap<K, V> {
#[inline(always)]
fn find_key_index(
&self,
key: &K,
keys: &[MaybeUninit<K>; SMALL_MAP_THRESHOLD],
len: usize,
) -> Option<usize> {
if len == 0 {
return None;
}
debug_assert!(len <= SMALL_MAP_THRESHOLD, "len {} exceeds SMALL_MAP_THRESHOLD {}", len, SMALL_MAP_THRESHOLD);
match len {
1 => {
let k0 = unsafe { keys[0].assume_init_ref() };
if k0 == key { Some(0) } else { None }
}
2 => {
let k0 = unsafe { keys[0].assume_init_ref() };
let k1 = unsafe { keys[1].assume_init_ref() };
if k0 == key {
Some(0)
} else if k1 == key {
Some(1)
} else {
None
}
}
3 => {
let k0 = unsafe { keys[0].assume_init_ref() };
let k1 = unsafe { keys[1].assume_init_ref() };
let k2 = unsafe { keys[2].assume_init_ref() };
if k0 == key {
Some(0)
} else if k1 == key {
Some(1)
} else if k2 == key {
Some(2)
} else {
None
}
}
4 => {
let k0 = unsafe { keys[0].assume_init_ref() };
let k1 = unsafe { keys[1].assume_init_ref() };
let k2 = unsafe { keys[2].assume_init_ref() };
let k3 = unsafe { keys[3].assume_init_ref() };
if k0 == key {
Some(0)
} else if k1 == key {
Some(1)
} else if k2 == key {
Some(2)
} else if k3 == key {
Some(3)
} else {
None
}
}
5..=8 => {
let k0 = unsafe { keys[0].assume_init_ref() };
let k1 = unsafe { keys[1].assume_init_ref() };
let k2 = unsafe { keys[2].assume_init_ref() };
let k3 = unsafe { keys[3].assume_init_ref() };
if k0 == key {
return Some(0);
}
if k1 == key {
return Some(1);
}
if k2 == key {
return Some(2);
}
if k3 == key {
return Some(3);
}
for i in 4..len {
let existing_key = unsafe { keys[i].assume_init_ref() };
if existing_key == key {
return Some(i);
}
}
None
}
_ => {
self.find_key_fallback(key, keys, len)
}
}
}
#[inline(never)] #[cold]
fn find_key_fallback(
&self,
key: &K,
keys: &[MaybeUninit<K>; SMALL_MAP_THRESHOLD],
len: usize,
) -> Option<usize> {
for i in 0..len {
let existing_key = unsafe { keys[i].assume_init_ref() };
if existing_key == key {
return Some(i);
}
}
None
}
pub fn insert(&mut self, key: K, value: V) -> Result<Option<V>> {
match &mut self.storage {
SmallMapStorage::Small { keys, values, len } => {
for i in 0..*len {
let existing_key = unsafe { keys[i].assume_init_ref() };
if *existing_key == key {
let existing_value = unsafe { values[i].assume_init_mut() };
let old_value = std::mem::replace(existing_value, value);
return Ok(Some(old_value));
}
}
if *len >= SMALL_MAP_THRESHOLD {
self.promote_to_large()?;
return self.insert(key, value); }
keys[*len] = MaybeUninit::new(key);
values[*len] = MaybeUninit::new(value);
*len += 1;
Ok(None)
}
SmallMapStorage::Large(map) => map
.insert(key, value)
.map_err(|_| ZiporaError::invalid_data("Failed to insert into large map")),
}
}
pub fn get(&self, key: &K) -> Option<&V> {
match &self.storage {
SmallMapStorage::Small { keys, values, len } => {
if let Some(index) = self.find_key_index(key, keys, *len) {
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
unsafe {
if *len > 4 {
_mm_prefetch(values[index].as_ptr() as *const i8, _MM_HINT_T0);
}
}
Some(unsafe { values[index].assume_init_ref() })
} else {
None
}
}
SmallMapStorage::Large(map) => map.get(key),
}
}
pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
match &mut self.storage {
SmallMapStorage::Small { keys, values, len } => {
let len = *len; let mut found_index = None;
for i in 0..len {
let existing_key = unsafe { keys[i].assume_init_ref() };
if existing_key == key {
found_index = Some(i);
break;
}
}
if let Some(index) = found_index {
Some(unsafe { values[index].assume_init_mut() })
} else {
None
}
}
SmallMapStorage::Large(map) => map.get_mut(key),
}
}
pub fn remove(&mut self, key: &K) -> Option<V> {
match &mut self.storage {
SmallMapStorage::Small { keys, values, len } => {
for i in 0..*len {
let existing_key = unsafe { keys[i].assume_init_ref() };
if existing_key == key {
*len -= 1;
let removed_value = unsafe { values[i].assume_init_read() };
let _removed_key = unsafe { keys[i].assume_init_read() };
if i < *len {
keys[i] = unsafe { std::ptr::read(&keys[*len]) };
values[i] = unsafe { std::ptr::read(&values[*len]) };
}
return Some(removed_value);
}
}
None
}
SmallMapStorage::Large(map) => map.remove(key),
}
}
pub fn contains_key(&self, key: &K) -> bool {
self.get(key).is_some()
}
pub fn clear(&mut self) {
match &mut self.storage {
SmallMapStorage::Small { keys, values, len } => {
for i in 0..*len {
unsafe {
keys[i].assume_init_drop();
values[i].assume_init_drop();
}
}
*len = 0;
}
SmallMapStorage::Large(_) => {
self.storage = SmallMapStorage::Small {
keys: [const { MaybeUninit::uninit() }; SMALL_MAP_THRESHOLD],
values: [const { MaybeUninit::uninit() }; SMALL_MAP_THRESHOLD],
len: 0,
};
}
}
}
pub fn iter(&self) -> SmallMapIter<'_, K, V> {
match &self.storage {
SmallMapStorage::Small { keys, values, len } => SmallMapIter::Small {
keys,
values,
index: 0,
len: *len,
},
SmallMapStorage::Large(_map) => {
panic!("Iterator not yet implemented for large maps with ZiporaHashMap")
},
}
}
fn promote_to_large(&mut self) -> Result<()>
where
K: Clone,
V: Clone,
{
if let SmallMapStorage::Small { keys, values, len } = &mut self.storage {
let mut large_map = ZiporaHashMap::new()?;
for i in 0..*len {
let key = unsafe { keys[i].assume_init_read() };
let value = unsafe { values[i].assume_init_read() };
large_map
.insert(key, value)
.map_err(|_| ZiporaError::invalid_data("Failed to promote to large map"))?;
}
self.storage = SmallMapStorage::Large(large_map);
Ok(())
} else {
Ok(())
}
}
}
impl<K, V> Default for SmallMap<K, V>
where
K: Clone + std::hash::Hash + Eq,
V: Clone,
{
fn default() -> Self {
Self::new()
}
}
impl<K, V> Drop for SmallMap<K, V>
where
K: Clone + std::hash::Hash + Eq,
V: Clone,
{
fn drop(&mut self) {
if let SmallMapStorage::Small { keys, values, len } = &mut self.storage {
for i in 0..*len {
unsafe {
keys[i].assume_init_drop();
values[i].assume_init_drop();
}
}
}
}
}
impl<K: fmt::Debug + PartialEq + Hash + Eq + 'static + Clone, V: fmt::Debug + Clone> fmt::Debug for SmallMap<K, V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map().entries(self.iter()).finish()
}
}
impl<K: Clone + PartialEq + Hash + Eq + 'static, V: Clone> Clone for SmallMap<K, V> {
fn clone(&self) -> Self {
let mut new_map = Self::new();
for (key, value) in self.iter() {
if let Err(_) = new_map.insert(key.clone(), value.clone()) {
break;
}
}
new_map
}
}
impl<K: PartialEq + Hash + Eq + 'static + Clone, V: PartialEq + Clone> PartialEq for SmallMap<K, V> {
fn eq(&self, other: &Self) -> bool {
if self.len() != other.len() {
return false;
}
for (key, value) in self.iter() {
match other.get(key) {
Some(other_value) if value == other_value => {}
_ => return false,
}
}
true
}
}
impl<K: Eq + PartialEq + Hash + 'static + Clone, V: Eq + Clone> Eq for SmallMap<K, V> {}
pub enum SmallMapIter<'a, K, V> {
Small {
keys: &'a [MaybeUninit<K>; SMALL_MAP_THRESHOLD],
values: &'a [MaybeUninit<V>; SMALL_MAP_THRESHOLD],
index: usize,
len: usize,
},
}
impl<'a, K, V> Iterator for SmallMapIter<'a, K, V> {
type Item = (&'a K, &'a V);
fn next(&mut self) -> Option<Self::Item> {
match self {
SmallMapIter::Small {
keys,
values,
index,
len,
} => {
if *index < *len {
let key = unsafe { keys[*index].assume_init_ref() };
let value = unsafe { values[*index].assume_init_ref() };
*index += 1;
Some((key, value))
} else {
None
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self {
SmallMapIter::Small { index, len, .. } => {
let remaining = len - index;
(remaining, Some(remaining))
}
}
}
}
impl<'a, K, V> ExactSizeIterator for SmallMapIter<'a, K, V> {}
trait OptimizedSearch {
fn find_optimized(
&self,
keys: &[MaybeUninit<Self>; SMALL_MAP_THRESHOLD],
len: usize,
) -> Option<usize>
where
Self: Sized + PartialEq;
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
impl OptimizedSearch for u8 {
#[inline(always)]
fn find_optimized(
&self,
keys: &[MaybeUninit<Self>; SMALL_MAP_THRESHOLD],
len: usize,
) -> Option<usize> {
if len == 0 {
return None;
}
unsafe {
let search_vec = _mm_set1_epi8(*self as i8);
let mut key_bytes = [0u8; 8];
for i in 0..len.min(8) {
key_bytes[i] = *keys[i].assume_init_ref();
}
let keys_vec = _mm_loadl_epi64(key_bytes.as_ptr() as *const __m128i);
let cmp = _mm_cmpeq_epi8(search_vec, keys_vec);
let mask = _mm_movemask_epi8(cmp) as u32;
if mask != 0 {
return Some(mask.trailing_zeros() as usize);
}
}
None
}
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
impl OptimizedSearch for u32 {
#[inline(always)]
fn find_optimized(
&self,
keys: &[MaybeUninit<Self>; SMALL_MAP_THRESHOLD],
len: usize,
) -> Option<usize> {
if len == 0 {
return None;
}
unsafe {
let search_vec = _mm_set1_epi32(*self as i32);
if len >= 4 {
let mut key_array = [0u32; 4];
for i in 0..4.min(len) {
key_array[i] = *keys[i].assume_init_ref();
}
let keys_vec = _mm_loadu_si128(key_array.as_ptr() as *const __m128i);
let cmp = _mm_cmpeq_epi32(search_vec, keys_vec);
let mask = _mm_movemask_ps(_mm_castsi128_ps(cmp)) as u32;
if mask != 0 {
return Some(mask.trailing_zeros() as usize);
}
}
if len > 4 {
let mut key_array = [0u32; 4];
for i in 4..len.min(8) {
key_array[i - 4] = *keys[i].assume_init_ref();
}
let keys_vec = _mm_loadu_si128(key_array.as_ptr() as *const __m128i);
let cmp = _mm_cmpeq_epi32(search_vec, keys_vec);
let mask = _mm_movemask_ps(_mm_castsi128_ps(cmp)) as u32;
if mask != 0 {
return Some(4 + mask.trailing_zeros() as usize);
}
}
}
None
}
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
impl OptimizedSearch for u64 {
#[inline(always)]
fn find_optimized(
&self,
keys: &[MaybeUninit<Self>; SMALL_MAP_THRESHOLD],
len: usize,
) -> Option<usize> {
if len == 0 {
return None;
}
unsafe {
let search_vec = _mm_set1_epi64x(*self as i64);
for i in (0..len).step_by(2) {
let mut key_array = [0u64; 2];
key_array[0] = *keys[i].assume_init_ref();
if i + 1 < len {
key_array[1] = *keys[i + 1].assume_init_ref();
}
let keys_vec = _mm_loadu_si128(key_array.as_ptr() as *const __m128i);
let cmp = _mm_cmpeq_epi64(search_vec, keys_vec);
let mask = _mm_movemask_pd(_mm_castsi128_pd(cmp)) as u32;
if mask & 0x1 != 0 {
return Some(i);
}
if mask & 0x2 != 0 && i + 1 < len {
return Some(i + 1);
}
}
}
None
}
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
impl OptimizedSearch for i32 {
#[inline(always)]
fn find_optimized(
&self,
keys: &[MaybeUninit<Self>; SMALL_MAP_THRESHOLD],
len: usize,
) -> Option<usize> {
unsafe {
let u32_self = *self as u32;
let u32_keys = std::mem::transmute::<
&[MaybeUninit<i32>; SMALL_MAP_THRESHOLD],
&[MaybeUninit<u32>; SMALL_MAP_THRESHOLD],
>(keys);
u32_self.find_optimized(u32_keys, len)
}
}
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
impl<V> SmallMap<u8, V>
where
V: Clone,
{
#[inline(always)]
fn find_key_index_simd(
&self,
key: &u8,
keys: &[MaybeUninit<u8>; SMALL_MAP_THRESHOLD],
len: usize,
) -> Option<usize> {
if len <= 4 {
match len {
0 => None,
1 => {
let k0 = unsafe { keys[0].assume_init_ref() };
if k0 == key { Some(0) } else { None }
}
2 => {
let k0 = unsafe { keys[0].assume_init_ref() };
let k1 = unsafe { keys[1].assume_init_ref() };
if k0 == key {
Some(0)
} else if k1 == key {
Some(1)
} else {
None
}
}
3 => {
let k0 = unsafe { keys[0].assume_init_ref() };
let k1 = unsafe { keys[1].assume_init_ref() };
let k2 = unsafe { keys[2].assume_init_ref() };
if k0 == key {
Some(0)
} else if k1 == key {
Some(1)
} else if k2 == key {
Some(2)
} else {
None
}
}
4 => {
let k0 = unsafe { keys[0].assume_init_ref() };
let k1 = unsafe { keys[1].assume_init_ref() };
let k2 = unsafe { keys[2].assume_init_ref() };
let k3 = unsafe { keys[3].assume_init_ref() };
if k0 == key {
Some(0)
} else if k1 == key {
Some(1)
} else if k2 == key {
Some(2)
} else if k3 == key {
Some(3)
} else {
None
}
}
_ => unreachable!(),
}
} else {
key.find_optimized(keys, len)
}
}
#[inline(always)]
pub fn get_fast(&self, key: &u8) -> Option<&V>
where
V: Clone,
{
match &self.storage {
SmallMapStorage::Small { keys, values, len } => {
if let Some(index) = self.find_key_index_simd(key, keys, *len) {
unsafe {
_mm_prefetch(values[index].as_ptr() as *const i8, _MM_HINT_T0);
if *len > 6 && index + 1 < *len {
_mm_prefetch(values[index + 1].as_ptr() as *const i8, _MM_HINT_T1);
}
}
Some(unsafe { values[index].assume_init_ref() })
} else {
None
}
}
SmallMapStorage::Large(map) => map.get(key),
}
}
}
unsafe impl<K: Send + Clone + std::hash::Hash + Eq, V: Send + Clone> Send for SmallMap<K, V> {}
unsafe impl<K: Sync + Clone + std::hash::Hash + Eq, V: Sync + Clone> Sync for SmallMap<K, V> {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let map: SmallMap<String, i32> = SmallMap::new();
assert_eq!(map.len(), 0);
assert!(map.is_empty());
assert_eq!(map.capacity(), SMALL_MAP_THRESHOLD);
}
#[test]
fn test_insert_and_get() -> Result<()> {
let mut map = SmallMap::new();
assert_eq!(map.insert("key1", "value1")?, None);
assert_eq!(map.insert("key2", "value2")?, None);
assert_eq!(map.insert("key1", "new_value1")?, Some("value1"));
assert_eq!(map.get(&"key1"), Some(&"new_value1"));
assert_eq!(map.get(&"key2"), Some(&"value2"));
assert_eq!(map.get(&"missing"), None);
assert_eq!(map.len(), 2);
Ok(())
}
#[test]
fn test_remove() -> Result<()> {
let mut map = SmallMap::new();
map.insert("key1", "value1")?;
map.insert("key2", "value2")?;
map.insert("key3", "value3")?;
assert_eq!(map.remove(&"key2"), Some("value2"));
assert_eq!(map.remove(&"key2"), None);
assert_eq!(map.len(), 2);
assert_eq!(map.get(&"key1"), Some(&"value1"));
assert_eq!(map.get(&"key3"), Some(&"value3"));
Ok(())
}
#[test]
fn test_contains_key() -> Result<()> {
let mut map = SmallMap::new();
assert!(!map.contains_key(&"key"));
map.insert("key", "value")?;
assert!(map.contains_key(&"key"));
map.remove(&"key");
assert!(!map.contains_key(&"key"));
Ok(())
}
#[test]
fn test_clear() -> Result<()> {
let mut map = SmallMap::new();
map.insert("key1", "value1")?;
map.insert("key2", "value2")?;
assert_eq!(map.len(), 2);
map.clear();
assert_eq!(map.len(), 0);
assert!(map.is_empty());
Ok(())
}
#[test]
fn test_promotion_to_large() -> Result<()> {
let mut map = SmallMap::new();
for i in 0..SMALL_MAP_THRESHOLD {
map.insert(i, i * 2)?;
}
assert_eq!(map.len(), SMALL_MAP_THRESHOLD);
map.insert(SMALL_MAP_THRESHOLD, SMALL_MAP_THRESHOLD * 2)?;
assert_eq!(map.len(), SMALL_MAP_THRESHOLD + 1);
for i in 0..=SMALL_MAP_THRESHOLD {
assert_eq!(map.get(&i), Some(&(i * 2)));
}
Ok(())
}
#[test]
fn test_iter() -> Result<()> {
let mut map = SmallMap::new();
map.insert("a", 1)?;
map.insert("b", 2)?;
map.insert("c", 3)?;
let mut items: Vec<_> = map.iter().collect();
items.sort_by_key(|&(k, _)| k);
assert_eq!(items, vec![(&"a", &1), (&"b", &2), (&"c", &3)]);
Ok(())
}
#[test]
fn test_clone() -> Result<()> {
let mut map = SmallMap::new();
map.insert("key1", "value1")?;
map.insert("key2", "value2")?;
let cloned = map.clone();
assert_eq!(map, cloned);
Ok(())
}
#[test]
fn test_equality() -> Result<()> {
let mut map1 = SmallMap::new();
let mut map2 = SmallMap::new();
assert_eq!(map1, map2);
map1.insert("key", "value")?;
assert_ne!(map1, map2);
map2.insert("key", "value")?;
assert_eq!(map1, map2);
Ok(())
}
#[test]
fn test_get_mut() -> Result<()> {
let mut map = SmallMap::new();
map.insert("key", "value")?;
if let Some(value) = map.get_mut(&"key") {
*value = "new_value";
}
assert_eq!(map.get(&"key"), Some(&"new_value"));
Ok(())
}
#[test]
fn test_memory_efficiency() {
let small_map = SmallMap::<u64, u64>::new();
let size = std::mem::size_of::<SmallMap<u64, u64>>();
println!("SmallMap<u64, u64> size: {} bytes", size);
assert!(size <= 1024);
assert!(size >= 128); }
#[test]
fn test_large_map_behavior() -> Result<()> {
let mut map = SmallMap::new();
for i in 0..20 {
map.insert(i, i.to_string())?;
}
assert_eq!(map.len(), 20);
for i in 0..20 {
assert_eq!(map.get(&i), Some(&i.to_string()));
}
assert_eq!(map.remove(&10), Some("10".to_string()));
assert_eq!(map.len(), 19);
assert_eq!(map.get(&10), None);
Ok(())
}
#[test]
fn test_clear_promotes_back_to_small() -> Result<()> {
let mut map = SmallMap::new();
for i in 0..20 {
map.insert(i, i)?;
}
map.clear();
assert_eq!(map.len(), 0);
assert_eq!(map.capacity(), SMALL_MAP_THRESHOLD);
Ok(())
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
#[test]
fn test_simd_optimized_u8_search() -> Result<()> {
let mut map = SmallMap::<u8, u32>::new();
for i in 0u8..8 {
map.insert(i, i as u32 * 100)?;
}
for i in 0u8..8 {
assert_eq!(map.get(&i), Some(&(i as u32 * 100)));
}
for i in 0u8..8 {
assert_eq!(map.get_fast(&i), Some(&(i as u32 * 100)));
}
assert_eq!(map.get(&10), None);
assert_eq!(map.get_fast(&10), None);
Ok(())
}
#[test]
fn test_cache_line_alignment() {
let alignment = std::mem::align_of::<SmallMap<u64, u64>>();
assert_eq!(alignment, 64, "SmallMap should be cache-line aligned");
}
#[test]
fn test_separated_layout_benefits() -> Result<()> {
let mut map = SmallMap::<u32, String>::new();
for i in 0..8 {
map.insert(i, format!("value_{}", i))?;
}
let mut keys_exist = vec![];
for i in 0..8 {
keys_exist.push(map.contains_key(&i));
}
assert!(keys_exist.iter().all(|&x| x));
for i in 0..8 {
assert_eq!(map.get(&i), Some(&format!("value_{}", i)));
}
Ok(())
}
}