use super::{RankSelectInterleaved256, RankSelectOps};
use crate::error::{Result, ZiporaError};
use crate::succinct::BitVector;
use crate::system::{CpuFeatures, get_cpu_features};
use crate::FastVec;
use std::sync::Arc;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[derive(Clone)]
pub struct MultiDimRankSelect<const DIMS: usize, const BLOCK_SIZE: usize = 256> {
dimensions: [Arc<RankSelectInterleaved256>; DIMS],
total_bits: usize,
cpu_features: &'static CpuFeatures,
}
impl<const DIMS: usize, const BLOCK_SIZE: usize> MultiDimRankSelect<DIMS, BLOCK_SIZE> {
pub fn new(bit_vectors: Vec<BitVector>) -> Result<Self> {
if DIMS == 0 {
return Err(ZiporaError::invalid_data(
"MultiDimRankSelect requires at least 1 dimension",
));
}
if DIMS > 32 {
return Err(ZiporaError::invalid_data(
"MultiDimRankSelect supports maximum 32 dimensions",
));
}
if bit_vectors.len() != DIMS {
return Err(ZiporaError::invalid_data(format!(
"Expected {} dimensions, got {}",
DIMS,
bit_vectors.len()
)));
}
let total_bits = bit_vectors.first()
.ok_or_else(|| ZiporaError::invalid_data("Empty bit vectors"))?
.len();
for (i, bv) in bit_vectors.iter().enumerate() {
if bv.len() != total_bits {
return Err(ZiporaError::invalid_data(format!(
"Dimension {} has {} bits, expected {}",
i,
bv.len(),
total_bits
)));
}
}
let mut dimensions_vec = Vec::with_capacity(DIMS);
for bv in bit_vectors {
let rs = RankSelectInterleaved256::new(bv)?;
dimensions_vec.push(Arc::new(rs));
}
let dimensions: [Arc<RankSelectInterleaved256>; DIMS] = dimensions_vec
.try_into()
.map_err(|_| ZiporaError::invalid_data("Failed to convert dimensions to array"))?;
Ok(Self {
dimensions,
total_bits,
cpu_features: get_cpu_features(),
})
}
#[inline]
pub fn total_bits(&self) -> usize {
self.total_bits
}
#[inline]
pub const fn num_dimensions(&self) -> usize {
DIMS
}
pub fn bulk_rank_multidim(&self, positions: &[usize; DIMS]) -> [usize; DIMS] {
let mut ranks = [0usize; DIMS];
for (i, &pos) in positions.iter().enumerate() {
if pos > self.total_bits {
ranks[i] = 0; continue;
}
}
#[cfg(target_arch = "x86_64")]
{
if DIMS <= 4 && self.cpu_features.has_avx2 {
unsafe { self.bulk_rank_avx2(positions, &mut ranks) };
return ranks;
}
}
self.bulk_rank_scalar(positions, &mut ranks);
ranks
}
#[inline]
fn bulk_rank_scalar(&self, positions: &[usize; DIMS], ranks: &mut [usize; DIMS]) {
for dim in 0..DIMS {
if positions[dim] <= self.total_bits {
ranks[dim] = self.dimensions[dim].rank1(positions[dim]);
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,popcnt")]
unsafe fn bulk_rank_avx2(&self, positions: &[usize; DIMS], ranks: &mut [usize; DIMS]) {
debug_assert!(DIMS <= 4, "AVX2 bulk rank supports maximum 4 dimensions");
for dim in 0..DIMS {
if positions[dim] <= self.total_bits {
ranks[dim] = self.dimensions[dim].rank1(positions[dim]);
if dim + 1 < DIMS {
let next_pos = positions[dim + 1];
if next_pos <= self.total_bits {
unsafe {
_mm_prefetch::<_MM_HINT_T0>(
&self.dimensions[dim + 1] as *const _ as *const i8
);
}
}
}
}
}
}
pub fn bulk_select_multidim(&self, ranks: &[usize; DIMS]) -> Result<[usize; DIMS]> {
let mut positions = [0usize; DIMS];
for dim in 0..DIMS {
positions[dim] = self.dimensions[dim].select1(ranks[dim])?;
}
Ok(positions)
}
pub fn intersect_dimensions(&self, dim_a: usize, dim_b: usize) -> Result<BitVector> {
if dim_a >= DIMS || dim_b >= DIMS {
return Err(ZiporaError::invalid_data(format!(
"Invalid dimension indices: {}, {} (max: {})",
dim_a, dim_b, DIMS
)));
}
let mut result = BitVector::new();
let bits_a = self.dimensions[dim_a].get_bit_data();
let bits_b = self.dimensions[dim_b].get_bit_data();
#[cfg(target_arch = "x86_64")]
if self.cpu_features.has_avx2 {
let result_bits = unsafe { Self::intersect_avx2(&bits_a, &bits_b) };
result = BitVector::from_raw_bits(result_bits, self.total_bits)?;
return Ok(result);
}
let result_bits = Self::intersect_scalar(&bits_a, &bits_b);
result = BitVector::from_raw_bits(result_bits, self.total_bits)?;
Ok(result)
}
fn intersect_scalar(bits_a: &[u64], bits_b: &[u64]) -> Vec<u64> {
let len = bits_a.len().min(bits_b.len());
let mut result = Vec::with_capacity(len);
for i in 0..len {
result.push(bits_a[i] & bits_b[i]);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn intersect_avx2(bits_a: &[u64], bits_b: &[u64]) -> Vec<u64> {
unsafe {
let len = bits_a.len().min(bits_b.len());
let mut result: Vec<u64> = vec![0u64; len];
let chunks = len / 4;
for chunk in 0..chunks {
let base = chunk * 4;
let vec_a = _mm256_loadu_si256(bits_a.as_ptr().add(base) as *const __m256i);
let vec_b = _mm256_loadu_si256(bits_b.as_ptr().add(base) as *const __m256i);
let vec_result = _mm256_and_si256(vec_a, vec_b);
_mm256_storeu_si256(result.as_mut_ptr().add(base) as *mut __m256i, vec_result);
}
for i in (chunks * 4)..len {
result[i] = bits_a[i] & bits_b[i];
}
result
}
}
pub fn union_dimensions(&self, dimensions: &[usize]) -> Result<BitVector> {
if dimensions.is_empty() {
return Err(ZiporaError::invalid_data("No dimensions specified for union"));
}
for &dim in dimensions {
if dim >= DIMS {
return Err(ZiporaError::invalid_data(format!(
"Invalid dimension index: {} (max: {})",
dim, DIMS
)));
}
}
let mut result = BitVector::new();
let bit_data_vecs: Vec<Vec<u64>> = dimensions.iter()
.map(|&dim| self.dimensions[dim].get_bit_data())
.collect();
let bit_data: Vec<&[u64]> = bit_data_vecs.iter()
.map(|v| v.as_slice())
.collect();
#[cfg(target_arch = "x86_64")]
if self.cpu_features.has_avx2 {
let result_bits = unsafe { Self::union_avx2(&bit_data) };
result = BitVector::from_raw_bits(result_bits, self.total_bits)?;
return Ok(result);
}
let result_bits = Self::union_scalar(&bit_data);
result = BitVector::from_raw_bits(result_bits, self.total_bits)?;
Ok(result)
}
fn union_scalar(bit_data: &[&[u64]]) -> Vec<u64> {
if bit_data.is_empty() {
return Vec::new();
}
let len = bit_data[0].len();
let mut result = vec![0u64; len];
for bits in bit_data {
for i in 0..len.min(bits.len()) {
result[i] |= bits[i];
}
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn union_avx2(bit_data: &[&[u64]]) -> Vec<u64> {
unsafe {
if bit_data.is_empty() {
return Vec::new();
}
let len = bit_data[0].len();
let mut result: Vec<u64> = vec![0u64; len];
let chunks = len / 4;
for chunk in 0..chunks {
let base = chunk * 4;
let mut acc = _mm256_setzero_si256();
for bits in bit_data {
if bits.len() >= base + 4 {
let vec = _mm256_loadu_si256(bits.as_ptr().add(base) as *const __m256i);
acc = _mm256_or_si256(acc, vec);
}
}
_mm256_storeu_si256(result.as_mut_ptr().add(base) as *mut __m256i, acc);
}
for i in (chunks * 4)..len {
let mut acc = 0u64;
for bits in bit_data {
if i < bits.len() {
acc |= bits[i];
}
}
result[i] = acc;
}
result
}
}
}
impl<const DIMS: usize, const BLOCK_SIZE: usize> std::fmt::Debug for MultiDimRankSelect<DIMS, BLOCK_SIZE> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MultiDimRankSelect")
.field("dimensions", &DIMS)
.field("block_size", &BLOCK_SIZE)
.field("total_bits", &self.total_bits)
.field("cpu_features", &self.cpu_features)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_bitvector(size: usize, pattern: impl Fn(usize) -> bool) -> Result<BitVector> {
let mut bv = BitVector::new();
for i in 0..size {
bv.push(pattern(i))?;
}
Ok(bv)
}
#[test]
fn test_multidim_creation() -> Result<()> {
let mut dims = vec![];
for _ in 0..3 {
dims.push(create_test_bitvector(100, |i| i % 2 == 0)?);
}
let multi_rs: MultiDimRankSelect<3> = MultiDimRankSelect::new(dims)?;
assert_eq!(multi_rs.num_dimensions(), 3);
assert_eq!(multi_rs.total_bits(), 100);
Ok(())
}
#[test]
fn test_bulk_rank_multidim() -> Result<()> {
let mut dims = vec![];
for d in 0..4 {
dims.push(create_test_bitvector(200, |i| i % (d + 2) == 0)?);
}
let multi_rs: MultiDimRankSelect<4> = MultiDimRankSelect::new(dims)?;
let positions = [50, 100, 150, 200];
let ranks = multi_rs.bulk_rank_multidim(&positions);
assert_eq!(ranks.len(), 4);
for &rank in &ranks {
assert!(rank <= 200);
}
Ok(())
}
#[test]
fn test_bulk_select_multidim() -> Result<()> {
let mut dims = vec![];
for _ in 0..3 {
dims.push(create_test_bitvector(100, |i| i % 3 == 0)?);
}
let multi_rs: MultiDimRankSelect<3> = MultiDimRankSelect::new(dims)?;
let ranks = [5, 10, 15];
let positions = multi_rs.bulk_select_multidim(&ranks)?;
assert_eq!(positions.len(), 3);
for &pos in &positions {
assert!(pos < 100);
}
Ok(())
}
#[test]
fn test_intersect_dimensions() -> Result<()> {
let mut dims = vec![];
dims.push(create_test_bitvector(100, |i| i % 2 == 0)?);
dims.push(create_test_bitvector(100, |i| i % 3 == 0)?);
let multi_rs: MultiDimRankSelect<2> = MultiDimRankSelect::new(dims)?;
let intersection = multi_rs.intersect_dimensions(0, 1)?;
assert_eq!(intersection.len(), 100);
for i in 0..100 {
let expected = i % 2 == 0 && i % 3 == 0;
let actual = intersection.get(i).ok_or_else(|| ZiporaError::out_of_bounds(i, 100))?;
assert_eq!(actual, expected, "Bit {} mismatch", i);
}
Ok(())
}
#[test]
fn test_union_dimensions() -> Result<()> {
let mut dims = vec![];
dims.push(create_test_bitvector(100, |i| i % 4 == 0)?);
dims.push(create_test_bitvector(100, |i| i % 6 == 0)?);
dims.push(create_test_bitvector(100, |i| i % 8 == 0)?);
let multi_rs: MultiDimRankSelect<3> = MultiDimRankSelect::new(dims)?;
let union = multi_rs.union_dimensions(&[0, 1, 2])?;
assert_eq!(union.len(), 100);
for i in 0..100 {
let expected = i % 4 == 0 || i % 6 == 0 || i % 8 == 0;
let actual = union.get(i).ok_or_else(|| ZiporaError::out_of_bounds(i, 100))?;
assert_eq!(actual, expected, "Bit {} mismatch", i);
}
Ok(())
}
#[test]
fn test_invalid_dimension_sizes() {
let mut dims = vec![];
dims.push(create_test_bitvector(100, |i| i % 2 == 0).unwrap());
dims.push(create_test_bitvector(200, |i| i % 2 == 0).unwrap());
let result: Result<MultiDimRankSelect<2>> = MultiDimRankSelect::new(dims);
assert!(result.is_err());
}
#[test]
fn test_single_dimension() -> Result<()> {
let mut dims = vec![];
dims.push(create_test_bitvector(100, |i| i % 5 == 0)?);
let multi_rs: MultiDimRankSelect<1> = MultiDimRankSelect::new(dims)?;
assert_eq!(multi_rs.num_dimensions(), 1);
let positions = [50];
let ranks = multi_rs.bulk_rank_multidim(&positions);
assert_eq!(ranks.len(), 1);
Ok(())
}
#[test]
fn test_high_dimensional() -> Result<()> {
let mut dims = vec![];
for d in 0..8 {
dims.push(create_test_bitvector(100, |i| i % (d + 2) == 0)?);
}
let multi_rs: MultiDimRankSelect<8> = MultiDimRankSelect::new(dims)?;
assert_eq!(multi_rs.num_dimensions(), 8);
let positions = [10, 20, 30, 40, 50, 60, 70, 80];
let ranks = multi_rs.bulk_rank_multidim(&positions);
assert_eq!(ranks.len(), 8);
Ok(())
}
#[test]
fn test_intersect_simd_remainder_boundary() -> Result<()> {
for &size in &[1, 63, 65, 192, 256, 257, 500, 501, 1024, 1025] {
let mut dims = vec![];
dims.push(create_test_bitvector(size, |i| i % 2 == 0)?);
dims.push(create_test_bitvector(size, |i| i % 3 == 0)?);
let multi_rs: MultiDimRankSelect<2> = MultiDimRankSelect::new(dims)?;
let intersection = multi_rs.intersect_dimensions(0, 1)?;
assert_eq!(intersection.len(), size, "Intersection length mismatch for size {}", size);
for i in 0..size {
let expected = i % 2 == 0 && i % 3 == 0;
let actual = intersection.get(i)
.ok_or_else(|| ZiporaError::out_of_bounds(i, size))?;
assert_eq!(actual, expected, "Intersect bit {} mismatch for size {}", i, size);
}
}
Ok(())
}
#[test]
fn test_union_simd_remainder_boundary() -> Result<()> {
for &size in &[1, 63, 65, 192, 256, 257, 500, 501, 1024, 1025] {
let mut dims = vec![];
dims.push(create_test_bitvector(size, |i| i % 4 == 0)?);
dims.push(create_test_bitvector(size, |i| i % 6 == 0)?);
dims.push(create_test_bitvector(size, |i| i % 8 == 0)?);
let multi_rs: MultiDimRankSelect<3> = MultiDimRankSelect::new(dims)?;
let union = multi_rs.union_dimensions(&[0, 1, 2])?;
assert_eq!(union.len(), size, "Union length mismatch for size {}", size);
for i in 0..size {
let expected = i % 4 == 0 || i % 6 == 0 || i % 8 == 0;
let actual = union.get(i)
.ok_or_else(|| ZiporaError::out_of_bounds(i, size))?;
assert_eq!(actual, expected, "Union bit {} mismatch for size {}", i, size);
}
}
Ok(())
}
}