use alloc::boxed::Box;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::hash::{Hash, Hasher};
use core::mem;
use super::Waiter;
use super::lock::Mutex;
const DEFAULT_SHARDS: usize = 16;
struct Entry<K, V> {
key: K,
value: V,
next: Option<Box<Entry<K, V>>>,
}
struct Shard<K, V> {
buckets: Mutex<Vec<Option<Box<Entry<K, V>>>>>,
}
pub struct FastMap<K, V, const N: usize = DEFAULT_SHARDS> {
shards: [Arc<Shard<K, V>>; N],
}
impl<K, V, const N: usize> FastMap<K, V, N>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Send + Sync + 'static,
{
pub fn new() -> Self {
Self::with_capacity(64)
}
pub fn with_capacity(capacity: usize) -> Self {
let capacity_per_shard = (capacity + N - 1) / N;
let shards = core::array::from_fn(|_| {
let mut buckets = Vec::with_capacity(capacity_per_shard);
buckets.resize_with(capacity_per_shard, || None);
Arc::new(Shard {
buckets: Mutex::new(buckets),
})
});
Self { shards }
}
fn shard_index(&self, key: &K) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish() as usize % N
}
fn bucket_index(&self, key: &K, _shard_idx: usize) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
let bucket_hash = hash.wrapping_mul(0x9e3779b9);
bucket_hash as usize
}
pub async fn insert(&self, key: K, value: V) -> Option<V> {
let shard_idx = self.shard_index(&key);
let shard = &self.shards[shard_idx];
let mut buckets = shard.buckets.lock(Waiter::default()).await;
let bucket_idx = {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
let bucket_hash = hash.wrapping_mul(0x9e3779b9);
bucket_hash as usize % buckets.len()
};
{
let mut current = &mut buckets[bucket_idx];
while let Some(entry) = current {
if entry.key == key {
let old_value = mem::replace(&mut entry.value, value);
return Some(old_value);
}
current = &mut entry.next;
}
}
let old_head = buckets[bucket_idx].take();
let new_entry = Box::new(Entry {
key,
value,
next: old_head,
});
buckets[bucket_idx] = Some(new_entry);
None
}
pub async fn get(&self, key: &K) -> Option<V>
where
V: Clone,
{
let shard_idx = self.shard_index(key);
let shard = &self.shards[shard_idx];
let buckets = shard.buckets.lock(Waiter::default()).await;
let bucket_idx = {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
let bucket_hash = hash.wrapping_mul(0x9e3779b9);
bucket_hash as usize % buckets.len()
};
let mut current = &buckets[bucket_idx];
while let Some(entry) = current {
if entry.key == *key {
return Some(entry.value.clone());
}
current = &entry.next;
}
None
}
pub async fn get_ref<F, R>(&self, key: &K, f: F) -> Option<R>
where
F: FnOnce(&V) -> R,
{
let shard_idx = self.shard_index(key);
let shard = &self.shards[shard_idx];
let buckets = shard.buckets.lock(Waiter::default()).await;
let bucket_idx = {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
let bucket_hash = hash.wrapping_mul(0x9e3779b9);
bucket_hash as usize % buckets.len()
};
let mut current = &buckets[bucket_idx];
while let Some(entry) = current {
if entry.key == *key {
return Some(f(&entry.value));
}
current = &entry.next;
}
None
}
pub async fn remove(&self, key: &K) -> Option<V> {
let shard_idx = self.shard_index(key);
let shard = &self.shards[shard_idx];
let mut buckets = shard.buckets.lock(Waiter::default()).await;
let bucket_idx = {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
let bucket_hash = hash.wrapping_mul(0x9e3779b9);
bucket_hash as usize % buckets.len()
};
let bucket = &mut buckets[bucket_idx];
if let Some(entry) = bucket {
if entry.key == *key {
let mut removed = bucket.take().unwrap();
*bucket = removed.next.take();
return Some(removed.value);
}
}
let mut current = bucket;
while let Some(entry) = current {
if let Some(next) = &mut entry.next {
if next.key == *key {
let mut removed = entry.next.take().unwrap();
entry.next = removed.next.take();
return Some(removed.value);
}
}
current = &mut entry.next;
}
None
}
pub async fn contains_key(&self, key: &K) -> bool {
let shard_idx = self.shard_index(key);
let shard = &self.shards[shard_idx];
let buckets = shard.buckets.lock(Waiter::default()).await;
let bucket_idx = {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
let bucket_hash = hash.wrapping_mul(0x9e3779b9);
bucket_hash as usize % buckets.len()
};
let mut current = &buckets[bucket_idx];
while let Some(entry) = current {
if entry.key == *key {
return true;
}
current = &entry.next;
}
false
}
pub async fn len(&self) -> usize {
let mut total = 0;
for shard in &self.shards {
let buckets = shard.buckets.lock(Waiter::default()).await;
for bucket in buckets.iter() {
let mut current = bucket;
while let Some(entry) = current {
total += 1;
current = &entry.next;
}
}
}
total
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
pub async fn clear(&self) {
for shard in &self.shards {
let mut buckets = shard.buckets.lock(Waiter::default()).await;
for bucket in buckets.iter_mut() {
*bucket = None;
}
}
}
}
struct DefaultHasher {
state: u64,
}
impl DefaultHasher {
fn new() -> Self {
Self {
state: 0xcafe_babe_dead_beef,
}
}
fn finish(&self) -> u64 {
self.state
}
}
impl Hasher for DefaultHasher {
fn write(&mut self, bytes: &[u8]) {
for &byte in bytes {
self.state = self
.state
.wrapping_mul(0x100000001b3)
.wrapping_add(byte as u64);
}
}
fn finish(&self) -> u64 {
self.state
}
}
impl<K, V, const N: usize> Default for FastMap<K, V, N>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
unsafe impl<K: Send, V: Send, const N: usize> Send for FastMap<K, V, N> {}
unsafe impl<K: Send + Sync, V: Send + Sync, const N: usize> Sync for FastMap<K, V, N> {}