use roaring::RoaringBitmap;
#[derive(Debug, Clone, Default)]
pub struct TidBitmap {
inner: RoaringBitmap,
cap_bytes: usize,
}
#[derive(Debug)]
pub enum BitmapError {
TooLarge { current: usize, cap: usize },
}
impl std::fmt::Display for BitmapError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooLarge { current, cap } => {
write!(f, "tid bitmap {current} bytes exceeds cap {cap}")
}
}
}
}
impl std::error::Error for BitmapError {}
impl TidBitmap {
pub fn new() -> Self {
Self {
inner: RoaringBitmap::new(),
cap_bytes: 32 * 1024 * 1024,
}
}
pub fn with_cap_bytes(cap_bytes: usize) -> Self {
Self {
inner: RoaringBitmap::new(),
cap_bytes,
}
}
pub fn insert(&mut self, tid: u32) -> Result<bool, BitmapError> {
let added = self.inner.insert(tid);
self.check_cap()?;
Ok(added)
}
pub fn extend_from_iter(
&mut self,
iter: impl IntoIterator<Item = u32>,
) -> Result<usize, BitmapError> {
let mut count = 0usize;
for tid in iter {
self.inner.insert(tid);
count += 1;
if count.is_multiple_of(4096) {
self.check_cap()?;
}
}
self.check_cap()?;
Ok(count)
}
fn check_cap(&self) -> Result<(), BitmapError> {
if self.cap_bytes == 0 {
return Ok(());
}
let current = self.inner.serialized_size();
if current > self.cap_bytes {
return Err(BitmapError::TooLarge {
current,
cap: self.cap_bytes,
});
}
Ok(())
}
pub fn contains(&self, tid: u32) -> bool {
self.inner.contains(tid)
}
pub fn len(&self) -> u64 {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn intersect_with(&mut self, other: &TidBitmap) {
self.inner &= &other.inner;
}
pub fn union_with(&mut self, other: &TidBitmap) {
self.inner |= &other.inner;
}
pub fn difference_with(&mut self, other: &TidBitmap) {
self.inner -= &other.inner;
}
pub fn iter(&self) -> impl Iterator<Item = u32> + '_ {
self.inner.iter()
}
pub fn into_sorted_vec(self) -> Vec<u32> {
self.inner.into_iter().collect()
}
pub fn group_by_page(&self, rows_per_page: u32) -> Vec<(u32, Vec<u32>)> {
if rows_per_page == 0 {
return Vec::new();
}
let mut groups: Vec<(u32, Vec<u32>)> = Vec::new();
let mut current_page: Option<u32> = None;
let mut current_rows: Vec<u32> = Vec::new();
for tid in self.inner.iter() {
let page = tid / rows_per_page;
let row = tid % rows_per_page;
match current_page {
Some(p) if p == page => current_rows.push(row),
_ => {
if let Some(p) = current_page {
groups.push((p, std::mem::take(&mut current_rows)));
}
current_page = Some(page);
current_rows.push(row);
}
}
}
if let Some(p) = current_page {
groups.push((p, current_rows));
}
groups
}
pub fn union_cardinality(&self, other: &TidBitmap) -> u64 {
self.inner.union_len(&other.inner)
}
pub fn intersection_cardinality(&self, other: &TidBitmap) -> u64 {
self.inner.intersection_len(&other.inner)
}
}
pub fn from_iter(iter: impl IntoIterator<Item = u32>) -> Result<TidBitmap, BitmapError> {
let mut bitmap = TidBitmap::new();
bitmap.extend_from_iter(iter)?;
Ok(bitmap)
}
pub fn intersect_all(mut bitmaps: Vec<TidBitmap>) -> TidBitmap {
if bitmaps.is_empty() {
return TidBitmap::new();
}
let mut acc = bitmaps.remove(0);
for b in bitmaps {
acc.intersect_with(&b);
if acc.is_empty() {
return acc;
}
}
acc
}
pub fn union_all(bitmaps: Vec<TidBitmap>) -> TidBitmap {
let mut iter = bitmaps.into_iter();
let Some(mut acc) = iter.next() else {
return TidBitmap::new();
};
for b in iter {
acc.union_with(&b);
}
acc
}