use super::{ItemIndex, item_set::IndexRemap, map_hash::MapHash};
use crate::internal::{TableValidationError, ValidateCompact};
use alloc::{
collections::{BTreeSet, btree_set},
vec::Vec,
};
use core::{
cell::Cell,
cmp::Ordering,
hash::{BuildHasher, Hash},
marker::PhantomData,
};
use equivalent::Comparable;
thread_local! {
static CMP: Cell<Option<&'static IndexCmp<'static>>>
= const { Cell::new(None) };
}
type IndexCmp<'a> = dyn Fn(&Index, &Index) -> Ordering + 'a;
#[derive(Clone, Debug, Default)]
pub(crate) struct MapBTreeTable {
items: BTreeSet<Index>,
hash_state: foldhash::fast::FixedState,
}
impl MapBTreeTable {
pub(crate) const fn new() -> Self {
Self {
items: BTreeSet::new(),
hash_state: foldhash::fast::FixedState::with_seed(0),
}
}
#[doc(hidden)]
pub(crate) fn len(&self) -> usize {
self.items.len()
}
#[doc(hidden)]
pub(crate) fn validate(
&self,
expected_len: usize,
compactness: ValidateCompact,
) -> Result<(), TableValidationError> {
if self.len() != expected_len {
return Err(TableValidationError::new(format!(
"expected length {expected_len}, was {}",
self.len(),
)));
}
match compactness {
ValidateCompact::Compact => {
let mut indexes: Vec<ItemIndex> =
Vec::with_capacity(expected_len);
for index in &self.items {
let v = index.value();
if v == Index::SENTINEL_VALUE {
return Err(TableValidationError::new(
"sentinel value should not be stored in map",
));
}
indexes.push(v);
}
indexes.sort_unstable();
for (i, index) in indexes.iter().enumerate() {
if index.as_u32() as usize != i {
return Err(TableValidationError::new(format!(
"value at index {i} should be {i}, was {index}",
)));
}
}
}
ValidateCompact::NonCompact => {
let indexes: Vec<ItemIndex> =
self.items.iter().map(|ix| ix.value()).collect();
let index_set: BTreeSet<ItemIndex> =
indexes.iter().copied().collect();
if index_set.len() != indexes.len() {
return Err(TableValidationError::new(format!(
"expected no duplicates, but found {} duplicates \
(values: {:?})",
indexes.len() - index_set.len(),
indexes,
)));
}
if index_set.contains(&Index::SENTINEL_VALUE) {
return Err(TableValidationError::new(
"sentinel value should not be stored in map",
));
}
}
}
Ok(())
}
#[inline]
pub(crate) fn first(&self) -> Option<ItemIndex> {
self.items.first().map(|ix| ix.value())
}
#[inline]
pub(crate) fn last(&self) -> Option<ItemIndex> {
self.items.last().map(|ix| ix.value())
}
pub(crate) fn find_index<K, Q, F>(
&self,
key: &Q,
lookup: F,
) -> Option<ItemIndex>
where
K: Ord,
Q: ?Sized + Comparable<K>,
F: Fn(ItemIndex) -> K,
{
let f = find_cmp(key, lookup);
let guard = CmpDropGuard::new(&f);
let ret = match self.items.get(&Index::sentinel()) {
Some(ix) if ix.value() == Index::SENTINEL_VALUE => {
panic!("internal map shouldn't store sentinel value")
}
Some(ix) => Some(ix.value()),
None => {
None
}
};
drop(guard);
ret
}
pub(crate) fn insert<K, Q, F>(
&mut self,
index: ItemIndex,
key: &Q,
lookup: F,
) where
K: Ord,
Q: ?Sized + Comparable<K>,
F: Fn(ItemIndex) -> K,
{
let f = insert_cmp(index, key, lookup);
let guard = CmpDropGuard::new(&f);
self.items.insert(Index::new(index));
drop(guard);
}
pub(crate) fn remove<K, F>(&mut self, index: ItemIndex, key: K, lookup: F)
where
F: Fn(ItemIndex) -> K,
K: Ord,
{
let f = insert_cmp(index, &key, lookup);
let guard = CmpDropGuard::new(&f);
self.items.remove(&Index::new(index));
drop(guard);
}
pub(crate) fn retain<F>(&mut self, mut f: F)
where
F: FnMut(ItemIndex) -> bool,
{
self.items.retain(|index| f(index.value()));
}
pub(crate) fn remap_indexes(&mut self, remap: &IndexRemap) {
for idx in self.items.iter() {
let new = remap.remap(idx.value());
idx.set_value(new);
}
}
#[inline]
pub(crate) fn clear(&mut self) {
self.items.clear();
}
pub(crate) fn iter(&self) -> Iter<'_> {
Iter::new(self.items.iter())
}
pub(crate) fn into_iter(self) -> IntoIter {
IntoIter::new(self.items.into_iter())
}
pub(crate) fn state(&self) -> &foldhash::fast::FixedState {
&self.hash_state
}
pub(crate) fn compute_hash<K: Hash>(&self, key: K) -> MapHash {
MapHash { hash: self.hash_state.hash_one(key) }
}
}
#[derive(Clone, Debug)]
pub(crate) struct Iter<'a> {
inner: btree_set::Iter<'a, Index>,
}
impl<'a> Iter<'a> {
fn new(inner: btree_set::Iter<'a, Index>) -> Self {
Self { inner }
}
pub(crate) fn len(&self) -> usize {
self.inner.len()
}
}
impl<'a> Iterator for Iter<'a> {
type Item = ItemIndex;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|index| index.value())
}
}
#[derive(Debug)]
pub(crate) struct IntoIter {
inner: btree_set::IntoIter<Index>,
}
impl IntoIter {
fn new(inner: btree_set::IntoIter<Index>) -> Self {
Self { inner }
}
}
impl Iterator for IntoIter {
type Item = ItemIndex;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|index| index.value())
}
}
fn find_cmp<'a, K, Q, F>(
key: &'a Q,
lookup: F,
) -> impl Fn(&Index, &Index) -> Ordering + 'a
where
Q: ?Sized + Comparable<K>,
F: 'a + Fn(ItemIndex) -> K,
K: Ord,
{
move |a: &Index, b: &Index| {
let (a, b) = (a.value(), b.value());
if a == b {
return Ordering::Equal;
}
match (a, b) {
(Index::SENTINEL_VALUE, v) => key.compare(&lookup(v)),
(v, Index::SENTINEL_VALUE) => key.compare(&lookup(v)).reverse(),
(a, b) => lookup(a).cmp(&lookup(b)),
}
}
}
fn insert_cmp<'a, K, Q, F>(
index: ItemIndex,
key: &'a Q,
lookup: F,
) -> impl Fn(&Index, &Index) -> Ordering + 'a
where
Q: ?Sized + Comparable<K>,
F: 'a + Fn(ItemIndex) -> K,
K: Ord,
{
move |a: &Index, b: &Index| {
let (a, b) = (a.value(), b.value());
if a == b {
return Ordering::Equal;
}
match (a, b) {
(Index::SENTINEL_VALUE, _) | (_, Index::SENTINEL_VALUE) => {
panic!("sentinel value should not be invoked in insert path")
}
(a, b) if a == index => key.compare(&lookup(b)),
(a, b) if b == index => key.compare(&lookup(a)).reverse(),
(a, b) => lookup(a).cmp(&lookup(b)),
}
}
}
struct CmpDropGuard<'a> {
_marker: PhantomData<&'a ()>,
}
impl<'a> CmpDropGuard<'a> {
fn new(f: &'a IndexCmp<'a>) -> Self {
let ret = Self { _marker: PhantomData };
let as_static = unsafe {
std::mem::transmute::<&'a IndexCmp<'a>, &'static IndexCmp<'static>>(
f,
)
};
CMP.set(Some(as_static));
ret
}
}
impl Drop for CmpDropGuard<'_> {
fn drop(&mut self) {
CMP.set(None);
}
}
#[repr(transparent)]
#[derive(Debug, Default)]
struct IndexCell(core::sync::atomic::AtomicU32);
impl Clone for IndexCell {
fn clone(&self) -> Self {
Self(core::sync::atomic::AtomicU32::new(self.get().as_u32()))
}
}
impl IndexCell {
#[inline]
const fn new(value: ItemIndex) -> Self {
Self(core::sync::atomic::AtomicU32::new(value.as_u32()))
}
#[inline]
fn get(&self) -> ItemIndex {
ItemIndex::new(self.0.load(core::sync::atomic::Ordering::Relaxed))
}
#[inline]
fn set(&self, value: ItemIndex) {
debug_assert_ne!(
value,
ItemIndex::SENTINEL,
"IndexCell::set: sentinel must never be stored in the table",
);
self.0.store(value.as_u32(), core::sync::atomic::Ordering::Relaxed);
}
}
#[derive(Clone, Debug)]
struct Index(IndexCell);
impl Index {
const SENTINEL_VALUE: ItemIndex = ItemIndex::SENTINEL;
#[inline]
const fn sentinel() -> Self {
Self(IndexCell::new(Self::SENTINEL_VALUE))
}
#[inline]
fn new(value: ItemIndex) -> Self {
if value == Self::SENTINEL_VALUE {
panic!("btree map overflow, index with value {value:?} was added")
}
Self(IndexCell::new(value))
}
#[inline]
fn value(&self) -> ItemIndex {
self.0.get()
}
#[inline]
fn set_value(&self, value: ItemIndex) {
self.0.set(value)
}
}
impl PartialEq for Index {
fn eq(&self, other: &Self) -> bool {
let (a, b) = (self.value(), other.value());
if a != Self::SENTINEL_VALUE && b != Self::SENTINEL_VALUE {
return a == b;
}
CMP.with(|cmp| {
let cmp = cmp.get().expect("cmp should be set");
cmp(self, other) == Ordering::Equal
})
}
}
impl Eq for Index {}
impl Ord for Index {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
CMP.with(|cmp| {
let cmp = cmp.get().expect("cmp should be set");
cmp(self, other)
})
}
}
impl PartialOrd for Index {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use crate::support::{alloc::Global, item_set::ItemSet};
use core::cell::Cell;
thread_local! {
static PANIC_TRIGGER: Cell<bool> = const { Cell::new(false) };
}
#[derive(Clone, Debug, PartialEq, Eq)]
struct PanickingKey(u32);
impl PartialOrd for PanickingKey {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PanickingKey {
fn cmp(&self, other: &Self) -> Ordering {
if PANIC_TRIGGER.with(|c| c.get()) {
panic!("simulated Ord panic");
}
self.0.cmp(&other.0)
}
}
#[test]
fn remap_indexes_does_not_call_user_ord() {
let mut set: ItemSet<PanickingKey, Global> = ItemSet::new();
for i in 0..5u32 {
set.assert_can_grow().insert(PanickingKey(i * 10));
}
set.remove(ItemIndex::new(1));
set.remove(ItemIndex::new(3));
let remap = set.shrink_to_fit();
assert!(!remap.is_identity(), "remap should carry two holes");
let mut table = MapBTreeTable::new();
let pre_lookup = |ix: ItemIndex| -> PanickingKey {
match ix.as_u32() {
0 => PanickingKey(0),
2 => PanickingKey(20),
4 => PanickingKey(40),
_ => panic!("unexpected index in pre-compaction lookup: {ix}"),
}
};
for ix in [0u32, 2, 4] {
let ix = ItemIndex::new(ix);
let key = pre_lookup(ix);
table.insert(ix, &key, pre_lookup);
}
assert_eq!(table.len(), 3);
assert_eq!(
table
.items
.iter()
.map(|i| i.value().as_u32())
.collect::<alloc::vec::Vec<_>>(),
[0u32, 2, 4],
);
PANIC_TRIGGER.with(|c| c.set(true));
table.remap_indexes(&remap);
PANIC_TRIGGER.with(|c| c.set(false));
assert_eq!(
table
.items
.iter()
.map(|i| i.value().as_u32())
.collect::<alloc::vec::Vec<_>>(),
[0u32, 1, 2],
);
}
}