use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
use arrow_array::RecordBatch;
#[derive(Clone)]
pub struct StoredBatch {
pub data: RecordBatch,
pub num_rows: usize,
pub row_offset: u64,
pub batch_position: usize,
}
impl StoredBatch {
pub fn new(data: RecordBatch, row_offset: u64, batch_position: usize) -> Self {
let num_rows = data.num_rows();
Self {
data,
num_rows,
row_offset,
batch_position,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StoreFull;
impl std::fmt::Display for StoreFull {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "BatchStore is full")
}
}
impl std::error::Error for StoreFull {}
pub struct BatchStore {
slots: Box<[UnsafeCell<MaybeUninit<StoredBatch>>]>,
committed_len: AtomicUsize,
capacity: usize,
total_rows: AtomicUsize,
estimated_bytes: AtomicUsize,
max_flushed_batch_position: AtomicUsize,
}
unsafe impl Sync for BatchStore {}
unsafe impl Send for BatchStore {}
impl BatchStore {
pub fn with_capacity(capacity: usize) -> Self {
assert!(capacity > 0, "capacity must be > 0");
let mut slots = Vec::with_capacity(capacity);
for _ in 0..capacity {
slots.push(UnsafeCell::new(MaybeUninit::uninit()));
}
Self {
slots: slots.into_boxed_slice(),
committed_len: AtomicUsize::new(0),
capacity,
total_rows: AtomicUsize::new(0),
estimated_bytes: AtomicUsize::new(0),
max_flushed_batch_position: AtomicUsize::new(usize::MAX), }
}
pub fn recommended_capacity(max_memtable_bytes: usize) -> usize {
const AVG_BATCH_SIZE: usize = 64 * 1024; const BUFFER_FACTOR: f64 = 1.2;
let estimated_batches = max_memtable_bytes / AVG_BATCH_SIZE;
let capacity = ((estimated_batches as f64) * BUFFER_FACTOR) as usize;
capacity.max(16) }
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
#[inline]
pub fn is_full(&self) -> bool {
self.committed_len.load(Ordering::Relaxed) >= self.capacity
}
#[inline]
pub fn remaining_capacity(&self) -> usize {
self.capacity
.saturating_sub(self.committed_len.load(Ordering::Relaxed))
}
pub fn append(&self, batch: RecordBatch) -> Result<(usize, u64, usize), StoreFull> {
let idx = self.committed_len.load(Ordering::Relaxed);
if idx >= self.capacity {
return Err(StoreFull);
}
let num_rows = batch.num_rows();
let estimated_size = Self::estimate_batch_size(&batch);
let row_offset = self.total_rows.load(Ordering::Relaxed) as u64;
let stored = StoredBatch::new(batch, row_offset, idx);
unsafe {
let slot_ptr = self.slots[idx].get();
std::ptr::write(slot_ptr, MaybeUninit::new(stored));
}
self.total_rows.fetch_add(num_rows, Ordering::Relaxed);
self.estimated_bytes
.fetch_add(estimated_size, Ordering::Relaxed);
self.committed_len.store(idx + 1, Ordering::Release);
Ok((idx, row_offset, estimated_size))
}
pub fn append_batches(
&self,
batches: Vec<RecordBatch>,
) -> Result<Vec<(usize, u64, usize)>, StoreFull> {
if batches.is_empty() {
return Ok(vec![]);
}
let start_idx = self.committed_len.load(Ordering::Relaxed);
let count = batches.len();
if start_idx + count > self.capacity {
return Err(StoreFull);
}
let mut results = Vec::with_capacity(count);
let mut total_rows_added = 0usize;
let mut total_bytes_added = 0usize;
let mut row_offset = self.total_rows.load(Ordering::Relaxed) as u64;
for (i, batch) in batches.into_iter().enumerate() {
let idx = start_idx + i;
let num_rows = batch.num_rows();
let estimated_size = Self::estimate_batch_size(&batch);
let stored = StoredBatch::new(batch, row_offset, idx);
unsafe {
let slot_ptr = self.slots[idx].get();
std::ptr::write(slot_ptr, MaybeUninit::new(stored));
}
results.push((idx, row_offset, estimated_size));
row_offset += num_rows as u64;
total_rows_added += num_rows;
total_bytes_added += estimated_size;
}
self.total_rows
.fetch_add(total_rows_added, Ordering::Relaxed);
self.estimated_bytes
.fetch_add(total_bytes_added, Ordering::Relaxed);
self.committed_len
.store(start_idx + count, Ordering::Release);
Ok(results)
}
fn estimate_batch_size(batch: &RecordBatch) -> usize {
batch
.columns()
.iter()
.map(|col| col.get_array_memory_size())
.sum::<usize>()
+ std::mem::size_of::<RecordBatch>()
}
#[inline]
pub fn len(&self) -> usize {
self.committed_len.load(Ordering::Acquire)
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn max_buffered_batch_position(&self) -> Option<usize> {
let len = self.len();
if len == 0 { None } else { Some(len - 1) }
}
#[inline]
pub fn total_rows(&self) -> usize {
self.total_rows.load(Ordering::Relaxed)
}
#[inline]
pub fn estimated_bytes(&self) -> usize {
self.estimated_bytes.load(Ordering::Relaxed)
}
#[inline]
pub fn max_flushed_batch_position(&self) -> Option<usize> {
let watermark = self.max_flushed_batch_position.load(Ordering::Acquire);
if watermark == usize::MAX {
None
} else {
Some(watermark)
}
}
#[inline]
pub fn set_max_flushed_batch_position(&self, batch_position: usize) {
debug_assert!(
batch_position != usize::MAX,
"batch_position cannot be usize::MAX (reserved as sentinel)"
);
self.max_flushed_batch_position
.store(batch_position, Ordering::Release);
}
#[inline]
pub fn pending_wal_flush_count(&self) -> usize {
let committed = self.committed_len.load(Ordering::Acquire);
let watermark = self.max_flushed_batch_position.load(Ordering::Acquire);
if watermark == usize::MAX {
committed
} else {
committed.saturating_sub(watermark + 1)
}
}
#[inline]
pub fn is_wal_flush_complete(&self) -> bool {
self.pending_wal_flush_count() == 0
}
#[inline]
pub fn pending_wal_flush_range(&self) -> Option<(usize, usize)> {
let committed = self.committed_len.load(Ordering::Acquire);
let watermark = self.max_flushed_batch_position.load(Ordering::Acquire);
let start = if watermark == usize::MAX {
0
} else {
watermark + 1
};
if committed > start {
Some((start, committed))
} else {
None
}
}
#[inline]
pub fn get(&self, index: usize) -> Option<&StoredBatch> {
let len = self.committed_len.load(Ordering::Acquire);
if index >= len {
return None;
}
unsafe {
let slot_ptr = self.slots[index].get();
Some((*slot_ptr).assume_init_ref())
}
}
#[inline]
pub fn get_batch(&self, index: usize) -> Option<&RecordBatch> {
self.get(index).map(|s| &s.data)
}
pub fn iter(&self) -> BatchStoreIter<'_> {
let len = self.committed_len.load(Ordering::Acquire);
BatchStoreIter {
store: self,
current: 0,
len,
}
}
pub fn to_vec(&self) -> Vec<RecordBatch> {
self.iter().map(|b| b.data.clone()).collect()
}
pub fn to_stored_vec(&self) -> Vec<StoredBatch> {
self.iter().cloned().collect()
}
pub fn iter_reversed(&self) -> BatchStoreIterReversed<'_> {
let len = self.committed_len.load(Ordering::Acquire);
BatchStoreIterReversed {
store: self,
current: len,
}
}
pub fn to_vec_reversed(&self) -> Result<Vec<RecordBatch>, arrow::error::ArrowError> {
use arrow::compute::kernels::take::take;
use arrow_array::UInt32Array;
self.iter_reversed()
.map(|b| {
let num_rows = b.data.num_rows();
if num_rows == 0 {
return Ok(b.data.clone());
}
let indices: Vec<u32> = (0..num_rows as u32).rev().collect();
let indices_array = UInt32Array::from(indices);
let columns: Result<Vec<_>, _> = b
.data
.columns()
.iter()
.map(|col| take(col.as_ref(), &indices_array, None))
.collect();
RecordBatch::try_new(b.data.schema(), columns?)
})
.collect()
}
pub fn to_stored_vec_reversed(&self) -> Vec<StoredBatch> {
self.iter_reversed().cloned().collect()
}
pub fn visible_batches(&self, max_visible_batch_position: usize) -> Vec<&StoredBatch> {
let len = self.committed_len.load(Ordering::Acquire);
let end = (max_visible_batch_position + 1).min(len);
(0..end).filter_map(|i| self.get(i)).collect()
}
pub fn max_visible_batch_positions(&self, max_visible_batch_position: usize) -> Vec<usize> {
let len = self.committed_len.load(Ordering::Acquire);
let end = (max_visible_batch_position + 1).min(len);
(0..end).collect()
}
#[inline]
pub fn is_batch_visible(
&self,
batch_position: usize,
max_visible_batch_position: usize,
) -> bool {
let len = self.committed_len.load(Ordering::Acquire);
batch_position < len && batch_position <= max_visible_batch_position
}
pub fn visible_record_batches(&self, max_visible_batch_position: usize) -> Vec<RecordBatch> {
self.visible_batches(max_visible_batch_position)
.into_iter()
.map(|b| b.data.clone())
.collect()
}
pub fn visible_batches_with_offsets(
&self,
max_visible_batch_position: usize,
) -> Vec<(RecordBatch, u64)> {
self.visible_batches(max_visible_batch_position)
.into_iter()
.map(|b| (b.data.clone(), b.row_offset))
.collect()
}
}
impl Drop for BatchStore {
fn drop(&mut self) {
let len = *self.committed_len.get_mut();
for i in 0..len {
unsafe {
let slot_ptr = self.slots[i].get();
std::ptr::drop_in_place((*slot_ptr).as_mut_ptr());
}
}
}
}
pub struct BatchStoreIter<'a> {
store: &'a BatchStore,
current: usize,
len: usize,
}
impl<'a> Iterator for BatchStoreIter<'a> {
type Item = &'a StoredBatch;
fn next(&mut self) -> Option<Self::Item> {
if self.current >= self.len {
return None;
}
let batch = unsafe {
let slot_ptr = self.store.slots[self.current].get();
(*slot_ptr).assume_init_ref()
};
self.current += 1;
Some(batch)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.len - self.current;
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for BatchStoreIter<'_> {}
pub struct BatchStoreIterReversed<'a> {
store: &'a BatchStore,
current: usize,
}
impl<'a> Iterator for BatchStoreIterReversed<'a> {
type Item = &'a StoredBatch;
fn next(&mut self) -> Option<Self::Item> {
if self.current == 0 {
return None;
}
self.current -= 1;
let batch = unsafe {
let slot_ptr = self.store.slots[self.current].get();
(*slot_ptr).assume_init_ref()
};
Some(batch)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.current, Some(self.current))
}
}
impl ExactSizeIterator for BatchStoreIterReversed<'_> {}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::Int32Array;
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use std::sync::Arc;
fn create_test_schema() -> Arc<ArrowSchema> {
Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("value", DataType::Int32, false),
]))
}
fn create_test_batch(num_rows: usize) -> RecordBatch {
let schema = create_test_schema();
let ids: Vec<i32> = (0..num_rows as i32).collect();
let values: Vec<i32> = ids.iter().map(|id| id * 10).collect();
RecordBatch::try_new(
schema,
vec![
Arc::new(Int32Array::from(ids)),
Arc::new(Int32Array::from(values)),
],
)
.unwrap()
}
#[test]
fn test_create_store() {
let store = BatchStore::with_capacity(10);
assert_eq!(store.capacity(), 10);
assert_eq!(store.len(), 0);
assert!(store.is_empty());
assert!(!store.is_full());
assert_eq!(store.remaining_capacity(), 10);
}
#[test]
fn test_append_single() {
let store = BatchStore::with_capacity(10);
let batch = create_test_batch(100);
let (id, row_offset, _size) = store.append(batch).unwrap();
assert_eq!(id, 0);
assert_eq!(row_offset, 0); assert_eq!(store.len(), 1);
assert!(!store.is_empty());
assert_eq!(store.total_rows(), 100);
}
#[test]
fn test_append_multiple() {
let store = BatchStore::with_capacity(10);
let mut expected_row_offset = 0u64;
for i in 0..5 {
let num_rows = 10 * (i + 1);
let batch = create_test_batch(num_rows);
let (id, row_offset, _size) = store.append(batch).unwrap();
assert_eq!(id, i);
assert_eq!(row_offset, expected_row_offset);
expected_row_offset += num_rows as u64;
}
assert_eq!(store.len(), 5);
assert_eq!(store.total_rows(), 10 + 20 + 30 + 40 + 50);
}
#[test]
fn test_capacity_limit() {
let store = BatchStore::with_capacity(3);
store.append(create_test_batch(10)).unwrap();
store.append(create_test_batch(10)).unwrap();
store.append(create_test_batch(10)).unwrap();
assert!(store.is_full());
assert_eq!(store.remaining_capacity(), 0);
let result = store.append(create_test_batch(10));
assert!(result.is_err());
assert_eq!(result.unwrap_err(), StoreFull);
}
#[test]
fn test_get_batch() {
let store = BatchStore::with_capacity(10);
let batch1 = create_test_batch(10);
let batch2 = create_test_batch(20);
store.append(batch1).unwrap();
store.append(batch2).unwrap();
let retrieved1 = store.get(0).unwrap();
assert_eq!(retrieved1.num_rows, 10);
assert_eq!(retrieved1.row_offset, 0);
let retrieved2 = store.get(1).unwrap();
assert_eq!(retrieved2.num_rows, 20);
assert_eq!(retrieved2.row_offset, 10);
assert!(store.get(2).is_none());
assert!(store.get(100).is_none());
}
#[test]
fn test_iter() {
let store = BatchStore::with_capacity(10);
for _ in 0..5 {
store.append(create_test_batch(10)).unwrap();
}
let batches: Vec<_> = store.iter().collect();
assert_eq!(batches.len(), 5);
}
#[test]
fn test_visibility_filtering() {
let store = BatchStore::with_capacity(10);
store.append(create_test_batch(10)).unwrap(); store.append(create_test_batch(10)).unwrap(); store.append(create_test_batch(10)).unwrap(); store.append(create_test_batch(10)).unwrap(); store.append(create_test_batch(10)).unwrap();
let visible = store.max_visible_batch_positions(2);
assert_eq!(visible, vec![0, 1, 2]);
let visible = store.max_visible_batch_positions(4);
assert_eq!(visible, vec![0, 1, 2, 3, 4]);
let visible = store.max_visible_batch_positions(0);
assert_eq!(visible, vec![0]);
}
#[test]
fn test_is_batch_visible() {
let store = BatchStore::with_capacity(10);
store.append(create_test_batch(10)).unwrap(); store.append(create_test_batch(10)).unwrap(); store.append(create_test_batch(10)).unwrap();
assert!(store.is_batch_visible(0, 0));
assert!(store.is_batch_visible(0, 1));
assert!(store.is_batch_visible(0, 2));
assert!(!store.is_batch_visible(2, 1));
assert!(store.is_batch_visible(2, 2));
assert!(store.is_batch_visible(2, 3));
assert!(!store.is_batch_visible(3, 10));
}
#[test]
fn test_recommended_capacity() {
let cap = BatchStore::recommended_capacity(64 * 1024 * 1024);
assert!(
(1200..=1300).contains(&cap),
"capacity should be around 1200, got {}",
cap
);
let cap = BatchStore::recommended_capacity(1024);
assert_eq!(cap, 16); }
#[test]
fn test_to_vec() {
let store = BatchStore::with_capacity(10);
let batch1 = create_test_batch(10);
let batch2 = create_test_batch(20);
store.append(batch1).unwrap();
store.append(batch2).unwrap();
let vec = store.to_vec();
assert_eq!(vec.len(), 2);
assert_eq!(vec[0].num_rows(), 10);
assert_eq!(vec[1].num_rows(), 20);
}
#[test]
fn test_to_vec_reversed() {
let store = BatchStore::with_capacity(10);
let batch1 = create_test_batch(10);
let batch2 = create_test_batch(5);
store.append(batch1).unwrap();
store.append(batch2).unwrap();
let forward = store.to_vec();
assert_eq!(forward.len(), 2);
assert_eq!(forward[0].num_rows(), 10);
assert_eq!(forward[1].num_rows(), 5);
let ids = forward[0]
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(ids.value(0), 0);
assert_eq!(ids.value(9), 9);
let reversed = store.to_vec_reversed().unwrap();
assert_eq!(reversed.len(), 2);
assert_eq!(reversed[0].num_rows(), 5); assert_eq!(reversed[1].num_rows(), 10);
let ids = reversed[0]
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(ids.value(0), 4); assert_eq!(ids.value(4), 0);
let ids = reversed[1]
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(ids.value(0), 9); assert_eq!(ids.value(9), 0); }
#[test]
fn test_iter_reversed() {
let store = BatchStore::with_capacity(10);
for i in 0..5 {
store.append(create_test_batch(10 * (i + 1))).unwrap();
}
let forward: Vec<_> = store.iter().map(|b| b.batch_position).collect();
assert_eq!(forward, vec![0, 1, 2, 3, 4]);
let reversed: Vec<_> = store.iter_reversed().map(|b| b.batch_position).collect();
assert_eq!(reversed, vec![4, 3, 2, 1, 0]);
let forward_rows: Vec<_> = store.iter().map(|b| b.num_rows).collect();
let reversed_rows: Vec<_> = store.iter_reversed().map(|b| b.num_rows).collect();
assert_eq!(forward_rows, vec![10, 20, 30, 40, 50]);
assert_eq!(reversed_rows, vec![50, 40, 30, 20, 10]);
}
#[test]
fn test_iter_reversed_empty() {
let store = BatchStore::with_capacity(10);
let reversed: Vec<_> = store.iter_reversed().collect();
assert!(reversed.is_empty());
}
#[test]
fn test_concurrent_readers() {
use std::sync::Arc;
use std::thread;
let store = Arc::new(BatchStore::with_capacity(100));
for _ in 0..50 {
store.append(create_test_batch(10)).unwrap();
}
let readers: Vec<_> = (0..4)
.map(|_| {
let reader_store = store.clone();
thread::spawn(move || {
for _ in 0..100 {
let len = reader_store.len();
assert_eq!(len, 50);
for i in 0..len {
let batch = reader_store.get(i);
assert!(batch.is_some());
assert_eq!(batch.unwrap().num_rows, 10);
}
let count = reader_store.iter().count();
assert_eq!(count, 50);
thread::yield_now();
}
})
})
.collect();
for r in readers {
r.join().unwrap();
}
}
#[test]
fn test_append_batches() {
let store = BatchStore::with_capacity(10);
let batches: Vec<_> = (0..5).map(|i| create_test_batch(10 * (i + 1))).collect();
let results = store.append_batches(batches).unwrap();
assert_eq!(results.len(), 5);
assert_eq!(store.len(), 5);
for (i, (batch_pos, _, _)) in results.iter().enumerate() {
assert_eq!(*batch_pos, i);
}
assert_eq!(results[0].1, 0); assert_eq!(results[1].1, 10); assert_eq!(results[2].1, 30); assert_eq!(results[3].1, 60); assert_eq!(results[4].1, 100);
assert_eq!(store.total_rows(), 10 + 20 + 30 + 40 + 50);
}
#[test]
fn test_append_batches_capacity_check() {
let store = BatchStore::with_capacity(3);
let batches: Vec<_> = (0..2).map(|_| create_test_batch(10)).collect();
store.append_batches(batches).unwrap();
assert_eq!(store.len(), 2);
let batches: Vec<_> = (0..2).map(|_| create_test_batch(10)).collect();
let result = store.append_batches(batches);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), StoreFull);
assert_eq!(store.len(), 2);
}
#[test]
fn test_append_batches_empty() {
let store = BatchStore::with_capacity(10);
let results = store.append_batches(vec![]).unwrap();
assert!(results.is_empty());
assert_eq!(store.len(), 0);
}
#[test]
fn test_concurrent_read_write() {
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::thread;
let store = Arc::new(BatchStore::with_capacity(200));
let done = Arc::new(AtomicBool::new(false));
let writer_store = store.clone();
let writer_done = done.clone();
let writer = thread::spawn(move || {
for _ in 0..100 {
writer_store.append(create_test_batch(10)).unwrap();
thread::yield_now();
}
writer_done.store(true, Ordering::Release);
});
let readers: Vec<_> = (0..4)
.map(|_| {
let reader_store = store.clone();
let reader_done = done.clone();
thread::spawn(move || {
while !reader_done.load(Ordering::Acquire) {
let len = reader_store.len();
for i in 0..len {
let batch = reader_store.get(i);
assert!(batch.is_some());
}
thread::yield_now();
}
assert_eq!(reader_store.len(), 100);
})
})
.collect();
writer.join().unwrap();
for r in readers {
r.join().unwrap();
}
}
}