use super::*;
#[cfg(feature = "async")]
impl<K, V> AsyncShardedHashMap<K, V, FxBuildHasher>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
#[tracing::instrument(level = "trace")]
pub fn new(shard_count: usize) -> Self {
Self::with_shards_and_hasher(shard_count, FxBuildHasher)
}
}
#[cfg(feature = "async")]
impl<K, V, S> AsyncShardedHashMap<K, V, S>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync,
{
#[inline]
fn build_with_count(count: usize, hasher: S) -> Self {
Self {
shards: Arc::new(TokioRwLock::new(vec![None; count])),
hasher,
shard_count: count,
total_len: Arc::new(AtomicUsize::new(0)),
#[cfg(feature = "advanced")]
version: Arc::new(AtomicUsize::new(0)),
#[cfg(feature = "advanced")]
profiling_enabled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
#[cfg(feature = "advanced")]
replicas: Arc::new(StdRwLock::new(Vec::new())),
#[cfg(feature = "advanced")]
quorum_config: Arc::new(StdRwLock::new(None)),
}
}
#[tracing::instrument(skip(hasher), level = "trace")]
pub fn with_shards_and_hasher(shard_count: usize, hasher: S) -> Self {
let requested = normalized_shard_count(shard_count);
let count = capped_shard_count(requested, MAX_SHARDS);
if requested != count {
tracing::warn!(
requested_shards = requested,
capped_shards = count,
max_shards = MAX_SHARDS,
"requested shard_count exceeded default cap and was clamped"
);
}
Self::build_with_count(count, hasher)
}
#[tracing::instrument(skip(hasher), level = "trace")]
pub fn with_shards_and_hasher_capped(shard_count: usize, hasher: S, max_shards: usize) -> Self {
let effective_max = max_shards.max(1);
let requested = normalized_shard_count(shard_count);
let count = capped_shard_count(requested, effective_max);
if requested != count {
tracing::warn!(
requested_shards = requested,
capped_shards = count,
max_shards = effective_max,
"requested shard_count exceeded configured cap and was clamped"
);
}
Self::build_with_count(count, hasher)
}
#[tracing::instrument(skip(hasher), level = "trace")]
pub fn try_with_shards_and_hasher(
shard_count: usize,
hasher: S,
) -> Result<Self, ShardCountError> {
Self::try_with_shards_and_hasher_capped(shard_count, hasher, MAX_SHARDS)
}
#[tracing::instrument(skip(hasher), level = "trace")]
pub fn try_with_shards_and_hasher_capped(
shard_count: usize,
hasher: S,
max_shards: usize,
) -> Result<Self, ShardCountError> {
let effective_max = max_shards.max(1);
let count = strict_shard_count(shard_count, effective_max)?;
Ok(Self::build_with_count(count, hasher))
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn shard_count(&self) -> usize {
self.shard_count
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn initialized_shards(&self) -> usize {
let g = self.shards.read().await;
g.iter().filter(|o| o.is_some()).count()
}
#[inline]
#[tracing::instrument(skip(self, key), level = "trace")]
fn shard_index(&self, key: &K) -> usize {
(self.hasher.hash_one(key) % self.shard_count as u64) as usize
}
#[inline]
#[tracing::instrument(skip(self), level = "trace")]
async fn get_or_init_shard(&self, index: usize) -> AsyncShard<K, V, S> {
let mut g = self.shards.write().await;
if g[index].is_none() {
let map = AsyncShardMap::with_hasher(self.hasher.clone());
g[index] = Some(Arc::new(TokioRwLock::new(map)));
}
if let Some(shard) = g[index].as_ref() {
shard.clone()
} else {
tracing::error!(
shard_index = index,
"async shard slot still uninitialized; creating fallback shard"
);
let map = AsyncShardMap::with_hasher(self.hasher.clone());
let shard = Arc::new(TokioRwLock::new(map));
g[index] = Some(shard.clone());
shard
}
}
#[inline]
fn bucketize_entries<I>(&self, entries: I) -> HashMap<usize, Vec<(K, V)>, FxBuildHasher>
where
I: IntoIterator<Item = (K, V)>,
{
let iter = entries.into_iter();
let estimated = iter.size_hint().0.min(self.shard_count);
let mut buckets: HashMap<usize, Vec<(K, V)>, FxBuildHasher> =
HashMap::with_capacity_and_hasher(estimated, FxBuildHasher);
for (k, v) in iter {
let shard_idx = self.shard_index(&k);
buckets.entry(shard_idx).or_default().push((k, v));
}
buckets
}
#[inline]
fn bucketize_keys<I>(&self, keys: I) -> HashMap<usize, Vec<K>, FxBuildHasher>
where
I: IntoIterator<Item = K>,
{
let iter = keys.into_iter();
let estimated = iter.size_hint().0.min(self.shard_count);
let mut buckets: HashMap<usize, Vec<K>, FxBuildHasher> =
HashMap::with_capacity_and_hasher(estimated, FxBuildHasher);
for k in iter {
let shard_idx = self.shard_index(&k);
buckets.entry(shard_idx).or_default().push(k);
}
buckets
}
#[inline]
fn bucketize_key_refs<'a>(
&self,
keys: &'a [K],
) -> HashMap<usize, Vec<(usize, &'a K)>, FxBuildHasher> {
let estimated = keys.len().min(self.shard_count);
let mut buckets: HashMap<usize, Vec<(usize, &'a K)>, FxBuildHasher> =
HashMap::with_capacity_and_hasher(estimated, FxBuildHasher);
for (idx, key) in keys.iter().enumerate() {
let shard_idx = self.shard_index(key);
buckets.entry(shard_idx).or_default().push((idx, key));
}
buckets
}
#[tracing::instrument(skip(self, key, value), level = "trace")]
pub async fn insert(&self, key: K, value: V) -> Option<V> {
let shard = self.get_or_init_shard(self.shard_index(&key)).await;
let mut guard: TokioWriteGuard<'_, HashMap<K, V, S>> = shard.write().await;
let old = guard.insert(key, value);
if old.is_none() {
self.total_len.fetch_add(1, Ordering::Relaxed);
}
old
}
#[tracing::instrument(skip(self, key), level = "trace")]
pub async fn get(&self, key: &K) -> Option<V> {
let shard = self.get_or_init_shard(self.shard_index(key)).await;
if let Ok(g) = shard.try_read() {
return g.get(key).cloned();
}
let g = shard.read().await;
g.get(key).cloned()
}
#[tracing::instrument(skip(self, key), level = "trace")]
pub async fn contains(&self, key: &K) -> bool {
let shard = self.get_or_init_shard(self.shard_index(key)).await;
if let Ok(g) = shard.try_read() {
return g.contains_key(key);
}
let g = shard.read().await;
g.contains_key(key)
}
#[tracing::instrument(skip(self, key), level = "trace")]
pub async fn remove(&self, key: &K) -> Option<V> {
let shard = self.get_or_init_shard(self.shard_index(key)).await;
let mut g = shard.write().await;
let old = g.remove(key);
if old.is_some() {
self.total_len.fetch_sub(1, Ordering::Relaxed);
}
old
}
#[inline]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn len(&self) -> usize {
self.total_len.load(Ordering::Relaxed)
}
#[inline]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn clear(&self) {
let slots = self.shards.read().await;
for shard in slots.iter().flatten() {
let mut g = shard.write().await;
g.clear();
}
self.total_len.store(0, Ordering::Relaxed);
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn iter(&self) -> Vec<(K, V)> {
let shard_arcs: Vec<AsyncShard<K, V, S>> = {
let g = self.shards.read().await;
g.iter().filter_map(|o| o.as_ref().cloned()).collect()
};
let mut snapshots = Vec::with_capacity(shard_arcs.len());
for shard in shard_arcs {
if let Ok(g) = shard.try_read() {
snapshots.push(g.clone());
} else {
let g = shard.read().await;
snapshots.push(g.clone());
}
}
#[cfg(feature = "rayon")]
{
snapshots
.par_iter()
.flat_map(|m| m.par_iter().map(|(k, v)| (k.clone(), v.clone())))
.collect()
}
#[cfg(not(feature = "rayon"))]
{
let mut items = Vec::new();
for m in snapshots {
items.extend(m.iter().map(|(k, v)| (k.clone(), v.clone())));
}
items
}
}
#[tracing::instrument(skip(self, entries), level = "trace")]
pub async fn batch_insert<I>(&self, entries: I) -> usize
where
I: IntoIterator<Item = (K, V)>,
{
let buckets = self.bucketize_entries(entries);
let mut count = 0;
for (shard_idx, pairs) in buckets {
let shard = self.get_or_init_shard(shard_idx).await;
let mut guard = shard.write().await;
for (k, v) in pairs {
if guard.insert(k, v).is_none() {
count += 1;
}
}
}
if count > 0 {
self.total_len.fetch_add(count, Ordering::Relaxed);
}
count
}
#[tracing::instrument(skip(self, keys), level = "trace")]
pub async fn batch_remove<I>(&self, keys: I) -> usize
where
I: IntoIterator<Item = K>,
{
let buckets = self.bucketize_keys(keys);
let mut count = 0;
for (shard_idx, keys) in buckets {
let shard = self.get_or_init_shard(shard_idx).await;
let mut guard = shard.write().await;
for k in keys {
if guard.remove(&k).is_some() {
count += 1;
}
}
}
if count > 0 {
self.total_len.fetch_sub(count, Ordering::Relaxed);
}
count
}
#[tracing::instrument(skip(self, keys), level = "trace")]
pub async fn batch_get(&self, keys: &[K]) -> Vec<Option<V>> {
let mut results = vec![None; keys.len()];
let buckets = self.bucketize_key_refs(keys);
for (shard_idx, items) in buckets {
let shard = self.get_or_init_shard(shard_idx).await;
let guard = shard.read().await;
for (idx, key) in items {
if let Some(val) = guard.get(key) {
results[idx] = Some(val.clone());
}
}
}
results
}
#[tracing::instrument(skip(self, key, f), level = "trace")]
pub async fn compute_if_present<F>(&self, key: &K, f: F) -> Option<V>
where
F: FnOnce(V) -> Option<V>,
{
let shard = self.get_or_init_shard(self.shard_index(key)).await;
let mut guard = shard.write().await;
let old_v = guard.get(key).cloned()?;
if let Some(new_v) = f(old_v) {
let result = new_v.clone();
guard.insert(key.clone(), new_v);
Some(result)
} else {
guard.remove(key);
self.total_len.fetch_sub(1, Ordering::Relaxed);
None
}
}
#[tracing::instrument(skip(self, key, f), level = "trace")]
pub async fn compute_if_absent<F>(&self, key: K, f: F) -> V
where
F: FnOnce() -> V,
{
let shard = self.get_or_init_shard(self.shard_index(&key)).await;
let mut guard = shard.write().await;
if let Some(v) = guard.get(&key) {
v.clone()
} else {
let new_v = f();
guard.insert(key, new_v.clone());
self.total_len.fetch_add(1, Ordering::Relaxed);
new_v
}
}
#[tracing::instrument(skip(self, predicate), level = "trace")]
pub async fn retain<F>(&self, predicate: F)
where
F: Fn(&K, &V) -> bool,
{
let shards_snapshot: Vec<AsyncShard<K, V, S>> = {
let g = self.shards.read().await;
g.iter().filter_map(|o| o.as_ref().cloned()).collect()
};
for shard in shards_snapshot {
let mut guard = shard.write().await;
let removed_count = guard.len();
guard.retain(|k, v| predicate(k, v));
let removed = removed_count - guard.len();
if removed > 0 {
self.total_len.fetch_sub(removed, Ordering::Relaxed);
}
}
}
#[cfg(feature = "advanced")]
#[tracing::instrument(skip(self, txn), level = "trace")]
pub async fn execute_transaction(&self, txn: Transaction<K, V>) -> TransactionResult<()> {
let mut shard_indices: Vec<usize> = txn
.ops
.iter()
.map(|op| match op {
TxnOp::Read(k) => self.shard_index(k),
TxnOp::Write(k, _) => self.shard_index(k),
TxnOp::Remove(k) => self.shard_index(k),
})
.collect();
shard_indices.sort_unstable();
shard_indices.dedup();
let mut shard_arcs = Vec::with_capacity(shard_indices.len());
for &idx in &shard_indices {
shard_arcs.push(self.get_or_init_shard(idx).await);
}
let mut guards = Vec::with_capacity(shard_arcs.len());
for shard_arc in &shard_arcs {
let guard = shard_arc.write().await;
guards.push(guard);
}
for op in txn.ops {
match op {
TxnOp::Read(k) => {
let idx = self.shard_index(&k);
let guard_idx = match shard_indices.binary_search(&idx) {
Ok(i) => i,
Err(_) => {
tracing::error!(
shard_index = idx,
"shard index missing in async transaction"
);
return TransactionResult::Aborted;
}
};
let guard = &guards[guard_idx];
let _ = guard.get(&k);
}
TxnOp::Write(k, v) => {
let idx = self.shard_index(&k);
let guard_idx = match shard_indices.binary_search(&idx) {
Ok(i) => i,
Err(_) => {
tracing::error!(
shard_index = idx,
"shard index missing in async transaction"
);
return TransactionResult::Aborted;
}
};
let guard = &mut guards[guard_idx];
if guard.insert(k, v).is_none() {
self.total_len.fetch_add(1, Ordering::Relaxed);
}
}
TxnOp::Remove(k) => {
let idx = self.shard_index(&k);
let guard_idx = match shard_indices.binary_search(&idx) {
Ok(i) => i,
Err(_) => {
tracing::error!(
shard_index = idx,
"shard index missing in async transaction"
);
return TransactionResult::Aborted;
}
};
let guard = &mut guards[guard_idx];
if guard.remove(&k).is_some() {
self.total_len.fetch_sub(1, Ordering::Relaxed);
}
}
}
}
TransactionResult::Committed(())
}
#[cfg(feature = "advanced")]
#[tracing::instrument(skip(self, key, expected, new), level = "trace")]
pub async fn compare_and_swap(&self, key: &K, expected: &V, new: V) -> CasResult<V>
where
V: PartialEq,
{
let shard = self.get_or_init_shard(self.shard_index(key)).await;
let mut guard = shard.write().await;
match guard.get(key) {
Some(current) if current == expected => {
guard.insert(key.clone(), new.clone());
CasResult::Success(new)
}
Some(current) => CasResult::Failure(current.clone()),
None => CasResult::Failure(new),
}
}
#[cfg(feature = "advanced")]
#[tracing::instrument(skip(self, key, expected), level = "trace")]
pub async fn compare_and_remove(&self, key: &K, expected: &V) -> bool
where
V: PartialEq,
{
let shard = self.get_or_init_shard(self.shard_index(key)).await;
let mut guard = shard.write().await;
match guard.get(key) {
Some(current) if current == expected => {
guard.remove(key);
self.total_len.fetch_sub(1, Ordering::Relaxed);
true
}
_ => false,
}
}
#[cfg(feature = "advanced")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn cow_snapshot(&self) -> CowSnapshot<K, V> {
let data = self.iter().await;
let version = self.version.load(Ordering::SeqCst) as u64;
CowSnapshot::new(data, version)
}
#[cfg(feature = "advanced")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn versioned_snapshot(&self) -> IsolatedSnapshot<K, V> {
let data = self.iter().await;
let version = self.version.fetch_add(1, Ordering::SeqCst) as u64;
IsolatedSnapshot::new(version, data)
}
#[cfg(feature = "advanced")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn snapshot_at_version(&self, version: u64) -> Option<IsolatedSnapshot<K, V>> {
let current_version = self.version.load(Ordering::SeqCst) as u64;
if version == current_version {
let data = self.iter().await;
Some(IsolatedSnapshot::new(version, data))
} else {
None
}
}
#[cfg(feature = "advanced")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn lock_profiles(&self) -> Vec<LockProfile> {
let profiling_enabled = self.profiling_enabled.load(Ordering::Relaxed);
if !profiling_enabled {
return Vec::new();
}
let slots = self.shards.read().await;
let mut profiles = Vec::new();
for (idx, slot) in slots.iter().enumerate() {
if let Some(_shard) = slot {
profiles.push(LockProfile {
shard_id: idx,
contention_count: 0,
avg_wait_time_ns: 0,
max_wait_time_ns: 0,
reads: 0,
writes: 0,
});
}
}
profiles
}
#[cfg(feature = "advanced")]
#[tracing::instrument(skip(self), level = "trace")]
pub fn enable_profiling(&self, enabled: bool) {
self.profiling_enabled.store(enabled, Ordering::Relaxed);
}
#[cfg(feature = "advanced")]
#[tracing::instrument(skip(replicas, quorum_config), level = "trace")]
pub fn with_replication(
shard_count: usize,
replicas: Vec<Arc<dyn Replica<K, V>>>,
quorum_config: QuorumConfig,
) -> Self
where
S: Default,
{
let map = Self::with_shards_and_hasher(shard_count, S::default());
{
let mut r = std_write_guard(&map.replicas, "with_replication");
*r = replicas;
}
{
let mut q = std_write_guard(&map.quorum_config, "with_replication_quorum");
*q = Some(quorum_config);
}
map
}
#[cfg(feature = "advanced")]
#[tracing::instrument(skip(self, key, value), level = "trace")]
pub async fn insert_replicated(&self, key: K, value: V) -> Result<Option<V>, ReplicaError> {
let old = self.insert(key.clone(), value.clone()).await;
let replicas = {
let r = std_read_guard(&self.replicas, "insert_replicated");
r.clone()
};
if replicas.is_empty() {
return Ok(old);
}
let quorum_config = {
let q = std_read_guard(&self.quorum_config, "insert_replicated_quorum");
q.clone()
};
let op = ReplicationOp::Insert {
key: key.clone(),
value: value.clone(),
};
let mut tasks = Vec::new();
for replica in replicas.iter() {
let replica_clone = Arc::clone(replica);
let op_clone = op.clone();
tasks.push(tokio::spawn(async move {
replica_clone.replicate(op_clone).await
}));
}
let mut success_count = 0;
for task in tasks {
if let Ok(Ok(())) = task.await {
success_count += 1;
}
}
if let Some(config) = quorum_config
&& success_count < config.write_quorum
{
return Err(ReplicaError::QuorumFailed);
}
Ok(old)
}
#[cfg(feature = "advanced")]
#[tracing::instrument(skip(self, key), level = "trace")]
pub async fn remove_replicated(&self, key: &K) -> Result<Option<V>, ReplicaError> {
let old = self.remove(key).await;
let replicas = {
let r = std_read_guard(&self.replicas, "remove_replicated");
r.clone()
};
if replicas.is_empty() {
return Ok(old);
}
let quorum_config = {
let q = std_read_guard(&self.quorum_config, "remove_replicated_quorum");
q.clone()
};
let op = ReplicationOp::Remove { key: key.clone() };
let mut tasks = Vec::new();
for replica in replicas.iter() {
let replica_clone = Arc::clone(replica);
let op_clone = op.clone();
tasks.push(tokio::spawn(async move {
replica_clone.replicate(op_clone).await
}));
}
let mut success_count = 0;
for task in tasks {
if let Ok(Ok(())) = task.await {
success_count += 1;
}
}
if let Some(config) = quorum_config
&& success_count < config.write_quorum
{
return Err(ReplicaError::QuorumFailed);
}
Ok(old)
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn keys(&self) -> Vec<K> {
self.iter().await.into_iter().map(|(k, _)| k).collect()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn values(&self) -> Vec<V> {
self.iter().await.into_iter().map(|(_, v)| v).collect()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn shard_stats(&self) -> ShardStats {
let slots = self.shards.read().await;
let mut initialized = 0;
let mut loads = Vec::new();
for shard in slots.iter().flatten() {
initialized += 1;
let guard = shard.read().await;
loads.push(guard.len());
}
let total = slots.len();
let empty = loads.iter().filter(|&&l| l == 0).count();
let max_load = loads.iter().max().copied().unwrap_or(0);
let avg_load = if initialized > 0 {
loads.iter().sum::<usize>() as f64 / initialized as f64
} else {
0.0
};
ShardStats {
initialized,
total,
empty,
avg_load,
max_load,
}
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn shard_utilization(&self) -> f64 {
let stats = self.shard_stats().await;
stats.utilization_percent()
}
#[cfg(feature = "lifecycle")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn per_shard_load(&self) -> Vec<PerShardLoad> {
let slots = self.shards.read().await;
let mut stats = Vec::new();
for (i, shard_opt) in slots.iter().enumerate() {
if let Some(shard) = shard_opt {
let guard = shard.read().await;
stats.push(PerShardLoad {
shard_idx: i,
entry_count: guard.len(),
capacity: guard.capacity(),
});
}
}
stats
}
#[cfg(feature = "lifecycle")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn memory_stats(&self) -> MemoryStats {
let slots = self.shards.read().await;
let mut shards_allocated = 0;
let mut total_capacity = 0usize;
let mut total_entries = 0usize;
for shard in slots.iter().flatten() {
shards_allocated += 1;
let guard = shard.read().await;
total_capacity += guard.capacity();
total_entries += guard.len();
}
let load_factor = if total_capacity > 0 {
total_entries as f64 / total_capacity as f64
} else {
0.0
};
MemoryStats {
shards_allocated,
total_capacity,
load_factor,
}
}
#[cfg(feature = "lifecycle")]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn drain(&self) -> DrainIterator<K, V> {
let slots = self.shards.read().await;
let mut items = Vec::new();
for shard in slots.iter().flatten() {
let mut guard = shard.write().await;
items.extend(guard.drain());
}
self.total_len.store(0, Ordering::Relaxed);
DrainIterator { items, index: 0 }
}
}