use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use arrow_schema::SchemaRef;
use lance_core::Result;
use uuid::Uuid;
use super::data_source::{LsmDataSource, LsmGeneration, ShardSnapshot};
use crate::dataset::Dataset;
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
#[derive(Clone)]
pub struct InMemoryMemTableRef {
pub batch_store: Arc<BatchStore>,
pub index_store: Arc<IndexStore>,
pub schema: SchemaRef,
pub generation: u64,
}
pub type ActiveMemTableRef = InMemoryMemTableRef;
#[derive(Clone)]
pub struct InMemoryMemTables {
pub active: InMemoryMemTableRef,
pub frozen: Vec<InMemoryMemTableRef>,
}
pub struct LsmDataSourceCollector {
base_table: Option<Arc<Dataset>>,
base_path: String,
shard_snapshots: Vec<ShardSnapshot>,
in_memory_memtables: HashMap<Uuid, InMemoryMemTables>,
}
impl LsmDataSourceCollector {
pub fn new(base_table: Arc<Dataset>, shard_snapshots: Vec<ShardSnapshot>) -> Self {
let base_path = base_table.uri().trim_end_matches('/').to_string();
Self {
base_table: Some(base_table),
base_path,
shard_snapshots,
in_memory_memtables: HashMap::new(),
}
}
pub fn without_base_table(
base_path: impl Into<String>,
shard_snapshots: Vec<ShardSnapshot>,
) -> Self {
Self {
base_table: None,
base_path: base_path.into().trim_end_matches('/').to_string(),
shard_snapshots,
in_memory_memtables: HashMap::new(),
}
}
pub fn with_active_memtable(mut self, shard_id: Uuid, memtable: InMemoryMemTableRef) -> Self {
match self.in_memory_memtables.entry(shard_id) {
Entry::Occupied(mut e) => e.get_mut().active = memtable,
Entry::Vacant(e) => {
e.insert(InMemoryMemTables {
active: memtable,
frozen: Vec::new(),
});
}
}
self
}
pub fn with_in_memory_memtables(
mut self,
shard_id: Uuid,
memtables: InMemoryMemTables,
) -> Self {
self.in_memory_memtables.insert(shard_id, memtables);
self
}
pub fn base_table(&self) -> Option<&Arc<Dataset>> {
self.base_table.as_ref()
}
pub fn shard_snapshots(&self) -> &[ShardSnapshot] {
&self.shard_snapshots
}
pub fn in_memory_memtables(&self) -> &HashMap<Uuid, InMemoryMemTables> {
&self.in_memory_memtables
}
fn in_memory_sources(shard_id: Uuid, mems: &InMemoryMemTables) -> Vec<LsmDataSource> {
let mut refs: Vec<&InMemoryMemTableRef> = std::iter::once(&mems.active)
.chain(mems.frozen.iter())
.collect();
refs.sort_by_key(|m| m.generation);
refs.into_iter()
.map(|m| LsmDataSource::ActiveMemTable {
batch_store: m.batch_store.clone(),
index_store: m.index_store.clone(),
schema: m.schema.clone(),
shard_id,
generation: LsmGeneration::memtable(m.generation),
})
.collect()
}
pub fn collect(&self) -> Result<Vec<LsmDataSource>> {
let mut sources = Vec::new();
if let Some(base) = &self.base_table {
sources.push(LsmDataSource::BaseTable {
dataset: base.clone(),
});
}
for snapshot in &self.shard_snapshots {
for flushed in &snapshot.flushed_generations {
let path = self.resolve_flushed_path(&snapshot.shard_id, &flushed.path);
sources.push(LsmDataSource::FlushedMemTable {
path,
shard_id: snapshot.shard_id,
generation: LsmGeneration::memtable(flushed.generation),
});
}
}
for (shard_id, mems) in &self.in_memory_memtables {
sources.extend(Self::in_memory_sources(*shard_id, mems));
}
Ok(sources)
}
pub fn collect_for_shards(&self, shard_ids: &HashSet<Uuid>) -> Result<Vec<LsmDataSource>> {
let mut sources = Vec::new();
if let Some(base) = &self.base_table {
sources.push(LsmDataSource::BaseTable {
dataset: base.clone(),
});
}
for snapshot in &self.shard_snapshots {
if !shard_ids.contains(&snapshot.shard_id) {
continue;
}
for flushed in &snapshot.flushed_generations {
let path = self.resolve_flushed_path(&snapshot.shard_id, &flushed.path);
sources.push(LsmDataSource::FlushedMemTable {
path,
shard_id: snapshot.shard_id,
generation: LsmGeneration::memtable(flushed.generation),
});
}
}
for (shard_id, mems) in &self.in_memory_memtables {
if !shard_ids.contains(shard_id) {
continue;
}
sources.extend(Self::in_memory_sources(*shard_id, mems));
}
Ok(sources)
}
pub fn num_sources(&self) -> usize {
let flushed_count: usize = self
.shard_snapshots
.iter()
.map(|s| s.flushed_generations.len())
.sum();
let base_count = if self.base_table.is_some() { 1 } else { 0 };
let in_memory_count: usize = self
.in_memory_memtables
.values()
.map(|m| 1 + m.frozen.len())
.sum();
base_count + flushed_count + in_memory_count
}
fn resolve_flushed_path(&self, shard_id: &Uuid, folder_name: &str) -> String {
format!("{}/_mem_wal/{}/{}", self.base_path, shard_id, folder_name)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::mem_wal::scanner::data_source::FlushedGeneration;
fn create_test_snapshots() -> Vec<ShardSnapshot> {
let shard_a = Uuid::new_v4();
let shard_b = Uuid::new_v4();
vec![
ShardSnapshot {
shard_id: shard_a,
spec_id: 1,
current_generation: 3,
flushed_generations: vec![
FlushedGeneration {
generation: 1,
path: "abc_gen_1".to_string(),
},
FlushedGeneration {
generation: 2,
path: "def_gen_2".to_string(),
},
],
},
ShardSnapshot {
shard_id: shard_b,
spec_id: 1,
current_generation: 2,
flushed_generations: vec![FlushedGeneration {
generation: 1,
path: "xyz_gen_1".to_string(),
}],
},
]
}
#[test]
fn test_collector_num_sources() {
let snapshots = create_test_snapshots();
assert_eq!(snapshots[0].flushed_generations.len(), 2);
assert_eq!(snapshots[1].flushed_generations.len(), 1);
}
#[test]
fn test_in_memory_memtable_ref() {
let batch_store = Arc::new(BatchStore::with_capacity(100));
let index_store = Arc::new(IndexStore::new());
let schema = Arc::new(arrow_schema::Schema::empty());
let memtable_ref = InMemoryMemTableRef {
batch_store,
index_store,
schema,
generation: 5,
};
assert_eq!(memtable_ref.generation, 5);
}
fn memtable_ref(generation: u64) -> InMemoryMemTableRef {
InMemoryMemTableRef {
batch_store: Arc::new(BatchStore::with_capacity(8)),
index_store: Arc::new(IndexStore::new()),
schema: Arc::new(arrow_schema::Schema::empty()),
generation,
}
}
#[test]
fn test_collect_includes_active_and_frozen() {
let shard = Uuid::new_v4();
let other = Uuid::new_v4();
let mems = InMemoryMemTables {
active: memtable_ref(5),
frozen: vec![memtable_ref(4), memtable_ref(3)],
};
let collector = LsmDataSourceCollector::without_base_table("/tmp/x", vec![])
.with_in_memory_memtables(shard, mems);
assert_eq!(collector.num_sources(), 3);
let sources = collector.collect().unwrap();
assert_eq!(sources.len(), 3);
assert!(sources.iter().all(|s| s.is_active_memtable()));
assert!(sources.iter().all(|s| s.shard_id() == Some(shard)));
let gens: Vec<u64> = sources.iter().map(|s| s.generation().as_u64()).collect();
assert_eq!(gens, vec![3, 4, 5]);
assert!(
collector
.collect_for_shards(&HashSet::from([other]))
.unwrap()
.is_empty()
);
assert_eq!(
collector
.collect_for_shards(&HashSet::from([shard]))
.unwrap()
.len(),
3
);
}
}