use std::borrow::Cow;
use std::cell::OnceCell;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use async_trait::async_trait;
use futures::FutureExt;
use futures::StreamExt;
use futures::TryStreamExt;
use futures::future::BoxFuture;
use futures::stream;
use lance_core::utils::deletion::DeletionVector;
use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap};
use lance_core::utils::tokio::spawn_cpu;
use lance_table::format::Fragment;
use lance_table::format::IndexMetadata;
use lance_table::rowids::RowIdSequence;
use roaring::RoaringBitmap;
use tokio::join;
use tracing::Instrument;
use tracing::instrument;
use crate::Dataset;
use crate::Result;
use crate::dataset::fragment::FileFragment;
use crate::dataset::rowids::load_row_id_sequence;
use crate::utils::future::SharedPrerequisite;
pub use lance_index::prefilter::{FilterLoader, PreFilter};
pub struct DatasetPreFilter {
pub(super) deleted_ids: Option<Arc<SharedPrerequisite<Arc<RowAddrMask>>>>,
pub(super) filtered_ids: Option<Arc<SharedPrerequisite<RowAddrMask>>>,
pub(super) deleted_fragments: Option<RoaringBitmap>,
pub(super) final_mask: Mutex<OnceCell<Arc<RowAddrMask>>>,
}
impl DatasetPreFilter {
pub fn new(
dataset: Arc<Dataset>,
indices: &[IndexMetadata],
filter: Option<Box<dyn FilterLoader>>,
) -> Self {
let mut fragments = RoaringBitmap::new();
if indices.iter().any(|idx| idx.fragment_bitmap.is_none()) {
fragments.insert_range(0..dataset.manifest.max_fragment_id.unwrap_or(0));
} else {
indices.iter().for_each(|idx| {
fragments |= idx.fragment_bitmap.as_ref().unwrap();
});
}
let deleted_ids =
Self::create_deletion_mask(dataset, fragments).map(SharedPrerequisite::spawn);
let filtered_ids = filter
.map(|filtered_ids| SharedPrerequisite::spawn(filtered_ids.load().in_current_span()));
Self {
deleted_ids,
filtered_ids,
deleted_fragments: None,
final_mask: Mutex::new(OnceCell::new()),
}
}
#[instrument(level = "debug", skip_all)]
async fn do_create_deletion_mask(
dataset: Arc<Dataset>,
missing_frags: Vec<u32>,
frags_with_deletion_files: Vec<u32>,
) -> Result<Arc<RowAddrMask>> {
let fragments = dataset.get_fragments();
let frag_map: Arc<HashMap<u32, &FileFragment>> = Arc::new(HashMap::from_iter(
fragments.iter().map(|frag| (frag.id() as u32, frag)),
));
let frag_id_deletion_vectors = stream::iter(
frags_with_deletion_files
.iter()
.map(|frag_id| (frag_id, frag_map.clone())),
)
.map(|(frag_id, frag_map)| async move {
let frag = frag_map.get(frag_id).unwrap();
frag.get_deletion_vector()
.await
.transpose()
.unwrap()
.map(|deletion_vector| (*frag_id, RoaringBitmap::from(deletion_vector.as_ref())))
})
.collect::<Vec<_>>()
.await;
let mut frag_id_deletion_vectors = stream::iter(frag_id_deletion_vectors)
.buffer_unordered(dataset.object_store.io_parallelism());
let mut deleted_ids = RowAddrTreeMap::new();
while let Some((id, deletion_vector)) = frag_id_deletion_vectors.try_next().await? {
deleted_ids.insert_bitmap(id, deletion_vector);
}
for frag_id in missing_frags.into_iter() {
deleted_ids.insert_fragment(frag_id);
}
Ok(Arc::new(RowAddrMask::from_block(deleted_ids)))
}
#[instrument(level = "debug", skip_all)]
async fn do_create_deletion_mask_row_id(dataset: Arc<Dataset>) -> Result<Arc<RowAddrMask>> {
async fn load_row_ids_and_deletions(
dataset: &Dataset,
) -> Result<Vec<(Arc<RowIdSequence>, Option<Arc<DeletionVector>>)>> {
stream::iter(dataset.get_fragments())
.map(|frag| async move {
let row_ids = load_row_id_sequence(dataset, frag.metadata());
let deletion_vector = frag.get_deletion_vector();
let (row_ids, deletion_vector) = join!(row_ids, deletion_vector);
Ok::<_, crate::Error>((row_ids?, deletion_vector?))
})
.buffer_unordered(dataset.object_store().io_parallelism())
.try_collect::<Vec<_>>()
.await
}
let dataset_clone = dataset.clone();
let key = crate::session::caches::RowAddrMaskKey {
version: dataset.manifest().version,
};
dataset
.metadata_cache
.as_ref()
.get_or_insert_with_key(key, move || {
async move {
let row_ids_and_deletions = load_row_ids_and_deletions(&dataset_clone).await?;
let allow_list = spawn_cpu(move || {
Result::Ok(row_ids_and_deletions.into_iter().fold(
RowAddrTreeMap::new(),
|mut allow_list, (row_ids, deletion_vector)| {
let seq = if let Some(deletion_vector) = deletion_vector {
let mut row_ids = row_ids.as_ref().clone();
row_ids.mask(deletion_vector.to_sorted_iter()).unwrap();
Cow::<RowIdSequence>::Owned(row_ids)
} else {
Cow::<RowIdSequence>::Borrowed(row_ids.as_ref())
};
let treemap = RowAddrTreeMap::from(seq.as_ref());
allow_list |= treemap;
allow_list
},
))
})
.await?;
Ok(RowAddrMask::from_allowed(allow_list))
}
})
.await
}
pub fn set_deleted_fragments(&mut self, fragments: RoaringBitmap) {
self.deleted_fragments = Some(fragments);
}
pub fn create_deletion_mask(
dataset: Arc<Dataset>,
fragments: RoaringBitmap,
) -> Option<BoxFuture<'static, Result<Arc<RowAddrMask>>>> {
let mut missing_frags = Vec::new();
let mut frags_with_deletion_files = Vec::new();
let frag_map: HashMap<u32, &Fragment> = HashMap::from_iter(
dataset
.manifest
.fragments
.iter()
.map(|frag| (frag.id as u32, frag)),
);
for frag_id in fragments.iter() {
let frag = frag_map.get(&frag_id);
if let Some(frag) = frag {
if frag.deletion_file.is_some() {
frags_with_deletion_files.push(frag_id);
}
} else {
missing_frags.push(frag_id);
}
}
if missing_frags.is_empty() && frags_with_deletion_files.is_empty() {
None
} else if dataset.manifest.uses_stable_row_ids() {
Some(Self::do_create_deletion_mask_row_id(dataset.clone()).boxed())
} else {
Some(
Self::do_create_deletion_mask(dataset, missing_frags, frags_with_deletion_files)
.boxed(),
)
}
}
}
#[async_trait]
impl PreFilter for DatasetPreFilter {
#[instrument(level = "debug", skip(self))]
async fn wait_for_ready(&self) -> Result<()> {
if let Some(filtered_ids) = &self.filtered_ids {
filtered_ids.wait_ready().await?;
}
if let Some(deleted_ids) = &self.deleted_ids {
deleted_ids.wait_ready().await?;
}
let final_mask = self.final_mask.lock().unwrap();
final_mask.get_or_init(|| {
let mut combined = RowAddrMask::default();
if let Some(filtered_ids) = &self.filtered_ids {
combined = combined & filtered_ids.get_ready();
}
if let Some(deleted_ids) = &self.deleted_ids {
combined = combined & (*deleted_ids.get_ready()).clone();
}
if let Some(deleted) = &self.deleted_fragments {
let mut block_list = RowAddrTreeMap::new();
for frag_id in deleted.iter() {
block_list.insert_fragment(frag_id);
}
combined = combined & RowAddrMask::from_block(block_list);
}
Arc::new(combined)
});
Ok(())
}
fn is_empty(&self) -> bool {
self.deleted_ids.is_none()
&& self.filtered_ids.is_none()
&& self.deleted_fragments.is_none()
}
fn mask(&self) -> Arc<RowAddrMask> {
self.final_mask
.lock()
.unwrap()
.get()
.expect("mask called without call to wait_for_ready")
.clone()
}
#[instrument(level = "debug", skip_all)]
fn filter_row_ids<'a>(&self, row_ids: Box<dyn Iterator<Item = &'a u64> + 'a>) -> Vec<u64> {
self.mask().selected_indices(row_ids)
}
}
#[cfg(test)]
mod test {
use lance_core::utils::mask::RowSetOps;
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
use crate::dataset::WriteParams;
use super::*;
struct TestDatasets {
no_deletions: Arc<Dataset>,
deletions_no_missing_frags: Arc<Dataset>,
deletions_missing_frags: Arc<Dataset>,
only_missing_frags: Arc<Dataset>,
}
async fn test_datasets(use_stable_row_id: bool) -> TestDatasets {
let test_data = BatchGenerator::new()
.col(Box::new(IncrementingInt32::new().named("x")))
.batch(9);
let mut dataset = Dataset::write(
test_data,
"memory://test",
Some(WriteParams {
max_rows_per_file: 3,
enable_stable_row_ids: use_stable_row_id,
..Default::default()
}),
)
.await
.unwrap();
let no_deletions = Arc::new(dataset.clone());
dataset.delete("x = 8").await.unwrap();
let deletions_no_missing_frags = Arc::new(dataset.clone());
dataset.delete("x >= 3 and x <= 5").await.unwrap();
assert_eq!(dataset.get_fragments().len(), 2);
let deletions_missing_frags = Arc::new(dataset.clone());
dataset.delete("x >= 3").await.unwrap();
assert_eq!(dataset.get_fragments().len(), 1);
assert!(
dataset.get_fragments()[0]
.metadata()
.deletion_file
.is_none()
);
let only_missing_frags = Arc::new(dataset.clone());
TestDatasets {
no_deletions,
deletions_no_missing_frags,
deletions_missing_frags,
only_missing_frags,
}
}
#[tokio::test]
async fn test_deletion_mask() {
let datasets = test_datasets(false).await;
let mask = DatasetPreFilter::create_deletion_mask(
datasets.no_deletions.clone(),
RoaringBitmap::from_iter(0..3),
);
assert!(mask.is_none());
let mask = DatasetPreFilter::create_deletion_mask(
datasets.deletions_no_missing_frags.clone(),
RoaringBitmap::from_iter(0..3),
);
assert!(mask.is_some());
let mask = mask.unwrap().await.unwrap();
assert_eq!(mask.block_list().and_then(|x| x.len()), Some(1));
let mask = DatasetPreFilter::create_deletion_mask(
datasets.deletions_missing_frags.clone(),
RoaringBitmap::from_iter(0..3),
);
assert!(mask.is_some());
let mask = mask.unwrap().await.unwrap();
let mut expected = RowAddrTreeMap::from_iter(vec![(2 << 32) + 2]);
expected.insert_fragment(1);
assert_eq!(mask.block_list(), Some(&expected));
let mask = DatasetPreFilter::create_deletion_mask(
datasets.deletions_missing_frags.clone(),
RoaringBitmap::from_iter(2..3),
);
assert!(mask.is_some());
let mask = mask.unwrap().await.unwrap();
assert_eq!(mask.block_list().and_then(|x| x.len()), Some(1));
let mask = DatasetPreFilter::create_deletion_mask(
datasets.only_missing_frags.clone(),
RoaringBitmap::from_iter(0..3),
);
assert!(mask.is_some());
let mask = mask.unwrap().await.unwrap();
let mut expected = RowAddrTreeMap::new();
expected.insert_fragment(1);
expected.insert_fragment(2);
assert_eq!(mask.block_list(), Some(&expected));
}
#[tokio::test]
async fn test_deletion_mask_stable_row_id() {
let datasets = test_datasets(true).await;
let mask = DatasetPreFilter::create_deletion_mask(
datasets.no_deletions.clone(),
RoaringBitmap::from_iter(0..3),
);
assert!(mask.is_none());
let mask = DatasetPreFilter::create_deletion_mask(
datasets.deletions_no_missing_frags.clone(),
RoaringBitmap::from_iter(0..3),
);
assert!(mask.is_some());
let mask = mask.unwrap().await.unwrap();
let expected = RowAddrTreeMap::from_iter(0..8);
assert_eq!(mask.allow_list(), Some(&expected));
let mask = DatasetPreFilter::create_deletion_mask(
datasets.deletions_missing_frags.clone(),
RoaringBitmap::from_iter(0..2),
);
assert!(mask.is_some());
let mask = mask.unwrap().await.unwrap();
assert_eq!(mask.allow_list().and_then(|x| x.len()), Some(5));
let mask = DatasetPreFilter::create_deletion_mask(
datasets.only_missing_frags.clone(),
RoaringBitmap::from_iter(0..3),
);
assert!(mask.is_some());
let mask = mask.unwrap().await.unwrap();
assert_eq!(mask.allow_list().and_then(|x| x.len()), Some(3)); }
}