use std::cell::OnceCell;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use async_trait::async_trait;
use futures::future::BoxFuture;
use futures::stream;
use futures::FutureExt;
use futures::StreamExt;
use futures::TryStreamExt;
use lance_core::utils::mask::RowIdMask;
use lance_core::utils::mask::RowIdTreeMap;
use lance_table::format::Fragment;
use lance_table::format::Index;
use roaring::RoaringBitmap;
use tracing::instrument;
use tracing::Instrument;
use crate::dataset::fragment::FileFragment;
use crate::error::Result;
use crate::utils::future::SharedPrerequisite;
use crate::Dataset;
#[async_trait]
pub trait FilterLoader: Send + 'static {
async fn load(self: Box<Self>) -> Result<RowIdMask>;
}
pub struct PreFilter {
deleted_ids: Option<Arc<SharedPrerequisite<Arc<RowIdTreeMap>>>>,
filtered_ids: Option<Arc<SharedPrerequisite<RowIdMask>>>,
final_mask: Mutex<OnceCell<RowIdMask>>,
}
impl PreFilter {
pub fn new(
dataset: Arc<Dataset>,
indices: &[Index],
filter: Option<Box<dyn FilterLoader>>,
) -> Self {
let mut fragments = RoaringBitmap::new();
if indices.iter().any(|idx| idx.fragment_bitmap.is_none()) {
fragments.insert_range(0..dataset.manifest.max_fragment_id);
} else {
indices.iter().for_each(|idx| {
fragments |= idx.fragment_bitmap.as_ref().unwrap();
});
}
let deleted_ids =
Self::create_deletion_mask(dataset.clone(), fragments).map(SharedPrerequisite::spawn);
let filtered_ids = filter
.map(|filtered_ids| SharedPrerequisite::spawn(filtered_ids.load().in_current_span()));
Self {
deleted_ids,
filtered_ids,
final_mask: Mutex::new(OnceCell::new()),
}
}
pub fn is_empty(&self) -> bool {
self.deleted_ids.is_none() && self.filtered_ids.is_none()
}
pub fn check_one(&self, row_id: u64) -> bool {
let final_mask = self.final_mask.lock().unwrap();
final_mask
.get()
.expect("check_one called before wait_ready")
.selected(row_id)
}
#[instrument(level = "debug", skip_all)]
async fn do_create_deletion_mask(
dataset: Arc<Dataset>,
missing_frags: Vec<u32>,
frags_with_deletion_files: Vec<u32>,
) -> Result<Arc<RowIdTreeMap>> {
let fragments = dataset.get_fragments();
let frag_map: Arc<HashMap<u32, &FileFragment>> = Arc::new(HashMap::from_iter(
fragments.iter().map(|frag| (frag.id() as u32, frag)),
));
let frag_id_deletion_vectors = stream::iter(
frags_with_deletion_files
.iter()
.map(|frag_id| (frag_id, frag_map.clone())),
)
.map(|(frag_id, frag_map)| async move {
let frag = frag_map.get(frag_id).unwrap();
frag.get_deletion_vector()
.await
.transpose()
.unwrap()
.map(|deletion_vector| (*frag_id, RoaringBitmap::from(deletion_vector.as_ref())))
})
.collect::<Vec<_>>()
.await;
let mut frag_id_deletion_vectors =
stream::iter(frag_id_deletion_vectors).buffer_unordered(num_cpus::get());
let mut deleted_ids = RowIdTreeMap::new();
while let Some((id, deletion_vector)) = frag_id_deletion_vectors.try_next().await? {
deleted_ids.insert_bitmap(id, deletion_vector);
}
for frag_id in missing_frags.into_iter() {
deleted_ids.insert_fragment(frag_id);
}
Ok(Arc::new(deleted_ids))
}
pub fn create_deletion_mask(
dataset: Arc<Dataset>,
fragments: RoaringBitmap,
) -> Option<BoxFuture<'static, Result<Arc<RowIdTreeMap>>>> {
let mut missing_frags = Vec::new();
let mut frags_with_deletion_files = Vec::new();
let frag_map: HashMap<u32, &Fragment> = HashMap::from_iter(
dataset
.manifest
.fragments
.iter()
.map(|frag| (frag.id as u32, frag)),
);
for frag_id in fragments.iter() {
let frag = frag_map.get(&frag_id);
if let Some(frag) = frag {
if frag.deletion_file.is_some() {
frags_with_deletion_files.push(frag_id);
}
} else {
missing_frags.push(frag_id);
}
}
if missing_frags.is_empty() && frags_with_deletion_files.is_empty() {
None
} else {
Some(
Self::do_create_deletion_mask(dataset, missing_frags, frags_with_deletion_files)
.boxed(),
)
}
}
pub async fn wait_for_ready(&self) -> Result<()> {
if let Some(filtered_ids) = &self.filtered_ids {
filtered_ids.wait_ready().await?;
}
if let Some(deleted_ids) = &self.deleted_ids {
deleted_ids.wait_ready().await?;
}
let final_mask = self.final_mask.lock().unwrap();
final_mask.get_or_init(|| {
let mut combined = RowIdMask::default();
if let Some(filtered_ids) = &self.filtered_ids {
combined = combined & filtered_ids.get_ready();
}
if let Some(deleted_ids) = &self.deleted_ids {
combined = combined.also_block((*deleted_ids.get_ready()).clone());
}
combined
});
Ok(())
}
#[instrument(level = "debug", skip_all)]
pub fn filter_row_ids(&self, row_ids: &[u64]) -> Vec<u64> {
let final_mask = self.final_mask.lock().unwrap();
final_mask
.get()
.expect("filter_row_ids called without call to wait_for_ready")
.selected_indices(row_ids)
}
}