use core::ops::RangeBounds;
use core::slice;
use smallvec::SmallVec;
use spacetimedb_sats::memory_usage::MemoryUsage;
use std::collections::btree_map::{BTreeMap, Range};
#[derive(Debug, PartialEq, Eq)]
pub struct MultiMap<K, V> {
map: BTreeMap<K, SmallVec<[V; 1]>>,
}
impl<K, V> Default for MultiMap<K, V> {
fn default() -> Self {
Self { map: BTreeMap::new() }
}
}
impl<K: MemoryUsage, V: MemoryUsage> MemoryUsage for MultiMap<K, V> {
fn heap_usage(&self) -> usize {
let Self { map } = self;
map.heap_usage()
}
}
impl<K: Ord, V: Ord> MultiMap<K, V> {
pub fn insert(&mut self, key: K, val: V) {
self.map.entry(key).or_default().push(val);
}
pub fn delete(&mut self, key: &K, val: &V) -> bool {
if let Some(vset) = self.map.get_mut(key) {
if let Some(idx) = vset.iter().position(|v| v == val) {
vset.swap_remove(idx);
return true;
}
}
false
}
pub fn values_in_range(&self, range: &impl RangeBounds<K>) -> MultiMapRangeIter<'_, K, V> {
MultiMapRangeIter {
outer: self.map.range((range.start_bound(), range.end_bound())),
inner: None,
}
}
pub fn values_in_point(&self, key: &K) -> MultiMapPointIter<'_, V> {
let vals = self.map.get(key).map(|vs| &**vs).unwrap_or_default();
let iter = vals.iter();
MultiMapPointIter { iter }
}
pub fn num_keys(&self) -> usize {
self.map.len()
}
#[allow(unused)] pub fn len(&self) -> usize {
self.map.values().map(|ptrs| ptrs.len()).sum()
}
#[allow(unused)] pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&mut self) {
self.map.clear();
}
}
pub struct MultiMapPointIter<'a, V> {
iter: slice::Iter<'a, V>,
}
impl<'a, V> Iterator for MultiMapPointIter<'a, V> {
type Item = &'a V;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
pub struct MultiMapRangeIter<'a, K, V> {
outer: Range<'a, K, SmallVec<[V; 1]>>,
inner: Option<slice::Iter<'a, V>>,
}
impl<'a, K, V> Iterator for MultiMapRangeIter<'a, K, V> {
type Item = &'a V;
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(inner) = self.inner.as_mut() {
if let Some(val) = inner.next() {
return Some(val);
}
}
self.inner = None;
let (_, next) = self.outer.next()?;
self.inner = Some(next.iter());
}
}
}