use std::collections::HashSet;
use std::sync::Arc;
use lance_core::{Error, Result};
use crate::dataset::{Dataset, DatasetBuilder};
use crate::session::Session;
pub struct FlushedMemTableCache {
inner: moka::future::Cache<String, Arc<Dataset>>,
}
impl FlushedMemTableCache {
pub fn new(max_entries: u64) -> Self {
Self {
inner: moka::future::Cache::builder()
.max_capacity(max_entries)
.support_invalidation_closures()
.build(),
}
}
pub async fn get_or_open(
&self,
path: &str,
session: Option<Arc<Session>>,
) -> Result<Arc<Dataset>> {
self.inner
.try_get_with(path.to_string(), async move {
let mut builder = DatasetBuilder::from_uri(path);
if let Some(session) = session {
builder = builder.with_session(session);
}
builder.load().await.map(Arc::new)
})
.await
.map_err(|e: Arc<Error>| Error::cloned(e.to_string()))
}
pub fn retain_paths(&self, live_paths: &HashSet<String>) {
let live = live_paths.clone();
let _ = self
.inner
.invalidate_entries_if(move |path, _| !live.contains(path));
}
}
impl std::fmt::Debug for FlushedMemTableCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlushedMemTableCache")
.field("entry_count", &self.inner.entry_count())
.finish()
}
}
pub(super) async fn open_flushed_dataset(
path: &str,
session: Option<&Arc<Session>>,
cache: Option<&Arc<FlushedMemTableCache>>,
) -> Result<Arc<Dataset>> {
match cache {
Some(cache) => cache.get_or_open(path, session.cloned()).await,
None => {
let mut builder = DatasetBuilder::from_uri(path);
if let Some(session) = session {
builder = builder.with_session(session.clone());
}
Ok(Arc::new(builder.load().await?))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use crate::dataset::WriteParams;
async fn write_dataset(uri: &str, ids: &[i32]) {
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"id",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(ids.to_vec()))],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
Dataset::write(reader, uri, Some(WriteParams::default()))
.await
.unwrap();
}
#[tokio::test]
async fn test_hit_returns_same_arc() {
let temp_dir = tempfile::tempdir().unwrap();
let uri = format!("{}/gen_1", temp_dir.path().to_str().unwrap());
write_dataset(&uri, &[1, 2, 3]).await;
let cache = FlushedMemTableCache::new(8);
let first = cache.get_or_open(&uri, None).await.unwrap();
let second = cache.get_or_open(&uri, None).await.unwrap();
assert!(
Arc::ptr_eq(&first, &second),
"a cache hit must return the same Arc<Dataset>, not re-open"
);
assert_eq!(cache.inner.entry_count(), 0); cache.inner.run_pending_tasks().await;
assert_eq!(cache.inner.entry_count(), 1);
}
#[tokio::test]
async fn test_concurrent_get_or_open_single_flight() {
let temp_dir = tempfile::tempdir().unwrap();
let uri = format!("{}/gen_1", temp_dir.path().to_str().unwrap());
write_dataset(&uri, &[1, 2, 3]).await;
let cache = Arc::new(FlushedMemTableCache::new(8));
let calls = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..16 {
let cache = cache.clone();
let uri = uri.clone();
let calls = calls.clone();
handles.push(tokio::spawn(async move {
calls.fetch_add(1, Ordering::SeqCst);
cache.get_or_open(&uri, None).await.unwrap()
}));
}
let datasets: Vec<Arc<Dataset>> = futures::future::try_join_all(handles).await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 16, "all tasks ran");
let first = &datasets[0];
for ds in &datasets {
assert!(
Arc::ptr_eq(first, ds),
"all concurrent callers must share one opened dataset"
);
}
cache.inner.run_pending_tasks().await;
assert_eq!(cache.inner.entry_count(), 1, "exactly one entry cached");
}
#[tokio::test]
async fn test_retain_paths_drops_unreferenced() {
let temp_dir = tempfile::tempdir().unwrap();
let base = temp_dir.path().to_str().unwrap();
let keep_uri = format!("{}/gen_1", base);
let drop_uri = format!("{}/gen_2", base);
write_dataset(&keep_uri, &[1]).await;
write_dataset(&drop_uri, &[2]).await;
let cache = FlushedMemTableCache::new(8);
cache.get_or_open(&keep_uri, None).await.unwrap();
cache.get_or_open(&drop_uri, None).await.unwrap();
cache.inner.run_pending_tasks().await;
assert_eq!(cache.inner.entry_count(), 2);
let live: HashSet<String> = [keep_uri.clone()].into_iter().collect();
cache.retain_paths(&live);
cache.inner.run_pending_tasks().await;
assert_eq!(cache.inner.entry_count(), 1, "only live path retained");
assert!(cache.inner.contains_key(&keep_uri));
assert!(!cache.inner.contains_key(&drop_uri));
}
#[tokio::test]
async fn test_open_flushed_dataset_no_cache_matches_direct_open() {
let temp_dir = tempfile::tempdir().unwrap();
let uri = format!("{}/gen_1", temp_dir.path().to_str().unwrap());
write_dataset(&uri, &[7, 8, 9]).await;
let a = open_flushed_dataset(&uri, None, None).await.unwrap();
let b = open_flushed_dataset(&uri, None, None).await.unwrap();
assert!(
!Arc::ptr_eq(&a, &b),
"no-cache path must cold-open each call"
);
assert_eq!(a.count_rows(None).await.unwrap(), 3);
let cache = Arc::new(FlushedMemTableCache::new(8));
let c = open_flushed_dataset(&uri, None, Some(&cache))
.await
.unwrap();
let d = open_flushed_dataset(&uri, None, Some(&cache))
.await
.unwrap();
assert!(Arc::ptr_eq(&c, &d), "cached path must reuse the Arc");
}
}