use std::collections::HashMap;
use arrow_array::{Array, RecordBatch};
use roaring::RoaringBitmap;
use crate::error::IndexError;
use crate::filter::FilterIndex;
use crate::scalar_value::{OwnedScalar, extract_scalar};
pub struct HashIndex {
map: HashMap<OwnedScalar, RoaringBitmap>,
total_rows: u32,
}
impl HashIndex {
#[allow(clippy::cast_possible_truncation)]
pub fn build(array: &dyn Array) -> Result<Self, IndexError> {
let n = array.len();
if n as u64 > u64::from(u32::MAX) {
return Err(IndexError::TooManyRows(n as u64));
}
let mut map: HashMap<OwnedScalar, RoaringBitmap> = HashMap::new();
for i in 0..n {
let scalar = extract_scalar(array, i)
.ok_or_else(|| IndexError::UnsupportedType(format!("{:?}", array.data_type())))?;
map.entry(scalar).or_default().insert(i as u32);
}
Ok(Self {
map,
total_rows: n as u32,
})
}
#[allow(clippy::cast_possible_truncation)]
pub fn build_batches(
batches: impl Iterator<Item = (RecordBatch, usize)>,
) -> Result<Self, IndexError> {
let mut map: HashMap<OwnedScalar, RoaringBitmap> = HashMap::new();
let mut offset: u64 = 0;
for (batch, col_idx) in batches {
let array = batch.column(col_idx);
let n = array.len() as u64;
if offset + n > u64::from(u32::MAX) {
return Err(IndexError::TooManyRows(offset + n));
}
for i in 0..array.len() {
let scalar = extract_scalar(array.as_ref(), i).ok_or_else(|| {
IndexError::UnsupportedType(format!("{:?}", array.data_type()))
})?;
map.entry(scalar)
.or_default()
.insert((offset + i as u64) as u32);
}
offset += n;
}
Ok(Self {
map,
total_rows: offset as u32,
})
}
pub fn lookup(&self, value: &OwnedScalar) -> FilterIndex {
match self.map.get(value) {
Some(bitmap) => FilterIndex::from_bitmap_ref(bitmap),
None => FilterIndex::from_ids(std::iter::empty::<u32>()),
}
}
pub fn lookup_many(&self, values: &[OwnedScalar]) -> FilterIndex {
let mut result = RoaringBitmap::new();
for value in values {
if let Some(bitmap) = self.map.get(value) {
result |= bitmap;
}
}
FilterIndex::from_bitmap(result)
}
pub fn distinct_count(&self) -> usize {
self.map.len()
}
pub fn total_rows(&self) -> u32 {
self.total_rows
}
pub fn value_counts(&self) -> impl Iterator<Item = (&OwnedScalar, u64)> {
self.map.iter().map(|(k, v)| (k, v.len()))
}
}