use deepsize::DeepSizeOf;
use super::{RowAddrMask, RowAddrTreeMap, RowSetOps};
#[derive(Clone, Debug, Default, DeepSizeOf)]
pub struct NullableRowAddrSet {
selected: RowAddrTreeMap,
nulls: RowAddrTreeMap,
}
impl NullableRowAddrSet {
pub fn new(selected: RowAddrTreeMap, nulls: RowAddrTreeMap) -> Self {
Self { selected, nulls }
}
pub fn with_nulls(mut self, nulls: RowAddrTreeMap) -> Self {
self.nulls = nulls;
self
}
pub fn empty() -> Self {
Default::default()
}
pub fn len(&self) -> Option<u64> {
self.true_rows().len()
}
pub fn is_empty(&self) -> bool {
self.selected.is_empty()
}
pub fn selected(&self, row_id: u64) -> bool {
self.selected.contains(row_id) && !self.nulls.contains(row_id)
}
pub fn null_rows(&self) -> &RowAddrTreeMap {
&self.nulls
}
pub fn true_rows(&self) -> RowAddrTreeMap {
self.selected.clone() - self.nulls.clone()
}
pub fn union_all(selections: &[Self]) -> Self {
let true_rows = selections
.iter()
.map(|s| s.true_rows())
.collect::<Vec<RowAddrTreeMap>>();
let true_rows_refs = true_rows.iter().collect::<Vec<&RowAddrTreeMap>>();
let selected = RowAddrTreeMap::union_all(&true_rows_refs);
let nulls = RowAddrTreeMap::union_all(
&selections
.iter()
.map(|s| &s.nulls)
.collect::<Vec<&RowAddrTreeMap>>(),
);
let nulls = nulls - &selected;
Self { selected, nulls }
}
}
impl PartialEq for NullableRowAddrSet {
fn eq(&self, other: &Self) -> bool {
self.true_rows() == other.true_rows() && self.nulls == other.nulls
}
}
impl std::ops::BitAndAssign<&Self> for NullableRowAddrSet {
fn bitand_assign(&mut self, rhs: &Self) {
self.nulls = if self.nulls.is_empty() && rhs.nulls.is_empty() {
RowAddrTreeMap::new() } else {
(self.nulls.clone() & &rhs.nulls) | (self.nulls.clone() & &rhs.selected) | (rhs.nulls.clone() & &self.selected) };
self.selected &= &rhs.selected;
}
}
impl std::ops::BitOrAssign<&Self> for NullableRowAddrSet {
fn bitor_assign(&mut self, rhs: &Self) {
self.nulls = if self.nulls.is_empty() && rhs.nulls.is_empty() {
RowAddrTreeMap::new() } else {
let true_rows =
(self.selected.clone() - &self.nulls) | (rhs.selected.clone() - &rhs.nulls);
(self.nulls.clone() | &rhs.nulls) - true_rows
};
self.selected |= &rhs.selected;
}
}
#[derive(Clone, Debug)]
pub enum NullableRowAddrMask {
AllowList(NullableRowAddrSet),
BlockList(NullableRowAddrSet),
}
impl NullableRowAddrMask {
pub fn selected(&self, row_id: u64) -> bool {
match self {
Self::AllowList(NullableRowAddrSet { selected, nulls }) => {
selected.contains(row_id) && !nulls.contains(row_id)
}
Self::BlockList(NullableRowAddrSet { selected, nulls }) => {
!selected.contains(row_id) && !nulls.contains(row_id)
}
}
}
pub fn drop_nulls(self) -> RowAddrMask {
match self {
Self::AllowList(NullableRowAddrSet { selected, nulls }) => {
RowAddrMask::AllowList(selected - nulls)
}
Self::BlockList(NullableRowAddrSet { selected, nulls }) => {
RowAddrMask::BlockList(selected | nulls)
}
}
}
}
impl std::ops::Not for NullableRowAddrMask {
type Output = Self;
fn not(self) -> Self::Output {
match self {
Self::AllowList(set) => Self::BlockList(set),
Self::BlockList(set) => Self::AllowList(set),
}
}
}
impl std::ops::BitAnd for NullableRowAddrMask {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Self::AllowList(a), Self::AllowList(b)) => {
let nulls = if a.nulls.is_empty() && b.nulls.is_empty() {
RowAddrTreeMap::new() } else {
(a.nulls.clone() & &b.nulls) | (a.nulls & &b.selected) | (b.nulls & &a.selected) };
let selected = a.selected & b.selected;
Self::AllowList(NullableRowAddrSet { selected, nulls })
}
(Self::AllowList(allow), Self::BlockList(block))
| (Self::BlockList(block), Self::AllowList(allow)) => {
let nulls = if allow.nulls.is_empty() && block.nulls.is_empty() {
RowAddrTreeMap::new() } else {
(allow.nulls.clone() & &block.nulls) | (allow.nulls - &block.selected) | (block.nulls & &allow.selected) };
let selected = allow.selected - block.selected;
Self::AllowList(NullableRowAddrSet { selected, nulls })
}
(Self::BlockList(a), Self::BlockList(b)) => {
let nulls = if a.nulls.is_empty() && b.nulls.is_empty() {
RowAddrTreeMap::new() } else {
(a.nulls.clone() & &b.nulls) | (a.nulls - &b.selected) | (b.nulls - &a.selected) };
let selected = a.selected | b.selected;
Self::BlockList(NullableRowAddrSet { selected, nulls })
}
}
}
}
impl std::ops::BitOr for NullableRowAddrMask {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Self::AllowList(a), Self::AllowList(b)) => {
let nulls = if a.nulls.is_empty() && b.nulls.is_empty() {
RowAddrTreeMap::new() } else {
let true_rows =
(a.selected.clone() - &a.nulls) | (b.selected.clone() - &b.nulls);
(a.nulls | b.nulls) - true_rows
};
let selected = (a.selected | b.selected) | &nulls;
Self::AllowList(NullableRowAddrSet { selected, nulls })
}
(Self::AllowList(allow), Self::BlockList(block))
| (Self::BlockList(block), Self::AllowList(allow)) => {
let allow_true = allow.selected.clone() - &allow.nulls;
let block_false = block.selected.clone() - &block.nulls;
let nulls = if allow.nulls.is_empty() && block.nulls.is_empty() {
RowAddrTreeMap::new() } else {
(allow.nulls & &block_false) | (block.nulls - &allow_true)
};
let selected = (block_false - &allow_true) | &nulls;
Self::BlockList(NullableRowAddrSet { selected, nulls })
}
(Self::BlockList(a), Self::BlockList(b)) => {
let a_false = a.selected.clone() - &a.nulls;
let b_false = b.selected.clone() - &b.nulls;
let nulls = if a.nulls.is_empty() && b.nulls.is_empty() {
RowAddrTreeMap::new() } else {
(a.nulls.clone() & &b_false)
| (b.nulls.clone() & &a_false)
| (a.nulls & &b.nulls)
};
let selected = (a_false & b_false) | &nulls;
Self::BlockList(NullableRowAddrSet { selected, nulls })
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn rows(ids: &[u64]) -> RowAddrTreeMap {
RowAddrTreeMap::from_iter(ids)
}
fn nullable_set(selected: &[u64], nulls: &[u64]) -> NullableRowAddrSet {
NullableRowAddrSet::new(rows(selected), rows(nulls))
}
fn allow(selected: &[u64], nulls: &[u64]) -> NullableRowAddrMask {
NullableRowAddrMask::AllowList(nullable_set(selected, nulls))
}
fn block(selected: &[u64], nulls: &[u64]) -> NullableRowAddrMask {
NullableRowAddrMask::BlockList(nullable_set(selected, nulls))
}
fn assert_mask_selects(mask: &NullableRowAddrMask, selected: &[u64], not_selected: &[u64]) {
for &id in selected {
assert!(mask.selected(id), "Expected row {} to be selected", id);
}
for &id in not_selected {
assert!(!mask.selected(id), "Expected row {} to NOT be selected", id);
}
}
#[test]
fn test_not_with_nulls() {
let mask = allow(&[1, 2], &[2]);
let not_mask = !mask;
assert_mask_selects(¬_mask, &[0], &[1, 2]);
}
#[test]
fn test_and_with_nulls() {
let true_mask = allow(&[0, 1, 2, 3, 4], &[]);
let null_mask = allow(&[0, 1, 2, 3, 4], &[1, 3]);
let result = true_mask & null_mask.clone();
assert_mask_selects(&result, &[0, 2, 4], &[1, 3]);
let false_mask = block(&[0, 1, 2, 3, 4], &[]);
let result = false_mask & null_mask;
assert_mask_selects(&result, &[], &[0, 1, 2, 3, 4]);
let mask1 = allow(&[0, 1, 2], &[1]);
let mask2 = allow(&[0, 2, 3], &[2]);
let result = mask1 & mask2;
assert_mask_selects(&result, &[0], &[1, 2, 3]);
}
#[test]
fn test_or_with_nulls() {
let false_mask = block(&[0, 1, 2], &[]);
let null_mask = allow(&[0, 1, 2], &[1, 2]);
let result = false_mask | null_mask.clone();
assert_mask_selects(&result, &[0], &[1, 2]);
let true_mask = allow(&[0, 1, 2], &[]);
let result = true_mask | null_mask;
assert_mask_selects(&result, &[0, 1, 2], &[]);
let mask1 = block(&[0, 1, 2, 3], &[1, 2]);
let mask2 = block(&[0, 1, 2, 3], &[2, 3]);
let result = mask1 | mask2;
assert_mask_selects(&result, &[], &[0, 1, 2, 3]);
}
#[test]
fn test_or_allow_block_keeps_block_nulls() {
let allow_mask = allow(&[1], &[0]);
let block_mask = block(&[], &[0]);
let result = allow_mask | block_mask;
assert_mask_selects(&result, &[1], &[0]);
}
#[test]
fn test_or_allow_block_keeps_block_nulls_with_false_rows() {
let allow_mask = allow(&[2], &[]);
let block_mask = block(&[1], &[0]);
let result = allow_mask | block_mask;
assert_mask_selects(&result, &[2], &[0, 1]);
}
#[test]
fn test_or_block_block_true_overrides_null() {
let true_mask = block(&[], &[]);
let null_mask = block(&[], &[0]);
let result = true_mask | null_mask;
assert_mask_selects(&result, &[0], &[]);
}
#[test]
fn test_row_selection_bit_or() {
let left = nullable_set(&[1, 2, 3, 4], &[2, 4]);
let right = nullable_set(&[3, 4, 5, 6], &[4, 6, 7]);
let expected_true = rows(&[1, 3, 5]);
let expected_nulls = rows(&[2, 4, 6, 7]);
let mut result = left.clone();
result |= &right;
assert_eq!(&result.true_rows(), &expected_true);
assert_eq!(result.null_rows(), &expected_nulls);
let mut result = right.clone();
result |= &left;
assert_eq!(&result.true_rows(), &expected_true);
assert_eq!(result.null_rows(), &expected_nulls);
}
#[test]
fn test_row_selection_bit_and() {
let left = nullable_set(&[1, 2, 3, 4], &[2, 4]);
let right = nullable_set(&[3, 4, 5, 6], &[4, 6, 7]);
let expected_true = rows(&[3]);
let expected_nulls = rows(&[4]);
let mut result = left.clone();
result &= &right;
assert_eq!(&result.true_rows(), &expected_true);
assert_eq!(result.null_rows(), &expected_nulls);
let mut result = right.clone();
result &= &left;
assert_eq!(&result.true_rows(), &expected_true);
assert_eq!(result.null_rows(), &expected_nulls);
}
#[test]
fn test_union_all() {
let set1 = nullable_set(&[1, 2, 3, 4], &[4, 5, 6]);
let set2 = nullable_set(&[1, 4, 7, 8], &[2, 5, 8]);
let set3 = NullableRowAddrSet::empty();
let result = NullableRowAddrSet::union_all(&[set1, set2, set3]);
assert_eq!(&result.true_rows(), &rows(&[1, 2, 3, 4, 7]));
assert_eq!(result.null_rows(), &rows(&[5, 6, 8]));
}
#[test]
fn test_nullable_row_addr_set_with_nulls() {
let set = NullableRowAddrSet::new(rows(&[1, 2, 3]), RowAddrTreeMap::new());
let set_with_nulls = set.with_nulls(rows(&[2]));
assert!(set_with_nulls.selected(1) && set_with_nulls.selected(3));
assert!(!set_with_nulls.selected(2)); }
#[test]
fn test_nullable_row_addr_set_len_and_is_empty() {
let set = nullable_set(&[1, 2, 3, 4, 5], &[2, 4]);
assert_eq!(set.len(), Some(3)); assert!(!set.is_empty());
let empty_set = NullableRowAddrSet::empty();
assert!(empty_set.is_empty());
assert_eq!(empty_set.len(), Some(0));
}
#[test]
fn test_nullable_row_addr_set_selected() {
let set = nullable_set(&[1, 2, 3], &[2]);
assert!(set.selected(1) && set.selected(3));
assert!(!set.selected(2)); assert!(!set.selected(4)); }
#[test]
fn test_nullable_row_addr_set_partial_eq() {
let set1 = nullable_set(&[1, 2, 3], &[2]);
let set2 = nullable_set(&[1, 2, 3], &[2]);
let set3 = nullable_set(&[1, 3], &[3]);
assert_eq!(set1, set2);
assert_ne!(set1, set3); }
#[test]
fn test_nullable_row_addr_set_bitand_fast_path() {
let set1 = nullable_set(&[1, 2, 3], &[]);
let set2 = nullable_set(&[2, 3, 4], &[]);
let mut result = set1;
result &= &set2;
assert!(result.selected(2) && result.selected(3));
assert!(!result.selected(1) && !result.selected(4));
assert!(result.null_rows().is_empty());
}
#[test]
fn test_nullable_row_addr_set_bitor_fast_path() {
let set1 = nullable_set(&[1, 2], &[]);
let set2 = nullable_set(&[3, 4], &[]);
let mut result = set1;
result |= &set2;
for id in [1, 2, 3, 4] {
assert!(result.selected(id));
}
assert!(result.null_rows().is_empty());
}
#[test]
fn test_nullable_row_id_mask_drop_nulls() {
let allow_mask = allow(&[1, 2, 3, 4], &[2, 4]);
let dropped = allow_mask.drop_nulls();
assert!(dropped.selected(1) && dropped.selected(3));
assert!(!dropped.selected(2) && !dropped.selected(4));
let block_mask = block(&[1, 2], &[3]);
let dropped = block_mask.drop_nulls();
assert!(!dropped.selected(1) && !dropped.selected(2) && !dropped.selected(3));
assert!(dropped.selected(4) && dropped.selected(5));
}
#[test]
fn test_nullable_row_id_mask_not_blocklist() {
let block_mask = block(&[1, 2], &[2]);
let not_mask = !block_mask;
assert!(matches!(not_mask, NullableRowAddrMask::AllowList(_)));
}
#[test]
fn test_nullable_row_id_mask_bitand_allow_allow_fast_path() {
let mask1 = allow(&[1, 2, 3], &[]);
let mask2 = allow(&[2, 3, 4], &[]);
let result = mask1 & mask2;
assert_mask_selects(&result, &[2, 3], &[1, 4]);
}
#[test]
fn test_nullable_row_id_mask_bitand_allow_block() {
let allow_mask = allow(&[1, 2, 3, 4, 5], &[2]);
let block_mask = block(&[3, 4], &[4]);
let result = allow_mask & block_mask;
assert_mask_selects(&result, &[1, 5], &[2, 3, 4]);
}
#[test]
fn test_nullable_row_id_mask_bitand_allow_block_fast_path() {
let allow_mask = allow(&[1, 2, 3], &[]);
let block_mask = block(&[2], &[]);
let result = allow_mask & block_mask;
assert_mask_selects(&result, &[1, 3], &[2]);
}
#[test]
fn test_nullable_row_id_mask_bitand_block_block() {
let block1 = block(&[1, 2], &[2]);
let block2 = block(&[2, 3], &[3]);
let result = block1 & block2;
assert_mask_selects(&result, &[4], &[1, 2, 3]);
}
#[test]
fn test_nullable_row_id_mask_bitand_block_block_fast_path() {
let block1 = block(&[1], &[]);
let block2 = block(&[2], &[]);
let result = block1 & block2;
assert_mask_selects(&result, &[3], &[1, 2]);
}
#[test]
fn test_nullable_row_id_mask_bitor_allow_allow_fast_path() {
let mask1 = allow(&[1, 2], &[]);
let mask2 = allow(&[3, 4], &[]);
let result = mask1 | mask2;
assert_mask_selects(&result, &[1, 2, 3, 4], &[5]);
}
#[test]
fn test_nullable_row_id_mask_bitor_allow_block() {
let allow_mask = allow(&[1, 2, 3], &[2]);
let block_mask = block(&[1, 4], &[4]);
let result = allow_mask | block_mask;
assert_mask_selects(&result, &[1, 2, 3], &[]);
}
#[test]
fn test_nullable_row_id_mask_bitor_allow_block_fast_path() {
let allow_mask = allow(&[1], &[]);
let block_mask = block(&[2], &[]);
let result = allow_mask | block_mask;
assert_mask_selects(&result, &[1, 3], &[2]);
}
#[test]
fn test_nullable_row_id_mask_bitor_block_block_fast_path() {
let block1 = block(&[1, 2], &[]);
let block2 = block(&[2, 3], &[]);
let result = block1 | block2;
assert_mask_selects(&result, &[1, 3, 4], &[2]);
}
}