use std::io::Cursor;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use arrow_array::RecordBatch;
use arrow_ipc::reader::StreamReader;
use arrow_ipc::writer::StreamWriter;
use arrow_schema::Schema as ArrowSchema;
use bytes::Bytes;
use lance_core::{Error, Result};
use lance_io::object_store::ObjectStore;
use object_store::path::Path;
use tokio::sync::{mpsc, watch};
use uuid::Uuid;
use super::util::{WatchableOnceCell, region_wal_path, wal_entry_filename};
use super::index::IndexStore;
use super::memtable::batch_store::{BatchStore, StoredBatch};
pub const WRITER_EPOCH_KEY: &str = "writer_epoch";
#[derive(Clone)]
pub struct BatchDurableWatcher {
rx: watch::Receiver<usize>,
target_batch_position: usize,
}
impl BatchDurableWatcher {
pub fn new(rx: watch::Receiver<usize>, target_batch_position: usize) -> Self {
Self {
rx,
target_batch_position,
}
}
pub async fn wait(&mut self) -> Result<()> {
loop {
let current = *self.rx.borrow();
if current >= self.target_batch_position {
return Ok(());
}
self.rx
.changed()
.await
.map_err(|_| Error::io("Durable watermark channel closed"))?;
}
}
pub fn is_durable(&self) -> bool {
*self.rx.borrow() >= self.target_batch_position
}
}
impl std::fmt::Debug for BatchDurableWatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BatchDurableWatcher")
.field("target_batch_position", &self.target_batch_position)
.field("current_watermark", &*self.rx.borrow())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct WalEntry {
pub position: u64,
pub writer_epoch: u64,
pub num_batches: usize,
}
#[derive(Debug, Clone)]
pub struct WalFlushResult {
pub entry: Option<WalEntry>,
pub wal_io_duration: std::time::Duration,
pub index_update_duration: std::time::Duration,
pub index_update_duration_breakdown: std::collections::HashMap<String, std::time::Duration>,
pub rows_indexed: usize,
pub wal_bytes: usize,
}
pub struct TriggerWalFlush {
pub batch_store: Arc<BatchStore>,
pub indexes: Option<Arc<IndexStore>>,
pub end_batch_position: usize,
pub done: Option<WatchableOnceCell<std::result::Result<WalFlushResult, String>>>,
}
impl std::fmt::Debug for TriggerWalFlush {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TriggerWalFlush")
.field(
"pending_batches",
&self.batch_store.pending_wal_flush_count(),
)
.field("end_batch_position", &self.end_batch_position)
.finish()
}
}
pub struct WalFlusher {
durable_watermark_tx: watch::Sender<usize>,
durable_watermark_rx: watch::Receiver<usize>,
object_store: Option<Arc<ObjectStore>>,
region_id: Uuid,
writer_epoch: u64,
next_wal_entry_position: AtomicU64,
flush_tx: Option<mpsc::UnboundedSender<TriggerWalFlush>>,
wal_dir: Path,
wal_flush_cell: std::sync::Mutex<Option<WatchableOnceCell<super::write::DurabilityResult>>>,
}
impl WalFlusher {
pub fn new(
base_path: &Path,
region_id: Uuid,
writer_epoch: u64,
next_wal_entry_position: u64,
) -> Self {
let wal_dir = region_wal_path(base_path, ®ion_id);
let (durable_watermark_tx, durable_watermark_rx) = watch::channel(0);
let wal_flush_cell = WatchableOnceCell::new();
Self {
durable_watermark_tx,
durable_watermark_rx,
object_store: None,
region_id,
writer_epoch,
next_wal_entry_position: AtomicU64::new(next_wal_entry_position),
flush_tx: None,
wal_dir,
wal_flush_cell: std::sync::Mutex::new(Some(wal_flush_cell)),
}
}
pub fn set_object_store(&mut self, object_store: Arc<ObjectStore>) {
self.object_store = Some(object_store);
}
pub fn set_flush_channel(&mut self, tx: mpsc::UnboundedSender<TriggerWalFlush>) {
self.flush_tx = Some(tx);
}
pub fn track_batch(&self, batch_position: usize) -> BatchDurableWatcher {
BatchDurableWatcher::new(self.durable_watermark_rx.clone(), batch_position + 1)
}
pub fn durable_watermark(&self) -> usize {
*self.durable_watermark_rx.borrow()
}
pub fn wal_flush_watcher(
&self,
) -> Option<super::util::WatchableOnceCellReader<super::write::DurabilityResult>> {
self.wal_flush_cell
.lock()
.unwrap()
.as_ref()
.map(|cell| cell.reader())
}
fn signal_wal_flush_complete(&self) {
let mut guard = self.wal_flush_cell.lock().unwrap();
if let Some(cell) = guard.take() {
cell.write(super::write::DurabilityResult::ok());
}
*guard = Some(WatchableOnceCell::new());
}
pub fn trigger_flush(
&self,
batch_store: Arc<BatchStore>,
indexes: Option<Arc<IndexStore>>,
end_batch_position: usize,
done: Option<WatchableOnceCell<std::result::Result<WalFlushResult, String>>>,
) -> Result<()> {
if let Some(tx) = &self.flush_tx {
tx.send(TriggerWalFlush {
batch_store,
indexes,
end_batch_position,
done,
})
.map_err(|_| Error::io("WAL flush channel closed"))?;
}
Ok(())
}
pub async fn flush_to_with_index_update(
&self,
batch_store: &BatchStore,
end_batch_position: usize,
indexes: Option<Arc<IndexStore>>,
) -> Result<WalFlushResult> {
let start_batch_position = batch_store
.max_flushed_batch_position()
.map(|w| w + 1)
.unwrap_or(0);
if start_batch_position >= end_batch_position {
return Ok(WalFlushResult {
entry: None,
wal_io_duration: std::time::Duration::ZERO,
index_update_duration: std::time::Duration::ZERO,
index_update_duration_breakdown: std::collections::HashMap::new(),
rows_indexed: 0,
wal_bytes: 0,
});
}
let object_store = self
.object_store
.as_ref()
.ok_or_else(|| Error::io("Object store not set on WAL flusher"))?;
let wal_entry_position = self.next_wal_entry_position.fetch_add(1, Ordering::SeqCst);
let final_path = self.wal_entry_path(wal_entry_position);
let mut stored_batches: Vec<StoredBatch> =
Vec::with_capacity(end_batch_position - start_batch_position);
for batch_position in start_batch_position..end_batch_position {
if let Some(stored) = batch_store.get(batch_position) {
stored_batches.push(stored.clone());
}
}
if stored_batches.is_empty() {
return Ok(WalFlushResult {
entry: None,
wal_io_duration: std::time::Duration::ZERO,
index_update_duration: std::time::Duration::ZERO,
index_update_duration_breakdown: std::collections::HashMap::new(),
rows_indexed: 0,
wal_bytes: 0,
});
}
let rows_to_index: usize = stored_batches.iter().map(|b| b.num_rows).sum();
let num_batches = stored_batches.len();
let schema = stored_batches[0].data.schema();
let mut metadata = schema.metadata().clone();
metadata.insert(WRITER_EPOCH_KEY.to_string(), self.writer_epoch.to_string());
let schema_with_epoch = Arc::new(ArrowSchema::new_with_metadata(
schema.fields().to_vec(),
metadata,
));
let mut buffer = Vec::new();
{
let mut writer =
StreamWriter::try_new(&mut buffer, &schema_with_epoch).map_err(|e| {
Error::io(format!("Failed to create Arrow IPC stream writer: {}", e))
})?;
for stored in &stored_batches {
writer.write(&stored.data).map_err(|e| {
Error::io(format!("Failed to write batch to Arrow IPC stream: {}", e))
})?;
}
writer
.finish()
.map_err(|e| Error::io(format!("Failed to finish Arrow IPC stream: {}", e)))?;
}
let wal_bytes = buffer.len();
let wal_path = final_path.clone();
let wal_data = Bytes::from(buffer);
let store = object_store.clone();
let (wal_result, index_result) = if let Some(idx_registry) = indexes {
let wal_future = async {
let start = Instant::now();
store
.inner
.put(&wal_path, wal_data.into())
.await
.map_err(|e| Error::io(format!("Failed to write WAL file: {}", e)))?;
Ok::<_, Error>(start.elapsed())
};
let index_future = async {
let start = Instant::now();
let per_index = tokio::task::spawn_blocking(move || {
idx_registry.insert_batches_parallel(&stored_batches)
})
.await
.map_err(|e| Error::internal(format!("Index update task panicked: {}", e)))??;
Ok::<_, Error>((start.elapsed(), per_index))
};
tokio::join!(wal_future, index_future)
} else {
let wal_future = async {
let start = Instant::now();
store
.inner
.put(&wal_path, wal_data.into())
.await
.map_err(|e| Error::io(format!("Failed to write WAL file: {}", e)))?;
Ok::<_, Error>(start.elapsed())
};
(
wal_future.await,
Ok((std::time::Duration::ZERO, std::collections::HashMap::new())),
)
};
let wal_io_duration = wal_result?;
let (index_update_duration, index_update_duration_breakdown) = index_result?;
batch_store.set_max_flushed_batch_position(end_batch_position - 1);
let _ = self.durable_watermark_tx.send(end_batch_position);
self.signal_wal_flush_complete();
let entry = WalEntry {
position: wal_entry_position,
writer_epoch: self.writer_epoch,
num_batches,
};
Ok(WalFlushResult {
entry: Some(entry),
wal_io_duration,
index_update_duration,
index_update_duration_breakdown,
rows_indexed: rows_to_index,
wal_bytes,
})
}
pub fn next_wal_entry_position(&self) -> u64 {
self.next_wal_entry_position.load(Ordering::SeqCst)
}
pub fn region_id(&self) -> Uuid {
self.region_id
}
pub fn writer_epoch(&self) -> u64 {
self.writer_epoch
}
pub fn wal_entry_path(&self, wal_entry_position: u64) -> Path {
let filename = wal_entry_filename(wal_entry_position);
self.wal_dir.child(filename.as_str())
}
}
#[derive(Debug)]
pub struct WalEntryData {
pub writer_epoch: u64,
pub batches: Vec<RecordBatch>,
}
impl WalEntryData {
pub async fn read(object_store: &ObjectStore, path: &Path) -> Result<Self> {
let data = object_store
.inner
.get(path)
.await
.map_err(|e| Error::io(format!("Failed to read WAL file: {}", e)))?
.bytes()
.await
.map_err(|e| Error::io(format!("Failed to get WAL file bytes: {}", e)))?;
let cursor = Cursor::new(data);
let reader = StreamReader::try_new(cursor, None)
.map_err(|e| Error::io(format!("Failed to open Arrow IPC stream reader: {}", e)))?;
let schema = reader.schema();
let writer_epoch = schema
.metadata()
.get(WRITER_EPOCH_KEY)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
let mut batches = Vec::new();
for batch_result in reader {
let batch = batch_result.map_err(|e| {
Error::io(format!("Failed to read batch from Arrow IPC stream: {}", e))
})?;
batches.push(batch);
}
Ok(Self {
writer_epoch,
batches,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
use tempfile::TempDir;
async fn create_local_store() -> (Arc<ObjectStore>, Path, TempDir) {
let temp_dir = tempfile::tempdir().unwrap();
let uri = format!("file://{}", temp_dir.path().display());
let (store, path) = ObjectStore::from_uri(&uri).await.unwrap();
(store, path, temp_dir)
}
fn create_test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &Schema, num_rows: usize) -> RecordBatch {
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from_iter_values(0..num_rows as i32)),
Arc::new(StringArray::from_iter_values(
(0..num_rows).map(|i| format!("name_{}", i)),
)),
],
)
.unwrap()
}
#[tokio::test]
async fn test_wal_flusher_track_batch() {
let (store, base_path, _temp_dir) = create_local_store().await;
let region_id = Uuid::new_v4();
let mut buffer = WalFlusher::new(&base_path, region_id, 1, 1);
buffer.set_object_store(store);
let watcher = buffer.track_batch(0);
assert!(!watcher.is_durable());
}
#[tokio::test]
async fn test_wal_flusher_flush_to_with_index_update() {
let (store, base_path, _temp_dir) = create_local_store().await;
let region_id = Uuid::new_v4();
let mut buffer = WalFlusher::new(&base_path, region_id, 1, 1);
buffer.set_object_store(store);
let schema = create_test_schema();
let batch1 = create_test_batch(&schema, 10);
let batch2 = create_test_batch(&schema, 5);
let batch_store = BatchStore::with_capacity(10);
batch_store.append(batch1).unwrap();
batch_store.append(batch2).unwrap();
let mut watcher1 = buffer.track_batch(0);
let mut watcher2 = buffer.track_batch(1);
assert!(!watcher1.is_durable());
assert!(!watcher2.is_durable());
assert!(batch_store.max_flushed_batch_position().is_none());
let result = buffer
.flush_to_with_index_update(&batch_store, batch_store.len(), None)
.await
.unwrap();
let entry = result.entry.unwrap();
assert_eq!(entry.position, 1);
assert_eq!(entry.writer_epoch, 1);
assert_eq!(entry.num_batches, 2);
assert_eq!(batch_store.max_flushed_batch_position(), Some(1));
watcher1.wait().await.unwrap();
watcher2.wait().await.unwrap();
assert!(watcher1.is_durable());
assert!(watcher2.is_durable());
}
#[tokio::test]
async fn test_wal_entry_read() {
let (store, base_path, _temp_dir) = create_local_store().await;
let region_id = Uuid::new_v4();
let mut buffer = WalFlusher::new(&base_path, region_id, 42, 1);
buffer.set_object_store(store.clone());
let schema = create_test_schema();
let batch_store = BatchStore::with_capacity(10);
batch_store.append(create_test_batch(&schema, 10)).unwrap();
batch_store.append(create_test_batch(&schema, 5)).unwrap();
let _watcher1 = buffer.track_batch(0);
let _watcher2 = buffer.track_batch(1);
let result = buffer
.flush_to_with_index_update(&batch_store, batch_store.len(), None)
.await
.unwrap();
let entry = result.entry.unwrap();
let wal_path = buffer.wal_entry_path(entry.position);
let wal_data = WalEntryData::read(&store, &wal_path).await.unwrap();
assert_eq!(wal_data.writer_epoch, 42);
assert_eq!(wal_data.batches.len(), 2);
assert_eq!(wal_data.batches[0].num_rows(), 10);
assert_eq!(wal_data.batches[1].num_rows(), 5);
}
}