#![cfg_attr(not(test), no_std)]
#![forbid(unused_must_use)]
extern crate alloc;
use alloc::{boxed::Box, string::String, vec::Vec};
use core::{borrow::Borrow, fmt, mem, num::NonZero};
pub trait BitString: Ord {
const IS_CONSTANT_SIZE: bool;
fn get(&self, index: usize) -> bool;
fn bits(&self) -> usize;
}
#[derive(Clone)]
pub struct CritBitMap<K, V> {
root: Option<Node<K, V>>,
len: usize,
}
#[derive(Clone)]
pub struct CritBitSet<K> {
map: CritBitMap<K, ()>,
}
pub struct IterMap<'a, K, V> {
stack: Vec<(&'a Node<K, V>, bool)>,
}
pub struct IterKeys<'a, K, V> {
iter: IterMap<'a, K, V>,
}
pub struct IterValues<'a, K, V> {
iter: IterMap<'a, K, V>,
}
pub struct IterSet<'a, K> {
iter: IterKeys<'a, K, ()>,
}
#[derive(Clone)]
enum Node<K, V> {
Parent {
branches: Box<[Self; 2]>,
bit: usize,
},
Leaf {
key: K,
value: V,
},
}
impl<K, V> CritBitMap<K, V> {
pub const fn new() -> Self {
Self { root: None, len: 0 }
}
pub fn is_empty(&self) -> bool {
self.root.is_none()
}
pub fn len(&self) -> usize {
self.len
}
pub fn iter(&self) -> IterMap<'_, K, V> {
IterMap { stack: self.root.as_ref().map_or_else(Vec::new, |node| Vec::from([(node, false)])) }
}
pub fn iter_with<F>(&self, mut f: F)
where
F: FnMut(&K, &V),
{
self.root.as_ref().map(|x| x.iter_with(&mut f));
}
pub fn iter_with_mut<F>(&mut self, mut f: F)
where
F: FnMut(&K, &mut V),
{
self.root.as_mut().map(|x| x.iter_with_mut(&mut f));
}
pub fn keys(&self) -> IterKeys<'_, K, V> {
IterKeys { iter: self.iter() }
}
pub fn values(&self) -> IterValues<'_, K, V> {
IterValues { iter: self.iter() }
}
pub fn first_key_value(&self) -> Option<(&K, &V)> {
self.key_value_at_end(false)
}
pub fn last_key_value(&self) -> Option<(&K, &V)> {
self.key_value_at_end(true)
}
fn key_value_at_end(&self, dir: bool) -> Option<(&K, &V)> {
let mut cur = self.root.as_ref()?;
loop {
match cur {
Node::Parent { branches, .. } => cur = &branches[usize::from(dir)],
Node::Leaf { key, value } => break Some((key, value)),
}
}
}
}
impl<K> CritBitSet<K> {
pub const fn new() -> Self {
Self {
map: CritBitMap::new(),
}
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn iter(&self) -> IterSet<'_, K> {
IterSet { iter: self.map.keys() }
}
pub fn iter_with<F>(&self, mut f: F)
where
F: FnMut(&K),
{
self.map.iter_with(|k, _| (f)(k))
}
pub fn first(&self) -> Option<&K> {
self.map.first_key_value().map(|x| x.0)
}
pub fn last(&self) -> Option<&K> {
self.map.last_key_value().map(|x| x.0)
}
}
impl<K, V> CritBitMap<K, V>
where
K: BitString,
{
pub fn clear(&mut self) {
*self = Self::new();
}
pub fn contains_key<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: BitString + ?Sized,
{
self.get(key).is_some()
}
pub fn get<Q>(&self, key: &Q) -> Option<&V>
where
K: Borrow<Q>,
Q: BitString + ?Sized,
{
self.get_key_value(key).map(|x| x.1)
}
pub fn get_key_value<Q>(&self, k: &Q) -> Option<(&K, &V)>
where
K: Borrow<Q>,
Q: BitString + ?Sized,
{
let mut cur = self.root.as_ref()?;
loop {
match cur {
Node::Parent { branches, bit } => cur = &branches[usize::from(k.get(*bit))],
Node::Leaf { key, value } => break (key.borrow() == k).then_some((key, value)),
}
}
}
fn get_key_value_mut<Q>(&mut self, k: &Q) -> Option<(&K, &mut V)>
where
K: Borrow<Q>,
Q: BitString + ?Sized,
{
let mut cur = self.root.as_mut()?;
loop {
match cur {
Node::Parent { branches, bit } => cur = &mut branches[usize::from(k.get(*bit))],
Node::Leaf { key, value } => break ((&*key).borrow() == k).then_some((key, value)),
}
}
}
pub fn get_mut<Q>(&mut self, key: &Q) -> Option<&mut V>
where
K: Borrow<Q>,
Q: BitString + ?Sized,
{
self.get_key_value_mut(key).map(|x| x.1)
}
pub fn insert(&mut self, key: K, value: V) -> Option<V> {
let Some(mut cur) = self.root.as_mut() else {
self.root = Some(Node::Leaf { key, value });
return None;
};
let bit = loop {
match cur {
Node::Parent { branches, bit } => cur = &mut branches[usize::from(key.get(*bit))],
Node::Leaf { key: k, value: v } if *k == key => {
return Some(mem::replace(v, value));
}
Node::Leaf { key: k, value: _ } => break find_different_bit(k, &key),
}
};
self.insert_new_unchecked(key, value, bit);
None
}
fn insert_new_unchecked(&mut self, key: K, value: V, newbit: usize) {
let mut cur = self.root.as_mut().expect("must have a root");
loop {
match cur {
Node::Leaf { .. } => break,
Node::Parent { bit, .. } if *bit > newbit => break,
Node::Parent { bit, branches } => cur = &mut branches[usize::from(key.get(*bit))],
}
}
let mut newbranches: Box<[mem::MaybeUninit<Node<K, V>>; 2]> =
Box::new([const { mem::MaybeUninit::uninit() }; 2]);
let newdir = usize::from(key.get(newbit));
unsafe {
let cur = cur as *mut Node<K, V>;
newbranches[newdir].write(Node::Leaf { key, value });
newbranches[1 - newdir].write(cur.read());
let branches = mem::transmute::<_, Box<[Node<K, V>; 2]>>(newbranches);
cur.write(Node::Parent {
branches,
bit: newbit,
});
}
self.len += 1;
}
pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
where
K: Borrow<Q>,
Q: BitString + ?Sized,
{
let mut prev = None;
{
let mut cur = self.root.as_mut()?;
loop {
let new_prev = cur as *mut Node<K, V>;
match cur {
Node::Parent { branches, bit } => {
let bit = key.get(*bit);
prev = Some((new_prev, bit));
cur = &mut branches[usize::from(bit)];
}
Node::Leaf { key: k, .. } if (&*k).borrow() == key => break,
Node::Leaf { .. } => return None,
}
}
}
if let Some((prev, bit)) = prev {
let (remove, replace);
{
let prev = unsafe { &mut *prev };
let Node::Parent { branches, .. } = prev else { unreachable!("must be a parent node") };
let dir = usize::from(bit);
remove = &mut branches[dir] as *mut Node<K, V>;
replace = &mut branches[1 - dir] as *mut Node<K, V>;
}
let (leaf, drop_prev);
unsafe {
leaf = remove.read();
drop_prev = prev.read();
prev.write(replace.read());
};
drop(drop_prev);
let Node::Leaf { value, .. } = leaf else { unreachable!("must be a leaf") };
Some(value)
} else {
let Some(Node::Leaf { value, .. }) = self.root.take() else { unreachable!("must have a leaf root") };
Some(value)
}
}
}
impl<K> CritBitSet<K>
where
K: BitString,
{
pub fn contains<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: BitString + ?Sized,
{
self.get(key).is_some()
}
pub fn get<Q>(&self, key: &Q) -> Option<&K>
where
K: Borrow<Q>,
Q: BitString + ?Sized,
{
self.map.get_key_value(key).map(|x| x.0)
}
pub fn insert(&mut self, key: K) -> bool {
self.map.insert(key, ()).is_none()
}
pub fn remove<Q>(&mut self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: BitString + ?Sized,
{
self.map.remove(key).is_some()
}
}
impl<'a, K, V> Iterator for IterMap<'a, K, V> {
type Item = (&'a K, &'a V);
fn next(&mut self) -> Option<Self::Item> {
loop {
let (node, dir) = self.stack.pop()?;
match node {
Node::Leaf { key, value } => break Some((key, value)),
Node::Parent { branches, .. } => {
(!dir).then(|| self.stack.push((node, true)));
self.stack.push((&branches[usize::from(dir)], false));
}
}
}
}
}
impl<'a, K, V> Iterator for IterKeys<'a, K, V> {
type Item = &'a K;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|x| x.0)
}
}
impl<'a, K, V> Iterator for IterValues<'a, K, V> {
type Item = &'a V;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|x| x.1)
}
}
impl<'a, K> Iterator for IterSet<'a, K> {
type Item = &'a K;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
impl<K, V> Default for CritBitMap<K, V> {
fn default() -> Self {
Self::new()
}
}
impl<K> Default for CritBitSet<K> {
fn default() -> Self {
Self::new()
}
}
impl<K, V> Node<K, V> {
fn iter_with<F>(&self, f: &mut F)
where
F: FnMut(&K, &V),
{
match self {
Self::Parent { branches, .. } => branches.iter().for_each(|x| x.iter_with(f)),
Self::Leaf { key, value } => (f)(key, value),
}
}
fn iter_with_mut<F>(&mut self, f: &mut F)
where
F: FnMut(&K, &mut V),
{
match self {
Self::Parent { branches, .. } => branches.iter_mut().for_each(|x| x.iter_with_mut(f)),
Self::Leaf { key, value } => (f)(key, value),
}
}
}
impl<T> BitString for &T
where
T: ?Sized + BitString,
{
const IS_CONSTANT_SIZE: bool = T::IS_CONSTANT_SIZE;
fn get(&self, index: usize) -> bool {
T::get(self, index)
}
fn bits(&self) -> usize {
T::bits(self)
}
}
macro_rules! impl_int {
($flip:literal $($ty:ident)*) => {
$(
impl BitString for $ty {
const IS_CONSTANT_SIZE: bool = true;
#[inline(always)]
fn get(&self, index: usize) -> bool {
let y = if $flip { *self } else { *self ^ (1 << self.bits() - 1) };
(self.bits() - 1).checked_sub(index).map_or(false, |x| (y >> x) & 1 != 0)
}
#[inline(always)]
fn bits(&self) -> usize {
mem::size_of::<Self>() * 8
}
}
impl BitString for NonZero<$ty> {
const IS_CONSTANT_SIZE: bool = $ty::IS_CONSTANT_SIZE;
#[inline(always)]
fn get(&self, index: usize) -> bool {
NonZero::get(*self).get(index)
}
#[inline(always)]
fn bits(&self) -> usize {
NonZero::get(*self).bits()
}
}
)*
}
}
impl_int!(false u8 u16 u32 u64 u128 usize);
impl_int!(true i8 i16 i32 i64 i128 isize);
impl<T> BitString for [T]
where
T: BitString,
{
const IS_CONSTANT_SIZE: bool = false;
fn get(&self, mut index: usize) -> bool {
if T::IS_CONSTANT_SIZE {
let bits = self.first().map_or(0, |x| x.bits());
if bits == 0 {
return false;
}
let sum_bits = self.len() * bits;
let [k, b] = [index / bits, index % bits];
self.get(k)
.map_or_else(|| self.len().get(index - sum_bits), |x| x.get(b))
} else {
for x in self {
let Some(i) = index.checked_sub(x.bits()) else {
return x.get(index);
};
index = i;
}
self.len().get(index)
}
}
fn bits(&self) -> usize {
if T::IS_CONSTANT_SIZE {
self.get(0).map_or(0, |x| x.bits() * self.len())
} else {
self.iter().fold(0, |s, x| s + x.bits())
}
}
}
impl BitString for str {
const IS_CONSTANT_SIZE: bool = false;
fn get(&self, index: usize) -> bool {
BitString::get(self.as_bytes(), index)
}
fn bits(&self) -> usize {
self.as_bytes().bits()
}
}
impl<T> BitString for Vec<T>
where
T: BitString,
{
const IS_CONSTANT_SIZE: bool = false;
fn get(&self, index: usize) -> bool {
BitString::get(self.as_slice(), index)
}
fn bits(&self) -> usize {
self.as_slice().bits()
}
}
impl BitString for String {
const IS_CONSTANT_SIZE: bool = false;
fn get(&self, index: usize) -> bool {
BitString::get(self.as_bytes(), index)
}
fn bits(&self) -> usize {
self.as_bytes().bits()
}
}
impl<K, V> fmt::Debug for CritBitMap<K, V>
where
K: fmt::Debug,
V: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut f = f.debug_map();
self.iter_with(|k, v| {
f.entry(k, v);
});
f.finish()
}
}
impl<K> fmt::Debug for CritBitSet<K>
where
K: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut f = f.debug_set();
self.iter_with(|k| {
f.entry(k);
});
f.finish()
}
}
fn find_different_bit<T>(lhs: &T, rhs: &T) -> usize
where
T: BitString,
{
(0..)
.find(|&i| lhs.get(i) != rhs.get(i))
.expect("at least one differing bit")
}
#[cfg(test)]
mod test {
use super::*;
#[track_caller]
fn set_assert_sorted<K: Ord>(set: &CritBitSet<K>) {
assert!(set.iter().collect::<Vec<_>>().windows(2).all(|x| x[0] < x[1]))
}
#[test]
fn set_i32() {
let mut set = CritBitSet::<i32>::new();
assert!(!set.contains(&0));
assert!(set.insert(0));
assert!(set.contains(&0));
assert!(!set.insert(0));
assert!(set.insert(1));
assert!(set.contains(&0));
assert!(set.contains(&1));
assert!(!set.contains(&3));
assert!(!set.insert(1));
assert!(!set.insert(1));
assert!(!set.insert(1));
assert!(set.insert(5));
assert!(set.insert(4));
assert!(set.contains(&0));
assert!(set.contains(&1));
assert!(!set.contains(&2));
assert!(set.contains(&4));
assert!(set.contains(&5));
set_assert_sorted(&set);
assert!(!set.remove(&3));
assert!(set.remove(&4));
assert!(!set.remove(&4));
set_assert_sorted(&set);
}
#[test]
fn set_str() {
let mut set = CritBitSet::<&str>::new();
assert!(!set.contains("a"));
assert!(set.insert("a"));
assert!(set.contains("a"));
assert!(!set.insert("a"));
assert!(set.insert("ab"));
assert!(set.contains("a"));
assert!(set.contains("ab"));
assert!(!set.contains("ba"));
assert!(!set.insert("ab"));
assert!(!set.insert("ab"));
assert!(!set.insert("ab"));
assert!(set.insert("e"));
assert!(set.insert("ac"));
assert!(set.contains("a"));
assert!(set.contains("ab"));
assert!(set.contains("ac"));
assert!(!set.contains("ad"));
assert!(set.contains("e"));
set_assert_sorted(&set);
assert!(!set.remove("ad"));
assert!(set.remove("ac"));
assert!(!set.remove("ac"));
set_assert_sorted(&set);
}
}