use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use arrow_schema::SchemaRef;
use lance_core::Result;
use uuid::Uuid;
use super::data_source::{LsmDataSource, LsmGeneration, RegionSnapshot};
use crate::dataset::Dataset;
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
#[derive(Clone)]
pub struct ActiveMemTableRef {
pub batch_store: Arc<BatchStore>,
pub index_store: Arc<IndexStore>,
pub schema: SchemaRef,
pub generation: u64,
}
pub struct LsmDataSourceCollector {
base_table: Arc<Dataset>,
base_path: String,
region_snapshots: Vec<RegionSnapshot>,
active_memtables: HashMap<Uuid, ActiveMemTableRef>,
}
impl LsmDataSourceCollector {
pub fn new(base_table: Arc<Dataset>, region_snapshots: Vec<RegionSnapshot>) -> Self {
let base_path = base_table.uri().trim_end_matches('/').to_string();
Self {
base_table,
base_path,
region_snapshots,
active_memtables: HashMap::new(),
}
}
pub fn with_active_memtable(mut self, region_id: Uuid, memtable: ActiveMemTableRef) -> Self {
self.active_memtables.insert(region_id, memtable);
self
}
pub fn base_table(&self) -> &Arc<Dataset> {
&self.base_table
}
pub fn region_snapshots(&self) -> &[RegionSnapshot] {
&self.region_snapshots
}
pub fn active_memtables(&self) -> &HashMap<Uuid, ActiveMemTableRef> {
&self.active_memtables
}
pub fn collect(&self) -> Result<Vec<LsmDataSource>> {
let mut sources = Vec::new();
sources.push(LsmDataSource::BaseTable {
dataset: self.base_table.clone(),
});
for snapshot in &self.region_snapshots {
for flushed in &snapshot.flushed_generations {
let path = self.resolve_flushed_path(&snapshot.region_id, &flushed.path);
sources.push(LsmDataSource::FlushedMemTable {
path,
region_id: snapshot.region_id,
generation: LsmGeneration::memtable(flushed.generation),
});
}
}
for (region_id, memtable) in &self.active_memtables {
sources.push(LsmDataSource::ActiveMemTable {
batch_store: memtable.batch_store.clone(),
index_store: memtable.index_store.clone(),
schema: memtable.schema.clone(),
region_id: *region_id,
generation: LsmGeneration::memtable(memtable.generation),
});
}
Ok(sources)
}
pub fn collect_for_regions(&self, region_ids: &HashSet<Uuid>) -> Result<Vec<LsmDataSource>> {
let mut sources = Vec::new();
sources.push(LsmDataSource::BaseTable {
dataset: self.base_table.clone(),
});
for snapshot in &self.region_snapshots {
if !region_ids.contains(&snapshot.region_id) {
continue;
}
for flushed in &snapshot.flushed_generations {
let path = self.resolve_flushed_path(&snapshot.region_id, &flushed.path);
sources.push(LsmDataSource::FlushedMemTable {
path,
region_id: snapshot.region_id,
generation: LsmGeneration::memtable(flushed.generation),
});
}
}
for (region_id, memtable) in &self.active_memtables {
if !region_ids.contains(region_id) {
continue;
}
sources.push(LsmDataSource::ActiveMemTable {
batch_store: memtable.batch_store.clone(),
index_store: memtable.index_store.clone(),
schema: memtable.schema.clone(),
region_id: *region_id,
generation: LsmGeneration::memtable(memtable.generation),
});
}
Ok(sources)
}
pub fn num_sources(&self) -> usize {
let flushed_count: usize = self
.region_snapshots
.iter()
.map(|s| s.flushed_generations.len())
.sum();
1 + flushed_count + self.active_memtables.len()
}
fn resolve_flushed_path(&self, region_id: &Uuid, folder_name: &str) -> String {
format!("{}/_mem_wal/{}/{}", self.base_path, region_id, folder_name)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::mem_wal::scanner::data_source::FlushedGeneration;
fn create_test_snapshots() -> Vec<RegionSnapshot> {
let region_a = Uuid::new_v4();
let region_b = Uuid::new_v4();
vec![
RegionSnapshot {
region_id: region_a,
spec_id: 1,
current_generation: 3,
flushed_generations: vec![
FlushedGeneration {
generation: 1,
path: "abc_gen_1".to_string(),
},
FlushedGeneration {
generation: 2,
path: "def_gen_2".to_string(),
},
],
},
RegionSnapshot {
region_id: region_b,
spec_id: 1,
current_generation: 2,
flushed_generations: vec![FlushedGeneration {
generation: 1,
path: "xyz_gen_1".to_string(),
}],
},
]
}
#[test]
fn test_collector_num_sources() {
let snapshots = create_test_snapshots();
assert_eq!(snapshots[0].flushed_generations.len(), 2);
assert_eq!(snapshots[1].flushed_generations.len(), 1);
}
#[test]
fn test_active_memtable_ref() {
let batch_store = Arc::new(BatchStore::with_capacity(100));
let index_store = Arc::new(IndexStore::new());
let schema = Arc::new(arrow_schema::Schema::empty());
let memtable_ref = ActiveMemTableRef {
batch_store,
index_store,
schema,
generation: 5,
};
assert_eq!(memtable_ref.generation, 5);
}
}