#![allow(clippy::print_stderr)]
use std::collections::VecDeque;
use std::fmt::Debug;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock as StdRwLock};
use std::time::{Duration, Instant};
use arrow_array::RecordBatch;
use arrow_schema::Schema as ArrowSchema;
use async_trait::async_trait;
use lance_core::datatypes::Schema;
use lance_core::{Error, Result};
use lance_index::mem_wal::RegionManifest;
use lance_io::object_store::ObjectStore;
use log::{debug, error, info, warn};
use object_store::path::Path;
use tokio::sync::{RwLock, mpsc};
use tokio::task::JoinHandle;
use tokio::time::{Interval, interval_at};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
pub use super::index::{
BTreeIndexConfig, BTreeMemIndex, FtsIndexConfig, IndexStore, IvfPqIndexConfig, MemIndexConfig,
};
pub use super::memtable::CacheConfig;
pub use super::memtable::MemTable;
pub use super::memtable::batch_store::{BatchStore, StoreFull, StoredBatch};
pub use super::memtable::flush::MemTableFlusher;
pub use super::memtable::scanner::MemTableScanner;
pub use super::util::{WatchableOnceCell, WatchableOnceCellReader};
pub use super::wal::{WalEntry, WalEntryData, WalFlushResult, WalFlusher};
use super::memtable::flush::TriggerMemTableFlush;
use super::wal::TriggerWalFlush;
use super::manifest::RegionManifestStore;
#[derive(Debug, Clone)]
pub struct RegionWriterConfig {
pub region_id: Uuid,
pub region_spec_id: u32,
pub durable_write: bool,
pub sync_indexed_write: bool,
pub max_wal_buffer_size: usize,
pub max_wal_flush_interval: Option<Duration>,
pub max_memtable_size: usize,
pub max_memtable_rows: usize,
pub max_memtable_batches: usize,
pub ivf_index_partition_capacity_safety_factor: usize,
pub manifest_scan_batch_size: usize,
pub max_unflushed_memtable_bytes: usize,
pub backpressure_log_interval: Duration,
pub async_index_buffer_rows: usize,
pub async_index_interval: Duration,
pub stats_log_interval: Option<Duration>,
}
impl Default for RegionWriterConfig {
fn default() -> Self {
Self {
region_id: Uuid::new_v4(),
region_spec_id: 0,
durable_write: true,
sync_indexed_write: true,
max_wal_buffer_size: 10 * 1024 * 1024, max_wal_flush_interval: Some(Duration::from_millis(100)), max_memtable_size: 256 * 1024 * 1024, max_memtable_rows: 100_000, max_memtable_batches: 8_000, ivf_index_partition_capacity_safety_factor: 8,
manifest_scan_batch_size: 2,
max_unflushed_memtable_bytes: 1024 * 1024 * 1024, backpressure_log_interval: Duration::from_secs(30),
async_index_buffer_rows: 10_000,
async_index_interval: Duration::from_secs(1),
stats_log_interval: Some(Duration::from_secs(60)), }
}
}
impl RegionWriterConfig {
pub fn new(region_id: Uuid) -> Self {
Self {
region_id,
..Default::default()
}
}
pub fn with_region_spec_id(mut self, spec_id: u32) -> Self {
self.region_spec_id = spec_id;
self
}
pub fn with_durable_write(mut self, durable: bool) -> Self {
self.durable_write = durable;
self
}
pub fn with_sync_indexed_write(mut self, indexed: bool) -> Self {
self.sync_indexed_write = indexed;
self
}
pub fn with_max_wal_buffer_size(mut self, size: usize) -> Self {
self.max_wal_buffer_size = size;
self
}
pub fn with_max_wal_flush_interval(mut self, interval: Duration) -> Self {
self.max_wal_flush_interval = Some(interval);
self
}
pub fn with_max_memtable_size(mut self, size: usize) -> Self {
self.max_memtable_size = size;
self
}
pub fn with_max_memtable_rows(mut self, rows: usize) -> Self {
self.max_memtable_rows = rows;
self
}
pub fn with_max_memtable_batches(mut self, batches: usize) -> Self {
self.max_memtable_batches = batches;
self
}
pub fn with_ivf_index_partition_capacity_safety_factor(mut self, factor: usize) -> Self {
self.ivf_index_partition_capacity_safety_factor = factor;
self
}
pub fn with_manifest_scan_batch_size(mut self, size: usize) -> Self {
self.manifest_scan_batch_size = size;
self
}
pub fn with_max_unflushed_memtable_bytes(mut self, size: usize) -> Self {
self.max_unflushed_memtable_bytes = size;
self
}
pub fn with_backpressure_log_interval(mut self, interval: Duration) -> Self {
self.backpressure_log_interval = interval;
self
}
pub fn with_async_index_buffer_rows(mut self, rows: usize) -> Self {
self.async_index_buffer_rows = rows;
self
}
pub fn with_async_index_interval(mut self, interval: Duration) -> Self {
self.async_index_interval = interval;
self
}
pub fn with_stats_log_interval(mut self, interval: Option<Duration>) -> Self {
self.stats_log_interval = interval;
self
}
}
type MessageFactory<T> = Box<dyn Fn() -> T + Send + Sync>;
#[async_trait]
pub trait MessageHandler<T: Send + Debug + 'static>: Send {
fn tickers(&mut self) -> Vec<(Duration, MessageFactory<T>)> {
vec![]
}
async fn handle(&mut self, message: T) -> Result<()>;
async fn cleanup(&mut self, _shutdown_ok: bool) -> Result<()> {
Ok(())
}
}
struct TaskDispatcher<T: Send + Debug> {
handler: Box<dyn MessageHandler<T>>,
rx: mpsc::UnboundedReceiver<T>,
cancellation_token: CancellationToken,
name: String,
}
impl<T: Send + Debug + 'static> TaskDispatcher<T> {
async fn run(mut self) -> Result<()> {
let tickers = self.handler.tickers();
let mut ticker_intervals: Vec<(Interval, MessageFactory<T>)> = tickers
.into_iter()
.map(|(duration, factory)| {
let interval = interval_at(tokio::time::Instant::now() + duration, duration);
(interval, factory)
})
.collect();
let result = loop {
if ticker_intervals.is_empty() {
tokio::select! {
biased;
_ = self.cancellation_token.cancelled() => {
debug!("Task '{}' received cancellation", self.name);
break Ok(());
}
msg = self.rx.recv() => {
match msg {
Some(message) => {
if let Err(e) = self.handler.handle(message).await {
error!("Task '{}' error handling message: {}", self.name, e);
break Err(e);
}
}
None => {
debug!("Task '{}' channel closed", self.name);
break Ok(());
}
}
}
}
} else {
let first_ticker = ticker_intervals.first_mut().unwrap();
let first_interval = &mut first_ticker.0;
tokio::select! {
biased;
_ = self.cancellation_token.cancelled() => {
debug!("Task '{}' received cancellation", self.name);
break Ok(());
}
_ = first_interval.tick() => {
let message = (ticker_intervals[0].1)();
if let Err(e) = self.handler.handle(message).await {
error!("Task '{}' error handling ticker message: {}", self.name, e);
break Err(e);
}
}
msg = self.rx.recv() => {
match msg {
Some(message) => {
if let Err(e) = self.handler.handle(message).await {
error!("Task '{}' error handling message: {}", self.name, e);
break Err(e);
}
}
None => {
debug!("Task '{}' channel closed", self.name);
break Ok(());
}
}
}
}
}
};
let cleanup_ok = result.is_ok();
self.handler.cleanup(cleanup_ok).await?;
info!("Task dispatcher '{}' stopped", self.name);
result
}
}
pub struct TaskExecutor {
tasks: StdRwLock<Vec<(String, JoinHandle<Result<()>>)>>,
cancellation_token: CancellationToken,
}
impl TaskExecutor {
pub fn new() -> Self {
Self {
tasks: StdRwLock::new(Vec::new()),
cancellation_token: CancellationToken::new(),
}
}
pub fn add_handler<T: Send + Debug + 'static>(
&self,
name: String,
handler: Box<dyn MessageHandler<T>>,
rx: mpsc::UnboundedReceiver<T>,
) -> Result<()> {
let dispatcher = TaskDispatcher {
handler,
rx,
cancellation_token: self.cancellation_token.clone(),
name: name.clone(),
};
let handle = tokio::spawn(async move { dispatcher.run().await });
self.tasks.write().unwrap().push((name, handle));
Ok(())
}
pub async fn shutdown_all(&self) -> Result<()> {
info!("Shutting down all tasks");
self.cancellation_token.cancel();
let tasks = std::mem::take(&mut *self.tasks.write().unwrap());
for (name, handle) in tasks {
match handle.await {
Ok(Ok(())) => debug!("Task '{}' completed successfully", name),
Ok(Err(e)) => warn!("Task '{}' completed with error: {}", name, e),
Err(e) => error!("Task '{}' panicked: {}", name, e),
}
}
Ok(())
}
}
impl Default for TaskExecutor {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum DurabilityResult {
Durable,
Failed(String),
}
impl DurabilityResult {
pub fn ok() -> Self {
Self::Durable
}
pub fn err(msg: impl Into<String>) -> Self {
Self::Failed(msg.into())
}
pub fn is_ok(&self) -> bool {
matches!(self, Self::Durable)
}
pub fn into_result(self) -> Result<()> {
match self {
Self::Durable => Ok(()),
Self::Failed(msg) => Err(Error::io(msg)),
}
}
}
pub type DurabilityWatcher = WatchableOnceCellReader<DurabilityResult>;
pub type DurabilityCell = WatchableOnceCell<DurabilityResult>;
#[derive(Debug, Default)]
pub struct BackpressureStats {
total_count: AtomicU64,
total_wait_ms: AtomicU64,
}
impl BackpressureStats {
pub fn new() -> Self {
Self::default()
}
pub fn record(&self, wait_ms: u64) {
self.total_count.fetch_add(1, Ordering::Relaxed);
self.total_wait_ms.fetch_add(wait_ms, Ordering::Relaxed);
}
pub fn count(&self) -> u64 {
self.total_count.load(Ordering::Relaxed)
}
pub fn total_wait_ms(&self) -> u64 {
self.total_wait_ms.load(Ordering::Relaxed)
}
pub fn snapshot(&self) -> BackpressureStatsSnapshot {
BackpressureStatsSnapshot {
total_count: self.total_count.load(Ordering::Relaxed),
total_wait_ms: self.total_wait_ms.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BackpressureStatsSnapshot {
pub total_count: u64,
pub total_wait_ms: u64,
}
pub struct BackpressureController {
config: RegionWriterConfig,
stats: Arc<BackpressureStats>,
}
impl BackpressureController {
pub fn new(config: RegionWriterConfig) -> Self {
Self {
config,
stats: Arc::new(BackpressureStats::new()),
}
}
pub fn stats(&self) -> &Arc<BackpressureStats> {
&self.stats
}
pub async fn maybe_apply_backpressure<F>(&self, mut get_state: F) -> Result<()>
where
F: FnMut() -> (usize, Option<DurabilityWatcher>),
{
let start = std::time::Instant::now();
let mut iteration = 0u32;
loop {
let (unflushed_memtable_bytes, oldest_watcher) = get_state();
if unflushed_memtable_bytes < self.config.max_unflushed_memtable_bytes {
if iteration > 0 {
let wait_ms = start.elapsed().as_millis() as u64;
self.stats.record(wait_ms);
}
return Ok(());
}
iteration += 1;
debug!(
"Backpressure triggered: unflushed_bytes={}, max={}, iteration={}",
unflushed_memtable_bytes, self.config.max_unflushed_memtable_bytes, iteration
);
if let Some(mut mem_watcher) = oldest_watcher {
tokio::select! {
_ = mem_watcher.await_value() => {}
_ = tokio::time::sleep(self.config.backpressure_log_interval) => {
warn!(
"Backpressure wait timeout, continuing to wait: unflushed_bytes={}, interval={}s, iteration={}",
unflushed_memtable_bytes,
self.config.backpressure_log_interval.as_secs(),
iteration
);
}
}
} else {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
}
}
#[derive(Debug)]
pub struct WriteResult {
pub batch_positions: std::ops::Range<usize>,
}
struct WriterState {
memtable: MemTable,
last_flushed_wal_entry_position: u64,
frozen_memtable_bytes: usize,
frozen_flush_watchers: VecDeque<(usize, DurabilityWatcher)>,
flush_requested: bool,
wal_flush_trigger_count: usize,
last_wal_flush_trigger_time: u64,
}
fn start_time() -> std::time::Instant {
use std::sync::OnceLock;
static START: OnceLock<std::time::Instant> = OnceLock::new();
*START.get_or_init(std::time::Instant::now)
}
fn now_millis() -> u64 {
start_time().elapsed().as_millis() as u64
}
struct SharedWriterState {
state: Arc<RwLock<WriterState>>,
wal_flusher: Arc<WalFlusher>,
wal_flush_tx: mpsc::UnboundedSender<TriggerWalFlush>,
memtable_flush_tx: mpsc::UnboundedSender<TriggerMemTableFlush>,
config: RegionWriterConfig,
schema: Arc<ArrowSchema>,
pk_field_ids: Vec<i32>,
max_memtable_batches: usize,
max_memtable_rows: usize,
ivf_index_partition_capacity_safety_factor: usize,
index_configs: Vec<MemIndexConfig>,
}
impl SharedWriterState {
#[allow(clippy::too_many_arguments)]
fn new(
state: Arc<RwLock<WriterState>>,
wal_flusher: Arc<WalFlusher>,
wal_flush_tx: mpsc::UnboundedSender<TriggerWalFlush>,
memtable_flush_tx: mpsc::UnboundedSender<TriggerMemTableFlush>,
config: RegionWriterConfig,
schema: Arc<ArrowSchema>,
pk_field_ids: Vec<i32>,
max_memtable_batches: usize,
max_memtable_rows: usize,
ivf_index_partition_capacity_safety_factor: usize,
index_configs: Vec<MemIndexConfig>,
) -> Self {
Self {
state,
wal_flusher,
wal_flush_tx,
memtable_flush_tx,
config,
schema,
pk_field_ids,
max_memtable_batches,
max_memtable_rows,
ivf_index_partition_capacity_safety_factor,
index_configs,
}
}
fn freeze_memtable(&self, state: &mut WriterState) -> Result<u64> {
let pending_wal_range = state.memtable.batch_store().pending_wal_flush_range();
let last_wal_entry_position = state.last_flushed_wal_entry_position;
let old_batch_store = state.memtable.batch_store();
let old_indexes = state.memtable.indexes_arc();
let next_generation = state.memtable.generation() + 1;
let mut new_memtable = MemTable::with_capacity(
self.schema.clone(),
next_generation,
self.pk_field_ids.clone(),
CacheConfig::default(),
self.max_memtable_batches,
)?;
if !self.index_configs.is_empty() {
let indexes = Arc::new(IndexStore::from_configs(
&self.index_configs,
self.max_memtable_rows,
self.ivf_index_partition_capacity_safety_factor,
)?);
new_memtable.set_indexes_arc(indexes);
}
let mut old_memtable = std::mem::replace(&mut state.memtable, new_memtable);
old_memtable.freeze(last_wal_entry_position);
let _memtable_flush_watcher = old_memtable.create_memtable_flush_completion();
if pending_wal_range.is_some() {
let completion_cell: WatchableOnceCell<std::result::Result<WalFlushResult, String>> =
WatchableOnceCell::new();
let completion_reader = completion_cell.reader();
old_memtable.set_wal_flush_completion(completion_reader);
let end_batch_position = old_batch_store.len();
self.wal_flusher.trigger_flush(
old_batch_store,
old_indexes,
end_batch_position,
Some(completion_cell),
)?;
}
let frozen_size = old_memtable.estimated_size();
state.frozen_memtable_bytes += frozen_size;
state.last_flushed_wal_entry_position = last_wal_entry_position;
let flush_watcher = old_memtable
.get_memtable_flush_watcher()
.expect("Flush watcher should exist after create_memtable_flush_completion");
state
.frozen_flush_watchers
.push_back((frozen_size, flush_watcher));
let frozen_memtable = Arc::new(old_memtable);
debug!(
"Frozen memtable generation {}, pending_count = {}",
next_generation - 1,
state.frozen_flush_watchers.len()
);
let _ = self.memtable_flush_tx.send(TriggerMemTableFlush {
memtable: frozen_memtable,
done: None,
});
Ok(next_generation)
}
fn track_batch_for_wal(&self, batch_position: usize) -> DurabilityWatcher {
let _wal_watcher = self.wal_flusher.track_batch(batch_position);
let cell: WatchableOnceCell<DurabilityResult> = WatchableOnceCell::new();
cell.write(DurabilityResult::ok());
cell.reader()
}
fn maybe_trigger_memtable_flush(&self, state: &mut WriterState) -> Result<()> {
if state.flush_requested {
return Ok(());
}
let should_flush = state.memtable.estimated_size() >= self.config.max_memtable_size
|| state.memtable.is_batch_store_full();
if should_flush {
state.flush_requested = true;
self.freeze_memtable(state)?;
state.flush_requested = false;
}
Ok(())
}
fn maybe_trigger_wal_flush(&self, state: &mut WriterState) {
let threshold = self.config.max_wal_buffer_size;
let batch_count = state.memtable.batch_count();
let total_bytes = state.memtable.estimated_size();
let batch_store = state.memtable.batch_store();
let indexes = state.memtable.indexes_arc();
let has_pending = batch_store.pending_wal_flush_count() > 0;
let time_trigger = if let Some(interval) = self.config.max_wal_flush_interval {
let interval_millis = interval.as_millis() as u64;
let last_trigger = state.last_wal_flush_trigger_time;
let now = now_millis();
if last_trigger == 0 {
state.last_wal_flush_trigger_time = now;
None
} else {
let elapsed = now.saturating_sub(last_trigger);
if elapsed >= interval_millis && has_pending {
state.last_wal_flush_trigger_time = now;
Some(now)
} else {
None
}
}
} else {
None
};
if time_trigger.is_some() {
let _ = self.wal_flush_tx.send(TriggerWalFlush {
batch_store,
indexes,
end_batch_position: batch_count,
done: None,
});
return;
}
if threshold == 0 {
return;
}
let thresholds_crossed = total_bytes / threshold;
while state.wal_flush_trigger_count < thresholds_crossed {
state.wal_flush_trigger_count += 1;
state.last_wal_flush_trigger_time = now_millis();
let _ = self.wal_flush_tx.send(TriggerWalFlush {
batch_store: batch_store.clone(),
indexes: indexes.clone(),
end_batch_position: batch_count,
done: None,
});
}
}
}
impl SharedWriterState {
fn unflushed_memtable_bytes(&self) -> usize {
self.state
.try_read()
.ok()
.map(|s| {
let active = s.memtable.estimated_size();
active + s.frozen_memtable_bytes
})
.unwrap_or(0)
}
fn oldest_memtable_watcher(&self) -> Option<DurabilityWatcher> {
self.state.try_read().ok().and_then(|s| {
s.frozen_flush_watchers
.front()
.map(|(_, watcher)| watcher.clone())
.or_else(|| s.memtable.get_memtable_flush_watcher())
})
}
}
pub struct RegionWriter {
config: RegionWriterConfig,
epoch: u64,
state: Arc<RwLock<WriterState>>,
wal_flusher: Arc<WalFlusher>,
task_executor: Arc<TaskExecutor>,
manifest_store: Arc<RegionManifestStore>,
stats: SharedWriteStats,
writer_state: Arc<SharedWriterState>,
backpressure: BackpressureController,
}
impl RegionWriter {
pub async fn open(
object_store: Arc<ObjectStore>,
base_path: Path,
base_uri: impl Into<String>,
config: RegionWriterConfig,
schema: Arc<ArrowSchema>,
index_configs: Vec<MemIndexConfig>,
) -> Result<Self> {
let base_uri = base_uri.into();
let region_id = config.region_id;
let manifest_store = Arc::new(RegionManifestStore::new(
object_store.clone(),
&base_path,
region_id,
config.manifest_scan_batch_size,
));
let (epoch, manifest) = manifest_store.claim_epoch(config.region_spec_id).await?;
info!(
"Opened RegionWriter for region {} (epoch {}, generation {})",
region_id, epoch, manifest.current_generation
);
let lance_schema = Schema::try_from(schema.as_ref())?;
let pk_field_ids: Vec<i32> = lance_schema
.unenforced_primary_key()
.iter()
.map(|f| f.id)
.collect();
let mut memtable = MemTable::with_capacity(
schema.clone(),
manifest.current_generation,
pk_field_ids.clone(),
CacheConfig::default(),
config.max_memtable_batches,
)?;
if !index_configs.is_empty() {
let indexes = Arc::new(IndexStore::from_configs(
&index_configs,
config.max_memtable_rows,
config.ivf_index_partition_capacity_safety_factor,
)?);
memtable.set_indexes_arc(indexes);
}
let state = Arc::new(RwLock::new(WriterState {
memtable,
last_flushed_wal_entry_position: manifest.wal_entry_position_last_seen,
frozen_memtable_bytes: 0,
frozen_flush_watchers: VecDeque::new(),
flush_requested: false,
wal_flush_trigger_count: 0,
last_wal_flush_trigger_time: 0,
}));
let mut wal_flusher = WalFlusher::new(
&base_path,
region_id,
epoch,
manifest.wal_entry_position_last_seen + 1,
);
wal_flusher.set_object_store(object_store.clone());
let (wal_flush_tx, wal_flush_rx) = mpsc::unbounded_channel();
let (memtable_flush_tx, memtable_flush_rx) = mpsc::unbounded_channel();
wal_flusher.set_flush_channel(wal_flush_tx.clone());
let wal_flusher = Arc::new(wal_flusher);
let flusher = Arc::new(MemTableFlusher::new(
object_store.clone(),
base_path,
base_uri,
region_id,
manifest_store.clone(),
));
let stats = new_shared_stats();
let backpressure = BackpressureController::new(config.clone());
let task_executor = Arc::new(TaskExecutor::new());
let wal_handler = WalFlushHandler::new(wal_flusher.clone(), state.clone(), stats.clone());
task_executor.add_handler(
"wal_flusher".to_string(),
Box::new(wal_handler),
wal_flush_rx,
)?;
let memtable_handler =
MemTableFlushHandler::new(state.clone(), flusher, epoch, stats.clone());
task_executor.add_handler(
"memtable_flusher".to_string(),
Box::new(memtable_handler),
memtable_flush_rx,
)?;
let writer_state = Arc::new(SharedWriterState::new(
state.clone(),
wal_flusher.clone(),
wal_flush_tx,
memtable_flush_tx,
config.clone(),
schema.clone(),
pk_field_ids,
config.max_memtable_batches,
config.max_memtable_rows,
config.ivf_index_partition_capacity_safety_factor,
index_configs,
));
Ok(Self {
config,
epoch,
state,
wal_flusher,
task_executor,
manifest_store,
stats,
writer_state,
backpressure,
})
}
pub async fn put(&self, batches: Vec<RecordBatch>) -> Result<WriteResult> {
if batches.is_empty() {
return Err(Error::invalid_input("Cannot write empty batch list"));
}
for (i, batch) in batches.iter().enumerate() {
if batch.num_rows() == 0 {
return Err(Error::invalid_input(format!("Batch {} is empty", i)));
}
}
let writer_state = &self.writer_state;
self.backpressure
.maybe_apply_backpressure(|| {
(
writer_state.unflushed_memtable_bytes(),
writer_state.oldest_memtable_watcher(),
)
})
.await?;
let start = std::time::Instant::now();
let (batch_positions, durable_watcher, batch_store, indexes) = {
let mut state = self.state.write().await;
let results = state.memtable.insert_batches_only(batches).await?;
let start_pos = results.first().map(|(pos, _, _)| *pos).unwrap_or(0);
let end_pos = results.last().map(|(pos, _, _)| pos + 1).unwrap_or(0);
let batch_positions = start_pos..end_pos;
let durable_watcher = self
.writer_state
.track_batch_for_wal(end_pos.saturating_sub(1));
self.writer_state.maybe_trigger_wal_flush(&mut state);
if let Err(e) = self.writer_state.maybe_trigger_memtable_flush(&mut state) {
warn!("Failed to trigger memtable flush: {}", e);
}
let batch_store = state.memtable.batch_store();
let indexes = state.memtable.indexes_arc();
(batch_positions, durable_watcher, batch_store, indexes)
};
self.stats.record_put(start.elapsed());
if self.config.durable_write {
self.wal_flusher
.trigger_flush(batch_store, indexes, batch_positions.end, None)?;
durable_watcher.clone().await_value().await.into_result()?;
}
Ok(WriteResult { batch_positions })
}
pub fn stats(&self) -> WriteStatsSnapshot {
self.stats.snapshot()
}
pub fn stats_handle(&self) -> SharedWriteStats {
self.stats.clone()
}
pub async fn manifest(&self) -> Result<Option<RegionManifest>> {
self.manifest_store.read_latest().await
}
pub fn epoch(&self) -> u64 {
self.epoch
}
pub fn region_id(&self) -> Uuid {
self.config.region_id
}
pub async fn memtable_stats(&self) -> MemTableStats {
let state = self.state.read().await;
MemTableStats {
row_count: state.memtable.row_count(),
batch_count: state.memtable.batch_count(),
estimated_size: state.memtable.estimated_size(),
generation: state.memtable.generation(),
}
}
pub async fn scan(&self) -> MemTableScanner {
let state = self.state.read().await;
state.memtable.scan()
}
pub async fn active_memtable_ref(&self) -> crate::dataset::mem_wal::scanner::ActiveMemTableRef {
let state = self.state.read().await;
crate::dataset::mem_wal::scanner::ActiveMemTableRef {
batch_store: state.memtable.batch_store(),
index_store: state
.memtable
.indexes_arc()
.unwrap_or_else(|| Arc::new(IndexStore::new())),
schema: state.memtable.schema().clone(),
generation: state.memtable.generation(),
}
}
pub fn wal_stats(&self) -> WalStats {
WalStats {
next_wal_entry_position: self.wal_flusher.next_wal_entry_position(),
}
}
pub async fn close(self) -> Result<()> {
info!("Closing RegionWriter for region {}", self.config.region_id);
let state = self.state.read().await;
let batch_store = state.memtable.batch_store();
let indexes = state.memtable.indexes_arc();
let batch_count = state.memtable.batch_count();
drop(state);
if batch_count > 0 {
let done = WatchableOnceCell::new();
let reader = done.reader();
if self
.writer_state
.wal_flush_tx
.send(TriggerWalFlush {
batch_store,
indexes,
end_batch_position: batch_count,
done: Some(done),
})
.is_ok()
{
let mut reader = reader;
let _ = reader.await_value().await;
}
}
self.task_executor.shutdown_all().await?;
info!("RegionWriter closed for region {}", self.config.region_id);
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MemTableStats {
pub row_count: usize,
pub batch_count: usize,
pub estimated_size: usize,
pub generation: u64,
}
#[derive(Debug, Clone)]
pub struct WalStats {
pub next_wal_entry_position: u64,
}
struct WalFlushHandler {
wal_flusher: Arc<WalFlusher>,
state: Arc<RwLock<WriterState>>,
stats: SharedWriteStats,
}
impl WalFlushHandler {
fn new(
wal_flusher: Arc<WalFlusher>,
state: Arc<RwLock<WriterState>>,
stats: SharedWriteStats,
) -> Self {
Self {
wal_flusher,
state,
stats,
}
}
}
#[async_trait]
impl MessageHandler<TriggerWalFlush> for WalFlushHandler {
async fn handle(&mut self, message: TriggerWalFlush) -> Result<()> {
let TriggerWalFlush {
batch_store,
indexes,
end_batch_position,
done,
} = message;
let result = self
.do_flush(batch_store, indexes, end_batch_position)
.await;
if let Some(cell) = done {
cell.write(result.map_err(|e| e.to_string()));
}
Ok(())
}
}
impl WalFlushHandler {
async fn do_flush(
&self,
batch_store: Arc<BatchStore>,
indexes: Option<Arc<IndexStore>>,
end_batch_position: usize,
) -> Result<WalFlushResult> {
let start = Instant::now();
let max_flushed = batch_store.max_flushed_batch_position();
let flushed_up_to = max_flushed.map(|p| p + 1).unwrap_or(0);
let is_frozen_flush = {
let state = self.state.read().await;
!Arc::ptr_eq(&batch_store, &state.memtable.batch_store())
};
if !is_frozen_flush && flushed_up_to >= end_batch_position {
return Ok(WalFlushResult {
entry: None,
wal_io_duration: std::time::Duration::ZERO,
index_update_duration: std::time::Duration::ZERO,
index_update_duration_breakdown: std::collections::HashMap::new(),
rows_indexed: 0,
wal_bytes: 0,
});
}
let flush_result = self
.wal_flusher
.flush_to_with_index_update(&batch_store, end_batch_position, indexes)
.await?;
let batches_flushed = flush_result
.entry
.as_ref()
.map(|e| e.num_batches)
.unwrap_or(0);
if batches_flushed > 0 {
self.stats
.record_wal_flush(start.elapsed(), flush_result.wal_bytes);
self.stats.record_wal_io(flush_result.wal_io_duration);
self.stats.record_index_update(
flush_result.index_update_duration,
flush_result.rows_indexed,
);
}
Ok(flush_result)
}
}
struct MemTableFlushHandler {
state: Arc<RwLock<WriterState>>,
flusher: Arc<MemTableFlusher>,
epoch: u64,
stats: SharedWriteStats,
}
impl MemTableFlushHandler {
fn new(
state: Arc<RwLock<WriterState>>,
flusher: Arc<MemTableFlusher>,
epoch: u64,
stats: SharedWriteStats,
) -> Self {
Self {
state,
flusher,
epoch,
stats,
}
}
}
#[async_trait]
impl MessageHandler<TriggerMemTableFlush> for MemTableFlushHandler {
async fn handle(&mut self, message: TriggerMemTableFlush) -> Result<()> {
let TriggerMemTableFlush { memtable, done } = message;
let result = self.flush_memtable(memtable).await;
if let Some(tx) = done {
let _ = tx.send(result);
} else {
result?;
}
Ok(())
}
}
impl MemTableFlushHandler {
async fn flush_memtable(
&mut self,
memtable: Arc<MemTable>,
) -> Result<super::memtable::flush::FlushResult> {
let start = Instant::now();
let memtable_size = memtable.estimated_size();
if let Some(mut completion_reader) = memtable.take_wal_flush_completion() {
completion_reader
.await_value()
.await
.map_err(|e| Error::io(format!("WAL flush failed: {}", e)))?;
}
let result = self.flusher.flush(&memtable, self.epoch).await?;
memtable.signal_memtable_flush_complete();
{
let mut state = self.state.write().await;
if let Some((_size, _watcher)) = state.frozen_flush_watchers.pop_front() {
state.frozen_memtable_bytes =
state.frozen_memtable_bytes.saturating_sub(memtable_size);
}
}
self.stats
.record_memtable_flush(start.elapsed(), result.rows_flushed);
info!(
"Flushed frozen memtable generation {} ({} rows in {:?})",
result.generation.generation,
result.rows_flushed,
start.elapsed()
);
Ok(result)
}
}
#[derive(Debug, Default)]
pub struct WriteStats {
put_count: AtomicU64,
put_time_nanos: AtomicU64,
wal_flush_count: AtomicU64,
wal_flush_time_nanos: AtomicU64,
wal_flush_bytes: AtomicU64,
wal_io_time_nanos: AtomicU64,
wal_io_count: AtomicU64,
index_update_time_nanos: AtomicU64,
index_update_count: AtomicU64,
index_update_rows: AtomicU64,
memtable_flush_count: AtomicU64,
memtable_flush_time_nanos: AtomicU64,
memtable_flush_rows: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct WriteStatsSnapshot {
pub put_count: u64,
pub put_time: Duration,
pub wal_flush_count: u64,
pub wal_flush_time: Duration,
pub wal_flush_bytes: u64,
pub wal_io_time: Duration,
pub wal_io_count: u64,
pub index_update_time: Duration,
pub index_update_count: u64,
pub index_update_rows: u64,
pub memtable_flush_count: u64,
pub memtable_flush_time: Duration,
pub memtable_flush_rows: u64,
}
impl WriteStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_put(&self, duration: Duration) {
self.put_count.fetch_add(1, Ordering::Relaxed);
self.put_time_nanos
.fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
}
pub fn record_wal_flush(&self, duration: Duration, bytes: usize) {
self.wal_flush_count.fetch_add(1, Ordering::Relaxed);
self.wal_flush_time_nanos
.fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
self.wal_flush_bytes
.fetch_add(bytes as u64, Ordering::Relaxed);
}
pub fn record_wal_io(&self, duration: Duration) {
self.wal_io_count.fetch_add(1, Ordering::Relaxed);
self.wal_io_time_nanos
.fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
}
pub fn record_index_update(&self, duration: Duration, rows: usize) {
self.index_update_count.fetch_add(1, Ordering::Relaxed);
self.index_update_time_nanos
.fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
self.index_update_rows
.fetch_add(rows as u64, Ordering::Relaxed);
}
pub fn record_memtable_flush(&self, duration: Duration, rows: usize) {
self.memtable_flush_count.fetch_add(1, Ordering::Relaxed);
self.memtable_flush_time_nanos
.fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
self.memtable_flush_rows
.fetch_add(rows as u64, Ordering::Relaxed);
}
pub fn snapshot(&self) -> WriteStatsSnapshot {
WriteStatsSnapshot {
put_count: self.put_count.load(Ordering::Relaxed),
put_time: Duration::from_nanos(self.put_time_nanos.load(Ordering::Relaxed)),
wal_flush_count: self.wal_flush_count.load(Ordering::Relaxed),
wal_flush_time: Duration::from_nanos(self.wal_flush_time_nanos.load(Ordering::Relaxed)),
wal_flush_bytes: self.wal_flush_bytes.load(Ordering::Relaxed),
wal_io_time: Duration::from_nanos(self.wal_io_time_nanos.load(Ordering::Relaxed)),
wal_io_count: self.wal_io_count.load(Ordering::Relaxed),
index_update_time: Duration::from_nanos(
self.index_update_time_nanos.load(Ordering::Relaxed),
),
index_update_count: self.index_update_count.load(Ordering::Relaxed),
index_update_rows: self.index_update_rows.load(Ordering::Relaxed),
memtable_flush_count: self.memtable_flush_count.load(Ordering::Relaxed),
memtable_flush_time: Duration::from_nanos(
self.memtable_flush_time_nanos.load(Ordering::Relaxed),
),
memtable_flush_rows: self.memtable_flush_rows.load(Ordering::Relaxed),
}
}
pub fn reset(&self) {
self.put_count.store(0, Ordering::Relaxed);
self.put_time_nanos.store(0, Ordering::Relaxed);
self.wal_flush_count.store(0, Ordering::Relaxed);
self.wal_flush_time_nanos.store(0, Ordering::Relaxed);
self.wal_flush_bytes.store(0, Ordering::Relaxed);
self.wal_io_time_nanos.store(0, Ordering::Relaxed);
self.wal_io_count.store(0, Ordering::Relaxed);
self.index_update_time_nanos.store(0, Ordering::Relaxed);
self.index_update_count.store(0, Ordering::Relaxed);
self.index_update_rows.store(0, Ordering::Relaxed);
self.memtable_flush_count.store(0, Ordering::Relaxed);
self.memtable_flush_time_nanos.store(0, Ordering::Relaxed);
self.memtable_flush_rows.store(0, Ordering::Relaxed);
}
}
impl WriteStatsSnapshot {
pub fn avg_put_latency(&self) -> Option<Duration> {
if self.put_count > 0 {
Some(self.put_time / self.put_count as u32)
} else {
None
}
}
pub fn put_throughput(&self) -> f64 {
if self.put_time.as_secs_f64() > 0.0 {
self.put_count as f64 / self.put_time.as_secs_f64()
} else {
0.0
}
}
pub fn avg_wal_flush_latency(&self) -> Option<Duration> {
if self.wal_flush_count > 0 {
Some(self.wal_flush_time / self.wal_flush_count as u32)
} else {
None
}
}
pub fn avg_wal_flush_bytes(&self) -> Option<u64> {
if self.wal_flush_count > 0 {
Some(self.wal_flush_bytes / self.wal_flush_count)
} else {
None
}
}
pub fn wal_throughput_bytes(&self) -> f64 {
if self.wal_flush_time.as_secs_f64() > 0.0 {
self.wal_flush_bytes as f64 / self.wal_flush_time.as_secs_f64()
} else {
0.0
}
}
pub fn avg_wal_io_latency(&self) -> Option<Duration> {
if self.wal_io_count > 0 {
Some(self.wal_io_time / self.wal_io_count as u32)
} else {
None
}
}
pub fn avg_index_update_latency(&self) -> Option<Duration> {
if self.index_update_count > 0 {
Some(self.index_update_time / self.index_update_count as u32)
} else {
None
}
}
pub fn avg_index_update_rows(&self) -> Option<u64> {
if self.index_update_count > 0 {
Some(self.index_update_rows / self.index_update_count)
} else {
None
}
}
pub fn avg_memtable_flush_latency(&self) -> Option<Duration> {
if self.memtable_flush_count > 0 {
Some(self.memtable_flush_time / self.memtable_flush_count as u32)
} else {
None
}
}
pub fn avg_memtable_flush_rows(&self) -> Option<u64> {
if self.memtable_flush_count > 0 {
Some(self.memtable_flush_rows / self.memtable_flush_count)
} else {
None
}
}
pub fn log_summary(&self, prefix: &str) {
tracing::info!(
prefix = prefix,
put_count = self.put_count,
put_throughput = self.put_throughput(),
put_avg_latency_us = self.avg_put_latency().unwrap_or_default().as_micros() as u64,
wal_flush_count = self.wal_flush_count,
wal_flush_bytes = self.wal_flush_bytes,
wal_avg_latency_us =
self.avg_wal_flush_latency().unwrap_or_default().as_micros() as u64,
memtable_flush_count = self.memtable_flush_count,
memtable_flush_rows = self.memtable_flush_rows,
memtable_avg_latency_us = self
.avg_memtable_flush_latency()
.unwrap_or_default()
.as_micros() as u64,
"MemWAL stats summary"
);
}
pub fn log_wal_breakdown(&self, prefix: &str) {
if self.wal_flush_count > 0 {
tracing::info!(
prefix = prefix,
wal_total_latency_us =
self.avg_wal_flush_latency().unwrap_or_default().as_micros() as u64,
wal_io_latency_us =
self.avg_wal_io_latency().unwrap_or_default().as_micros() as u64,
index_update_latency_us = self
.avg_index_update_latency()
.unwrap_or_default()
.as_micros() as u64,
index_update_rows = self.index_update_rows,
"MemWAL WAL flush breakdown"
);
}
}
}
pub type SharedWriteStats = Arc<WriteStats>;
pub fn new_shared_stats() -> SharedWriteStats {
Arc::new(WriteStats::new())
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, StringArray};
use arrow_schema::{DataType, Field};
use tempfile::TempDir;
async fn create_local_store() -> (Arc<ObjectStore>, Path, String, TempDir) {
let temp_dir = tempfile::tempdir().unwrap();
let uri = format!("file://{}", temp_dir.path().display());
let (store, path) = ObjectStore::from_uri(&uri).await.unwrap();
(store, path, uri, temp_dir)
}
fn create_test_schema() -> Arc<ArrowSchema> {
Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &ArrowSchema, start_id: i32, num_rows: usize) -> RecordBatch {
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from_iter_values(
start_id..start_id + num_rows as i32,
)),
Arc::new(StringArray::from_iter_values(
(0..num_rows).map(|i| format!("name_{}", start_id as usize + i)),
)),
],
)
.unwrap()
}
#[tokio::test]
async fn test_region_writer_basic_write() {
let (store, base_path, base_uri, _temp_dir) = create_local_store().await;
let schema = create_test_schema();
let config = RegionWriterConfig {
region_id: Uuid::new_v4(),
region_spec_id: 0,
durable_write: false,
sync_indexed_write: false,
max_wal_buffer_size: 1024 * 1024,
max_wal_flush_interval: None,
max_memtable_size: 64 * 1024 * 1024,
manifest_scan_batch_size: 2,
..Default::default()
};
let writer = RegionWriter::open(
store,
base_path,
base_uri,
config.clone(),
schema.clone(),
vec![],
)
.await
.unwrap();
let batch = create_test_batch(&schema, 0, 10);
let result = writer.put(vec![batch]).await.unwrap();
assert_eq!(result.batch_positions, 0..1);
let stats = writer.memtable_stats().await;
assert_eq!(stats.row_count, 10);
assert_eq!(stats.batch_count, 1);
writer.close().await.unwrap();
}
#[tokio::test]
async fn test_region_writer_multiple_writes() {
let (store, base_path, base_uri, _temp_dir) = create_local_store().await;
let schema = create_test_schema();
let config = RegionWriterConfig {
region_id: Uuid::new_v4(),
region_spec_id: 0,
durable_write: false,
sync_indexed_write: false,
max_wal_buffer_size: 1024 * 1024,
max_wal_flush_interval: None,
max_memtable_size: 64 * 1024 * 1024,
manifest_scan_batch_size: 2,
..Default::default()
};
let writer = RegionWriter::open(store, base_path, base_uri, config, schema.clone(), vec![])
.await
.unwrap();
let batches: Vec<_> = (0..5)
.map(|i| create_test_batch(&schema, i * 10, 10))
.collect();
let result = writer.put(batches).await.unwrap();
assert_eq!(result.batch_positions, 0..5);
let stats = writer.memtable_stats().await;
assert_eq!(stats.row_count, 50);
assert_eq!(stats.batch_count, 5);
writer.close().await.unwrap();
}
#[tokio::test]
async fn test_region_writer_with_indexes() {
let (store, base_path, base_uri, _temp_dir) = create_local_store().await;
let schema = create_test_schema();
let config = RegionWriterConfig {
region_id: Uuid::new_v4(),
region_spec_id: 0,
durable_write: false,
sync_indexed_write: true,
max_wal_buffer_size: 1024 * 1024,
max_wal_flush_interval: None,
max_memtable_size: 64 * 1024 * 1024,
manifest_scan_batch_size: 2,
..Default::default()
};
let index_configs = vec![MemIndexConfig::BTree(BTreeIndexConfig {
name: "id_idx".to_string(),
field_id: 0,
column: "id".to_string(),
})];
let writer = RegionWriter::open(
store,
base_path,
base_uri,
config,
schema.clone(),
index_configs,
)
.await
.unwrap();
let batch = create_test_batch(&schema, 0, 10);
writer.put(vec![batch]).await.unwrap();
let stats = writer.memtable_stats().await;
assert_eq!(stats.row_count, 10);
writer.close().await.unwrap();
}
#[tokio::test]
async fn test_region_writer_auto_flush_by_size() {
let (store, base_path, base_uri, _temp_dir) = create_local_store().await;
let schema = create_test_schema();
let config = RegionWriterConfig {
region_id: Uuid::new_v4(),
region_spec_id: 0,
durable_write: false,
sync_indexed_write: false,
max_wal_buffer_size: 1024 * 1024,
max_wal_flush_interval: None,
max_memtable_size: 1024, manifest_scan_batch_size: 2,
..Default::default()
};
let writer = RegionWriter::open(store, base_path, base_uri, config, schema.clone(), vec![])
.await
.unwrap();
let initial_gen = writer.memtable_stats().await.generation;
for i in 0..20 {
let batch = create_test_batch(&schema, i * 10, 10);
writer.put(vec![batch]).await.unwrap();
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let stats = writer.memtable_stats().await;
assert!(
stats.generation > initial_gen,
"Generation should increment after auto-flush"
);
writer.close().await.unwrap();
}
#[tokio::test]
async fn test_no_backpressure_when_under_threshold() {
let config = RegionWriterConfig::default().with_max_unflushed_memtable_bytes(1024 * 1024);
let controller = BackpressureController::new(config);
controller
.maybe_apply_backpressure(|| (100, None))
.await
.unwrap();
assert_eq!(controller.stats().count(), 0);
}
#[tokio::test]
async fn test_backpressure_loops_until_under_threshold() {
use std::sync::atomic::AtomicUsize;
use std::time::Duration;
let config = RegionWriterConfig::default()
.with_max_unflushed_memtable_bytes(100) .with_backpressure_log_interval(Duration::from_millis(50));
let controller = BackpressureController::new(config);
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
controller
.maybe_apply_backpressure(move || {
let count = call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let unflushed = 1000usize.saturating_sub(count * 400);
(unflushed, None)
})
.await
.unwrap();
assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 4);
assert_eq!(controller.stats().count(), 1);
}
#[test]
fn test_record_put() {
let stats = WriteStats::new();
stats.record_put(Duration::from_millis(10));
stats.record_put(Duration::from_millis(20));
let snapshot = stats.snapshot();
assert_eq!(snapshot.put_count, 2);
assert_eq!(snapshot.put_time, Duration::from_millis(30));
assert_eq!(snapshot.avg_put_latency(), Some(Duration::from_millis(15)));
}
#[test]
fn test_record_wal_flush() {
let stats = WriteStats::new();
stats.record_wal_flush(Duration::from_millis(100), 1024);
stats.record_wal_flush(Duration::from_millis(200), 2048);
let snapshot = stats.snapshot();
assert_eq!(snapshot.wal_flush_count, 2);
assert_eq!(snapshot.wal_flush_time, Duration::from_millis(300));
assert_eq!(snapshot.wal_flush_bytes, 3072);
assert_eq!(snapshot.avg_wal_flush_bytes(), Some(1536));
}
#[test]
fn test_record_memtable_flush() {
let stats = WriteStats::new();
stats.record_memtable_flush(Duration::from_secs(1), 10000);
let snapshot = stats.snapshot();
assert_eq!(snapshot.memtable_flush_count, 1);
assert_eq!(snapshot.memtable_flush_time, Duration::from_secs(1));
assert_eq!(snapshot.memtable_flush_rows, 10000);
}
#[test]
fn test_stats_reset() {
let stats = WriteStats::new();
stats.record_put(Duration::from_millis(10));
stats.record_wal_flush(Duration::from_millis(100), 1024);
stats.reset();
let snapshot = stats.snapshot();
assert_eq!(snapshot.put_count, 0);
assert_eq!(snapshot.wal_flush_count, 0);
}
}
#[cfg(test)]
mod region_writer_tests {
use std::sync::Arc;
use arrow_array::{
FixedSizeListArray, Float32Array, Int64Array, RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use lance_arrow::FixedSizeListArrayExt;
use lance_index::scalar::ScalarIndexParams;
use lance_index::scalar::inverted::InvertedIndexParams;
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::builder::PQBuildParams;
use lance_index::{DatasetIndexExt, IndexType};
use lance_linalg::distance::MetricType;
use uuid::Uuid;
use crate::dataset::mem_wal::{DatasetMemWalExt, MemWalConfig};
use crate::dataset::{Dataset, WriteParams};
use crate::index::vector::VectorIndexParams;
use super::super::RegionWriterConfig;
fn create_test_schema(vector_dim: i32) -> Arc<ArrowSchema> {
use std::collections::HashMap;
let mut id_metadata = HashMap::new();
id_metadata.insert(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
);
let id_field = Field::new("id", DataType::Int64, false).with_metadata(id_metadata);
Arc::new(ArrowSchema::new(vec![
id_field,
Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
vector_dim,
),
true,
),
Field::new("text", DataType::Utf8, true),
]))
}
fn create_test_batch(
schema: &ArrowSchema,
start_id: i64,
num_rows: usize,
vector_dim: i32,
) -> RecordBatch {
let vectors: Vec<f32> = (0..num_rows)
.flat_map(|i| {
let seed = (start_id as usize + i) as f32;
(0..vector_dim as usize).map(move |d| (seed * 0.1 + d as f32 * 0.01).sin())
})
.collect();
let vector_array =
FixedSizeListArray::try_new_from_values(Float32Array::from(vectors), vector_dim)
.unwrap();
let texts: Vec<String> = (0..num_rows)
.map(|i| format!("Sample text for row {}", start_id as usize + i))
.collect();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int64Array::from_iter_values(
start_id..start_id + num_rows as i64,
)),
Arc::new(vector_array),
Arc::new(StringArray::from_iter_values(texts)),
],
)
.unwrap()
}
#[tokio::test]
async fn test_region_writer_smoke() {
let vector_dim = 128;
let batch_size = 20;
let num_batches = 100;
let schema = create_test_schema(vector_dim);
let uri = format!("memory://test_region_writer_{}", Uuid::new_v4());
let initial_batch = create_test_batch(&schema, 0, 100, vector_dim);
let batches = RecordBatchIterator::new([Ok(initial_batch)], schema.clone());
let mut dataset = Dataset::write(batches, &uri, Some(WriteParams::default()))
.await
.expect("Failed to create dataset");
dataset
.initialize_mem_wal(MemWalConfig {
region_spec: None,
maintained_indexes: vec![],
})
.await
.expect("Failed to initialize MemWAL");
let region_id = Uuid::new_v4();
let config = RegionWriterConfig::new(region_id)
.with_durable_write(false)
.with_sync_indexed_write(false);
let writer = dataset
.mem_wal_writer(region_id, config)
.await
.expect("Failed to create writer");
let batches: Vec<RecordBatch> = (0..num_batches)
.map(|i| create_test_batch(&schema, (i * batch_size) as i64, batch_size, vector_dim))
.collect();
writer.put(batches).await.expect("Failed to write");
writer.close().await.expect("Failed to close");
}
#[tokio::test]
#[ignore]
async fn test_region_writer_s3_ivfpq() {
let prefix = std::env::var("DATASET_PREFIX").expect("DATASET_PREFIX not set");
let vector_dim = 512;
let batch_size = 20;
let num_batches = 10000;
let num_partitions = 16;
let num_sub_vectors = 64;
let schema = create_test_schema(vector_dim);
let uri = format!(
"{}/test_s3_{}",
prefix.trim_end_matches('/'),
Uuid::new_v4()
);
let initial_batch = create_test_batch(&schema, 0, 1000, vector_dim);
let batches = RecordBatchIterator::new([Ok(initial_batch)], schema.clone());
let mut dataset = Dataset::write(batches, &uri, Some(WriteParams::default()))
.await
.expect("Failed to create dataset");
let scalar_params = ScalarIndexParams::default();
dataset
.create_index(
&["id"],
IndexType::BTree,
Some("id_btree".to_string()),
&scalar_params,
false,
)
.await
.expect("Failed to create BTree index");
let fts_params = InvertedIndexParams::default();
dataset
.create_index(
&["text"],
IndexType::Inverted,
Some("text_fts".to_string()),
&fts_params,
false,
)
.await
.expect("Failed to create FTS index");
let ivf_params = IvfBuildParams {
num_partitions: Some(num_partitions),
..Default::default()
};
let pq_params = PQBuildParams {
num_sub_vectors,
num_bits: 8,
..Default::default()
};
let vector_params =
VectorIndexParams::with_ivf_pq_params(MetricType::L2, ivf_params, pq_params);
dataset
.create_index(
&["vector"],
IndexType::Vector,
Some("vector_idx".to_string()),
&vector_params,
true,
)
.await
.expect("Failed to create IVF-PQ index");
dataset
.initialize_mem_wal(MemWalConfig {
region_spec: None,
maintained_indexes: vec![
"id_btree".to_string(),
"text_fts".to_string(),
"vector_idx".to_string(),
],
})
.await
.expect("Failed to initialize MemWAL");
let region_id = Uuid::new_v4();
let config = RegionWriterConfig::new(region_id)
.with_durable_write(false)
.with_sync_indexed_write(false);
let writer = dataset
.mem_wal_writer(region_id, config)
.await
.expect("Failed to create writer");
let batches: Vec<RecordBatch> = (0..num_batches)
.map(|i| create_test_batch(&schema, (i * batch_size) as i64, batch_size, vector_dim))
.collect();
writer.put(batches).await.expect("Failed to write");
writer.close().await.expect("Failed to close");
}
#[tokio::test]
async fn test_region_writer_e2e_correctness() {
use std::time::Duration;
use tempfile::TempDir;
let vector_dim = 32;
let rows_per_batch = 50;
let num_write_rounds = 3;
let batches_per_round = 3;
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let uri = format!("file://{}", temp_dir.path().display());
let schema = create_test_schema(vector_dim);
let initial_batch = create_test_batch(&schema, 0, 500, vector_dim);
let batches = RecordBatchIterator::new([Ok(initial_batch)], schema.clone());
let mut dataset = Dataset::write(batches, &uri, Some(WriteParams::default()))
.await
.expect("Failed to create dataset");
dataset
.create_index(
&["id"],
IndexType::BTree,
Some("id_btree".to_string()),
&ScalarIndexParams::default(),
false,
)
.await
.expect("Failed to create BTree index");
dataset
.initialize_mem_wal(MemWalConfig {
region_spec: None,
maintained_indexes: vec!["id_btree".to_string()],
})
.await
.expect("Failed to initialize MemWAL");
let region_id = Uuid::new_v4();
let config = RegionWriterConfig::new(region_id)
.with_durable_write(true) .with_sync_indexed_write(true)
.with_max_memtable_size(50 * 1024) .with_max_wal_buffer_size(10 * 1024) .with_max_wal_flush_interval(Duration::from_millis(50));
let writer = dataset
.mem_wal_writer(region_id, config)
.await
.expect("Failed to create writer");
let mut total_rows_written = 0i64;
for _round in 0..num_write_rounds {
let start_id = 500 + total_rows_written;
let batches_to_write: Vec<RecordBatch> = (0..batches_per_round)
.map(|i| {
create_test_batch(
&schema,
start_id + (i * rows_per_batch) as i64,
rows_per_batch,
vector_dim,
)
})
.collect();
writer.put(batches_to_write).await.expect("Failed to write");
total_rows_written += (batches_per_round * rows_per_batch) as i64;
tokio::time::sleep(Duration::from_millis(150)).await;
}
writer.close().await.expect("Failed to close");
let mem_wal_dir = temp_dir.path().join("_mem_wal").join(region_id.to_string());
assert!(mem_wal_dir.exists(), "MemWAL directory should exist");
let wal_dir = mem_wal_dir.join("wal");
assert!(wal_dir.exists(), "WAL directory should exist");
let wal_files: Vec<_> = std::fs::read_dir(&wal_dir)
.expect("Failed to read WAL dir")
.filter_map(|e| e.ok())
.collect();
assert!(
!wal_files.is_empty(),
"WAL directory should contain at least one file"
);
let manifest_dir = mem_wal_dir.join("manifest");
assert!(manifest_dir.exists(), "Manifest directory should exist");
let manifest_files: Vec<_> = std::fs::read_dir(&manifest_dir)
.expect("Failed to read manifest dir")
.filter_map(|e| e.ok())
.collect();
assert!(
!manifest_files.is_empty(),
"Manifest directory should contain at least one file"
);
let (store, base_path) = lance_io::object_store::ObjectStore::from_uri(&uri)
.await
.expect("Failed to open store");
let manifest_store =
super::super::manifest::RegionManifestStore::new(store, &base_path, region_id, 2);
let manifest = manifest_store
.read_latest()
.await
.expect("Failed to read manifest")
.expect("Manifest should exist");
assert!(
!manifest.flushed_generations.is_empty(),
"Should have at least one flushed generation"
);
for flushed_gen in &manifest.flushed_generations {
let gen_path = temp_dir
.path()
.join("_mem_wal")
.join(region_id.to_string())
.join(&flushed_gen.path);
assert!(
gen_path.exists(),
"Flushed generation directory should exist at {:?}",
gen_path
);
let gen_contents_count = std::fs::read_dir(&gen_path)
.expect("Failed to read gen dir")
.filter_map(|e| e.ok())
.count();
assert!(
gen_contents_count > 0,
"Generation directory should have files"
);
}
for wal_file in wal_files.iter().take(1) {
let wal_path = wal_file.path();
let file_name = wal_path.file_name().unwrap().to_string_lossy();
assert!(
file_name.ends_with(".arrow"),
"WAL file should have .arrow extension"
);
}
let dataset = Dataset::open(&uri).await.expect("Failed to reopen dataset");
let new_region_id = Uuid::new_v4();
let new_config = RegionWriterConfig::new(new_region_id)
.with_durable_write(false)
.with_sync_indexed_write(true);
let new_writer = dataset
.mem_wal_writer(new_region_id, new_config)
.await
.expect("Failed to create new writer");
let verify_batch = create_test_batch(&schema, 10000, 10, vector_dim);
new_writer
.put(vec![verify_batch])
.await
.expect("Failed to write to new region");
let scanner = new_writer.scan().await;
let result = scanner.try_into_batch().await.expect("Failed to scan");
assert_eq!(result.num_rows(), 10, "New region should have 10 rows");
new_writer
.close()
.await
.expect("Failed to close new writer");
}
}