#![deny(missing_docs)]
mod bitops;
mod eq;
mod intersect_by_rank;
#[cfg(test)]
mod tests;
use std::cmp::Ordering;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::ops::Bound;
use std::ops::RangeBounds;
use std::sync::Arc;
use std::sync::OnceLock;
use itertools::Itertools;
use vortex_buffer::BitBuffer;
use vortex_buffer::BitBufferMut;
use vortex_error::VortexResult;
use vortex_error::vortex_panic;
pub enum AllOr<T> {
All,
None,
Some(T),
}
impl<T> AllOr<T> {
#[inline]
pub fn unwrap_or_else<F, G>(self, all_true: F, all_false: G) -> T
where
F: FnOnce() -> T,
G: FnOnce() -> T,
{
match self {
Self::Some(v) => v,
AllOr::All => all_true(),
AllOr::None => all_false(),
}
}
}
impl<T> AllOr<&T> {
#[inline]
pub fn cloned(self) -> AllOr<T>
where
T: Clone,
{
match self {
Self::All => AllOr::All,
Self::None => AllOr::None,
Self::Some(v) => AllOr::Some(v.clone()),
}
}
}
impl<T> Debug for AllOr<T>
where
T: Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::All => f.write_str("All"),
Self::None => f.write_str("None"),
Self::Some(v) => f.debug_tuple("Some").field(v).finish(),
}
}
}
impl<T> PartialEq for AllOr<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::All, Self::All) => true,
(Self::None, Self::None) => true,
(Self::Some(lhs), Self::Some(rhs)) => lhs == rhs,
_ => false,
}
}
}
impl<T> Eq for AllOr<T> where T: Eq {}
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
pub enum Mask {
AllTrue(usize),
AllFalse(usize),
Values(Arc<MaskValues>),
}
impl Debug for Mask {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::AllTrue(len) => write!(f, "All true({len})"),
Self::AllFalse(len) => write!(f, "All false({len})"),
Self::Values(mask) => write!(f, "{mask:?}"),
}
}
}
impl Default for Mask {
fn default() -> Self {
Self::new_true(0)
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MaskValues {
buffer: BitBuffer,
#[cfg_attr(feature = "serde", serde(skip))]
indices: OnceLock<Vec<usize>>,
#[cfg_attr(feature = "serde", serde(skip))]
slices: OnceLock<Vec<(usize, usize)>>,
true_count: usize,
density: f64,
}
impl Debug for MaskValues {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "true_count={}, ", self.true_count)?;
write!(f, "density={}, ", self.density)?;
if let Some(v) = self.indices.get() {
write!(f, "indices={v:?}, ")?;
}
if let Some(v) = self.slices.get() {
write!(f, "slices={v:?}, ")?;
}
if f.alternate() {
f.write_str("\n")?;
}
write!(f, "{}", self.buffer)
}
}
impl Mask {
pub fn new(length: usize, value: bool) -> Self {
if value {
Self::AllTrue(length)
} else {
Self::AllFalse(length)
}
}
#[inline]
pub fn new_true(length: usize) -> Self {
Self::AllTrue(length)
}
#[inline]
pub fn new_false(length: usize) -> Self {
Self::AllFalse(length)
}
pub fn from_buffer(buffer: BitBuffer) -> Self {
let len = buffer.len();
let true_count = buffer.true_count();
if true_count == 0 {
return Self::AllFalse(len);
}
if true_count == len {
return Self::AllTrue(len);
}
Self::Values(Arc::new(MaskValues {
buffer,
indices: Default::default(),
slices: Default::default(),
true_count,
density: true_count as f64 / len as f64,
}))
}
pub fn from_indices(len: usize, indices: Vec<usize>) -> Self {
let true_count = indices.len();
assert!(indices.is_sorted(), "Mask indices must be sorted");
assert!(
indices.last().is_none_or(|&idx| idx < len),
"Mask indices must be in bounds (len={len})"
);
if true_count == 0 {
return Self::AllFalse(len);
}
if true_count == len {
return Self::AllTrue(len);
}
let mut buf = BitBufferMut::new_unset(len);
indices.iter().for_each(|&idx| buf.set(idx));
debug_assert_eq!(buf.len(), len);
Self::Values(Arc::new(MaskValues {
buffer: buf.freeze(),
indices: OnceLock::from(indices),
slices: Default::default(),
true_count,
density: true_count as f64 / len as f64,
}))
}
pub fn from_excluded_indices(len: usize, indices: impl IntoIterator<Item = usize>) -> Self {
let mut buf = BitBufferMut::new_set(len);
let mut false_count: usize = 0;
indices.into_iter().for_each(|idx| {
buf.unset(idx);
false_count += 1;
});
debug_assert_eq!(buf.len(), len);
let true_count = len - false_count;
if false_count == 0 {
return Self::AllTrue(len);
}
if false_count == len {
return Self::AllFalse(len);
}
Self::Values(Arc::new(MaskValues {
buffer: buf.freeze(),
indices: Default::default(),
slices: Default::default(),
true_count,
density: true_count as f64 / len as f64,
}))
}
pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
Self::check_slices(len, &vec);
Self::from_slices_unchecked(len, vec)
}
fn from_slices_unchecked(len: usize, slices: Vec<(usize, usize)>) -> Self {
#[cfg(debug_assertions)]
Self::check_slices(len, &slices);
let true_count = slices.iter().map(|(b, e)| e - b).sum();
if true_count == 0 {
return Self::AllFalse(len);
}
if true_count == len {
return Self::AllTrue(len);
}
let mut buf = BitBufferMut::new_unset(len);
for (start, end) in slices.iter().copied() {
(start..end).for_each(|idx| buf.set(idx));
}
debug_assert_eq!(buf.len(), len);
Self::Values(Arc::new(MaskValues {
buffer: buf.freeze(),
indices: Default::default(),
slices: OnceLock::from(slices),
true_count,
density: true_count as f64 / len as f64,
}))
}
#[inline(always)]
fn check_slices(len: usize, vec: &[(usize, usize)]) {
assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
for (first, second) in vec.iter().tuple_windows() {
assert!(
first.0 < second.0,
"Slices must be sorted, got {first:?} and {second:?}"
);
assert!(
first.1 <= second.0,
"Slices must be non-overlapping, got {first:?} and {second:?}"
);
}
}
pub fn from_intersection_indices(
len: usize,
lhs: impl Iterator<Item = usize>,
rhs: impl Iterator<Item = usize>,
) -> Self {
let mut intersection = Vec::with_capacity(len);
let mut lhs = lhs.peekable();
let mut rhs = rhs.peekable();
while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
match l.cmp(&r) {
Ordering::Less => {
lhs.next();
}
Ordering::Greater => {
rhs.next();
}
Ordering::Equal => {
intersection.push(l);
lhs.next();
rhs.next();
}
}
}
Self::from_indices(len, intersection)
}
pub fn clear(&mut self) {
*self = Self::new_false(0);
}
#[inline]
pub fn len(&self) -> usize {
match self {
Self::AllTrue(len) => *len,
Self::AllFalse(len) => *len,
Self::Values(values) => values.len(),
}
}
#[inline]
pub fn is_empty(&self) -> bool {
match self {
Self::AllTrue(len) => *len == 0,
Self::AllFalse(len) => *len == 0,
Self::Values(values) => values.is_empty(),
}
}
#[inline]
pub fn true_count(&self) -> usize {
match &self {
Self::AllTrue(len) => *len,
Self::AllFalse(_) => 0,
Self::Values(values) => values.true_count,
}
}
#[inline]
pub fn false_count(&self) -> usize {
match &self {
Self::AllTrue(_) => 0,
Self::AllFalse(len) => *len,
Self::Values(values) => values.buffer.len() - values.true_count,
}
}
#[inline]
pub fn all_true(&self) -> bool {
match &self {
Self::AllTrue(_) => true,
Self::AllFalse(0) => true,
Self::AllFalse(_) => false,
Self::Values(values) => values.buffer.len() == values.true_count,
}
}
#[inline]
pub fn all_false(&self) -> bool {
self.true_count() == 0
}
#[inline]
pub fn density(&self) -> f64 {
match &self {
Self::AllTrue(_) => 1.0,
Self::AllFalse(_) => 0.0,
Self::Values(values) => values.density,
}
}
#[inline]
pub fn value(&self, idx: usize) -> bool {
match self {
Mask::AllTrue(_) => true,
Mask::AllFalse(_) => false,
Mask::Values(values) => values.buffer.value(idx),
}
}
pub fn first(&self) -> Option<usize> {
match &self {
Self::AllTrue(len) => (*len > 0).then_some(0),
Self::AllFalse(_) => None,
Self::Values(values) => {
if let Some(indices) = values.indices.get() {
return indices.first().copied();
}
if let Some(slices) = values.slices.get() {
return slices.first().map(|(start, _)| *start);
}
values.buffer.set_indices().next()
}
}
}
pub fn last(&self) -> Option<usize> {
match &self {
Self::AllTrue(len) => (*len > 0).then_some(*len - 1),
Self::AllFalse(_) => None,
Self::Values(values) => {
if let Some(indices) = values.indices.get() {
return indices.last().copied();
}
if let Some(slices) = values.slices.get() {
return slices.last().map(|(_, end)| end - 1);
}
values.buffer.set_slices().last().map(|(_, end)| end - 1)
}
}
}
pub fn rank(&self, n: usize) -> usize {
if n >= self.true_count() {
vortex_panic!(
"Rank {n} out of bounds for mask with true count {}",
self.true_count()
);
}
match &self {
Self::AllTrue(_) => n,
Self::AllFalse(_) => unreachable!("no true values in all-false mask"),
Self::Values(values) => values.indices()[n],
}
}
pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
let start = match range.start_bound() {
Bound::Included(&s) => s,
Bound::Excluded(&s) => s + 1,
Bound::Unbounded => 0,
};
let end = match range.end_bound() {
Bound::Included(&e) => e + 1,
Bound::Excluded(&e) => e,
Bound::Unbounded => self.len(),
};
assert!(start <= end);
assert!(start <= self.len());
assert!(end <= self.len());
let len = end - start;
match &self {
Self::AllTrue(_) => Self::new_true(len),
Self::AllFalse(_) => Self::new_false(len),
Self::Values(values) => Self::from_buffer(values.buffer.slice(range)),
}
}
#[inline]
pub fn bit_buffer(&self) -> AllOr<&BitBuffer> {
match &self {
Self::AllTrue(_) => AllOr::All,
Self::AllFalse(_) => AllOr::None,
Self::Values(values) => AllOr::Some(&values.buffer),
}
}
#[inline]
pub fn to_bit_buffer(&self) -> BitBuffer {
match self {
Self::AllTrue(l) => BitBuffer::new_set(*l),
Self::AllFalse(l) => BitBuffer::new_unset(*l),
Self::Values(values) => values.bit_buffer().clone(),
}
}
#[inline]
pub fn into_bit_buffer(self) -> BitBuffer {
match self {
Self::AllTrue(l) => BitBuffer::new_set(l),
Self::AllFalse(l) => BitBuffer::new_unset(l),
Self::Values(values) => Arc::try_unwrap(values)
.map(|v| v.into_bit_buffer())
.unwrap_or_else(|v| v.bit_buffer().clone()),
}
}
#[inline]
pub fn indices(&self) -> AllOr<&[usize]> {
match &self {
Self::AllTrue(_) => AllOr::All,
Self::AllFalse(_) => AllOr::None,
Self::Values(values) => AllOr::Some(values.indices()),
}
}
#[inline]
pub fn slices(&self) -> AllOr<&[(usize, usize)]> {
match &self {
Self::AllTrue(_) => AllOr::All,
Self::AllFalse(_) => AllOr::None,
Self::Values(values) => AllOr::Some(values.slices()),
}
}
#[inline]
pub fn threshold_iter(&self, threshold: f64) -> AllOr<MaskIter<'_>> {
match &self {
Self::AllTrue(_) => AllOr::All,
Self::AllFalse(_) => AllOr::None,
Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)),
}
}
#[inline]
pub fn values(&self) -> Option<&MaskValues> {
if let Self::Values(values) = self {
Some(values)
} else {
None
}
}
pub fn valid_counts_for_indices(&self, indices: &[usize]) -> Vec<usize> {
match self {
Self::AllTrue(_) => indices.to_vec(),
Self::AllFalse(_) => vec![0; indices.len()],
Self::Values(values) => {
let mut bool_iter = values.bit_buffer().iter();
let mut valid_counts = Vec::with_capacity(indices.len());
let mut valid_count = 0;
let mut idx = 0;
for &next_idx in indices {
while idx < next_idx {
idx += 1;
valid_count += bool_iter
.next()
.unwrap_or_else(|| vortex_panic!("Row indices exceed array length"))
as usize;
}
valid_counts.push(valid_count);
}
valid_counts
}
}
}
pub fn limit(self, limit: usize) -> Self {
if self.len() <= limit {
return self;
}
match self {
Mask::AllTrue(len) => {
Self::from_iter([Self::new_true(limit), Self::new_false(len - limit)])
}
Mask::AllFalse(_) => self,
Mask::Values(ref mask_values) => {
if limit >= mask_values.true_count() {
return self;
}
let existing_buffer = mask_values.bit_buffer();
let mut new_buffer_builder = BitBufferMut::new_unset(mask_values.len());
debug_assert!(limit < mask_values.len());
for index in existing_buffer.set_indices().take(limit) {
unsafe { new_buffer_builder.set_unchecked(index) }
}
Self::from(new_buffer_builder.freeze())
}
}
}
pub fn concat<'a>(masks: impl Iterator<Item = &'a Self>) -> VortexResult<Self> {
let masks: Vec<_> = masks.collect();
let len = masks.iter().map(|t| t.len()).sum();
if masks.iter().all(|t| t.all_true()) {
return Ok(Mask::AllTrue(len));
}
if masks.iter().all(|t| t.all_false()) {
return Ok(Mask::AllFalse(len));
}
let mut builder = BitBufferMut::with_capacity(len);
for mask in masks {
match mask {
Mask::AllTrue(n) => builder.append_n(true, *n),
Mask::AllFalse(n) => builder.append_n(false, *n),
Mask::Values(v) => builder.append_buffer(v.bit_buffer()),
}
}
Ok(Mask::from_buffer(builder.freeze()))
}
}
impl MaskValues {
#[inline]
pub fn len(&self) -> usize {
self.buffer.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
#[inline]
pub fn density(&self) -> f64 {
self.density
}
#[inline]
pub fn true_count(&self) -> usize {
self.true_count
}
#[inline]
pub fn bit_buffer(&self) -> &BitBuffer {
&self.buffer
}
#[inline]
pub fn into_bit_buffer(self) -> BitBuffer {
self.buffer
}
#[inline]
pub fn value(&self, index: usize) -> bool {
self.buffer.value(index)
}
pub fn indices(&self) -> &[usize] {
self.indices.get_or_init(|| {
if self.true_count == 0 {
return vec![];
}
if self.true_count == self.len() {
return (0..self.len()).collect();
}
if let Some(slices) = self.slices.get() {
let mut indices = Vec::with_capacity(self.true_count);
indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
debug_assert!(indices.is_sorted());
assert_eq!(indices.len(), self.true_count);
return indices;
}
let mut indices = Vec::with_capacity(self.true_count);
indices.extend(self.buffer.set_indices());
debug_assert!(indices.is_sorted());
assert_eq!(indices.len(), self.true_count);
indices
})
}
#[inline]
pub fn slices(&self) -> &[(usize, usize)] {
self.slices.get_or_init(|| {
if self.true_count == self.len() {
return vec![(0, self.len())];
}
self.buffer.set_slices().collect()
})
}
#[inline]
pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> {
if self.density >= threshold {
MaskIter::Slices(self.slices())
} else {
MaskIter::Indices(self.indices())
}
}
}
pub enum MaskIter<'a> {
Indices(&'a [usize]),
Slices(&'a [(usize, usize)]),
}
impl From<BitBuffer> for Mask {
fn from(value: BitBuffer) -> Self {
Self::from_buffer(value)
}
}
impl FromIterator<bool> for Mask {
#[inline]
fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
Self::from_buffer(BitBuffer::from_iter(iter))
}
}
impl FromIterator<Mask> for Mask {
fn from_iter<T: IntoIterator<Item = Mask>>(iter: T) -> Self {
let masks = iter
.into_iter()
.filter(|m| !m.is_empty())
.collect::<Vec<_>>();
let total_length = masks.iter().map(|v| v.len()).sum();
if masks.iter().all(|v| v.all_true()) {
return Self::AllTrue(total_length);
}
if masks.iter().all(|v| v.all_false()) {
return Self::AllFalse(total_length);
}
let mut buffer = BitBufferMut::with_capacity(total_length);
for mask in masks {
match mask {
Mask::AllTrue(count) => buffer.append_n(true, count),
Mask::AllFalse(count) => buffer.append_n(false, count),
Mask::Values(values) => {
buffer.append_buffer(values.bit_buffer());
}
};
}
Self::from_buffer(buffer.freeze())
}
}