use crate::instrumentation::{record_swiss_probe, set_swiss_load_factor};
use foldhash::fast::FixedState;
use hashbrown::HashMap;
use std::borrow::Borrow;
use std::hash::Hash;
use tracing::Level;
#[derive(Debug, Clone, Default)]
pub struct SwissIndex<K, V> {
inner: HashMap<K, V, FixedState>,
}
impl<K, V> SwissIndex<K, V>
where
K: Eq + Hash,
{
#[inline]
pub fn new() -> Self {
Self {
inner: HashMap::with_hasher(FixedState::default()),
}
}
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
Self {
inner: HashMap::with_capacity_and_hasher(capacity, FixedState::default()),
}
}
#[inline]
pub fn len(&self) -> usize {
self.inner.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[inline]
pub fn insert(&mut self, key: K, value: V) -> Option<V> {
let result = self.inner.insert(key, value);
self.maybe_record_probe_and_load_factor();
result
}
#[inline]
pub fn get<Q>(&self, key: &Q) -> Option<&V>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let result = self.inner.get(key);
self.maybe_record_probe();
result
}
#[inline]
pub fn get_mut<Q>(&mut self, key: &Q) -> Option<&mut V>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.maybe_record_probe();
self.inner.get_mut(key)
}
#[inline]
pub fn entry_or_insert_with(&mut self, key: K, default: impl FnOnce() -> V) -> &mut V {
self.maybe_record_probe();
self.inner.entry(key).or_insert_with(default)
}
#[inline]
pub fn values(&self) -> hashbrown::hash_map::Values<'_, K, V> {
self.inner.values()
}
#[inline]
pub fn contains_key<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let result = self.inner.contains_key(key);
self.maybe_record_probe();
result
}
#[inline]
pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let result = self.inner.remove(key);
self.maybe_record_probe_and_load_factor();
result
}
#[inline]
pub fn clear(&mut self) {
self.inner.clear();
if tracing::enabled!(Level::TRACE) {
self.update_load_factor();
}
}
#[inline]
pub fn iter(&self) -> hashbrown::hash_map::Iter<'_, K, V> {
self.inner.iter()
}
#[inline]
pub fn iter_mut(&mut self) -> hashbrown::hash_map::IterMut<'_, K, V> {
self.inner.iter_mut()
}
#[inline]
fn maybe_record_probe(&self) {
if tracing::enabled!(Level::TRACE) {
self.record_probe_cold();
}
}
#[inline]
fn maybe_record_probe_and_load_factor(&self) {
if tracing::enabled!(Level::TRACE) {
self.record_probe_cold();
self.update_load_factor();
}
}
#[cold]
#[inline(never)]
fn record_probe_cold(&self) {
record_swiss_probe();
let span = tracing::span!(
Level::TRACE,
"hash_probe",
probes = 1,
items = self.len(),
load_factor = self.load_factor_milli() as f64 / 1000.0
);
span.in_scope(|| {});
}
#[cold]
#[inline(never)]
fn update_load_factor(&self) {
set_swiss_load_factor(self.load_factor_milli());
}
#[inline]
fn load_factor_milli(&self) -> u64 {
let capacity = self.inner.capacity();
if capacity == 0 {
0
} else {
(self.inner.len() as u64 * 1000) / capacity as u64
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
}
impl<K, V, Q> std::ops::Index<&Q> for SwissIndex<K, V>
where
K: Eq + Hash + Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
type Output = V;
fn index(&self, key: &Q) -> &V {
&self.inner[key]
}
}
impl<K, V> IntoIterator for SwissIndex<K, V>
where
K: Eq + Hash,
{
type Item = (K, V);
type IntoIter = hashbrown::hash_map::IntoIter<K, V>;
fn into_iter(self) -> Self::IntoIter {
self.inner.into_iter()
}
}
impl<'a, K, V> IntoIterator for &'a SwissIndex<K, V>
where
K: Eq + Hash,
{
type Item = (&'a K, &'a V);
type IntoIter = hashbrown::hash_map::Iter<'a, K, V>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, K, V> IntoIterator for &'a mut SwissIndex<K, V>
where
K: Eq + Hash,
{
type Item = (&'a K, &'a mut V);
type IntoIter = hashbrown::hash_map::IterMut<'a, K, V>;
fn into_iter(self) -> Self::IntoIter {
self.iter_mut()
}
}
impl<K, V> FromIterator<(K, V)> for SwissIndex<K, V>
where
K: Eq + Hash,
{
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
let inner = HashMap::from_iter(iter);
let index = Self { inner };
index.update_load_factor();
index
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_swiss_index_basic_ops() {
let mut map = SwissIndex::new();
assert!(map.is_empty());
map.insert("key1", 1);
assert_eq!(map.len(), 1);
assert_eq!(map.get("key1"), Some(&1));
assert!(map.contains_key("key1"));
map.insert("key2", 2);
assert_eq!(map.len(), 2);
assert_eq!(map.remove("key1"), Some(1));
assert_eq!(map.len(), 1);
assert!(!map.contains_key("key1"));
}
#[test]
fn test_swiss_index_remove_absent_returns_none_and_values_yields_remaining() {
let mut map: SwissIndex<&str, i32> = SwissIndex::new();
map.insert("a", 1);
map.insert("b", 2);
assert_eq!(map.remove("missing"), None);
assert_eq!(map.len(), 2);
assert_eq!(map.remove("a"), Some(1));
assert_eq!(map.len(), 1);
assert_eq!(map.remove("a"), None);
let mut vals: Vec<i32> = map.values().copied().collect();
vals.sort_unstable();
assert_eq!(vals, vec![2]);
}
#[test]
fn test_swiss_index_get_mut_clear_and_insert_returns_old() {
let mut map: SwissIndex<&str, i32> = SwissIndex::new();
assert_eq!(map.insert("k", 1), None);
assert_eq!(map.insert("k", 2), Some(1));
assert_eq!(map.get("k"), Some(&2));
*map.get_mut("k").unwrap() += 40;
assert_eq!(map.get("k"), Some(&42));
assert!(map.get_mut("absent").is_none());
map.insert("k2", 7);
assert_eq!(map.len(), 2);
map.clear();
assert!(map.is_empty());
assert_eq!(map.len(), 0);
assert!(map.get("k").is_none());
}
#[test]
fn test_swiss_index_capacity_and_load_factor() {
let mut map = SwissIndex::with_capacity(100);
assert_eq!(map.load_factor_milli(), 0);
for i in 0..50 {
map.insert(i, i * 10);
}
let lf = map.load_factor_milli();
assert!(lf > 0);
assert!(lf < 1000);
}
#[test]
fn test_swiss_index_entry_or_insert_with() {
let mut map = SwissIndex::new();
let val = map.entry_or_insert_with(42, || 100);
assert_eq!(*val, 100);
let val = map.entry_or_insert_with(42, || 999);
assert_eq!(*val, 100);
}
#[test]
fn test_swiss_index_from_iter() {
let map: SwissIndex<i32, i32> = [(1, 10), (2, 20), (3, 30)].into_iter().collect();
assert_eq!(map.len(), 3);
assert_eq!(map.get(&2), Some(&20));
}
}