use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::time::{Duration, Instant};
use manifoldb_core::TransactionError;
use manifoldb_storage::{StorageEngine, Transaction};
#[derive(Debug, Clone)]
pub struct BatchWriterConfig {
pub max_batch_size: usize,
pub flush_interval: Duration,
pub enabled: bool,
}
impl Default for BatchWriterConfig {
fn default() -> Self {
Self { max_batch_size: 100, flush_interval: Duration::from_millis(10), enabled: true }
}
}
impl BatchWriterConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn max_batch_size(mut self, size: usize) -> Self {
self.max_batch_size = size;
self
}
#[must_use]
pub const fn flush_interval(mut self, interval: Duration) -> Self {
self.flush_interval = interval;
self
}
#[must_use]
pub const fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
#[must_use]
pub fn disabled() -> Self {
Self { enabled: false, ..Default::default() }
}
}
#[derive(Debug, Clone)]
pub enum WriteOp {
Put {
table: String,
key: Vec<u8>,
value: Vec<u8>,
},
Delete {
table: String,
key: Vec<u8>,
},
}
#[derive(Debug, Default)]
pub struct WriteBuffer {
ops: Vec<WriteOp>,
index: HashMap<(String, Vec<u8>), Option<usize>>,
}
impl WriteBuffer {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn put(&mut self, table: String, key: Vec<u8>, value: Vec<u8>) {
let idx = self.ops.len();
self.ops.push(WriteOp::Put { table: table.clone(), key: key.clone(), value });
self.index.insert((table, key), Some(idx));
}
pub fn delete(&mut self, table: String, key: Vec<u8>) {
self.ops.push(WriteOp::Delete { table: table.clone(), key: key.clone() });
self.index.insert((table, key), None);
}
#[must_use]
pub fn get(&self, table: &str, key: &[u8]) -> Option<Option<&[u8]>> {
self.index.get(&(table.to_string(), key.to_vec())).map(|idx| {
idx.map(|i| {
if let WriteOp::Put { value, .. } = &self.ops[i] {
value.as_slice()
} else {
&[][..]
}
})
})
}
#[must_use]
pub fn is_deleted(&self, table: &str, key: &[u8]) -> bool {
matches!(self.index.get(&(table.to_string(), key.to_vec())), Some(None))
}
#[must_use]
pub fn len(&self) -> usize {
self.ops.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
#[must_use]
pub fn into_ops(self) -> Vec<WriteOp> {
self.ops
}
#[must_use]
pub fn ops(&self) -> &[WriteOp] {
&self.ops
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BatchEntryStatus {
Pending,
Committed,
Failed,
}
struct PendingEntry {
tx_id: u64,
ops: Vec<WriteOp>,
status: BatchEntryStatus,
error: Option<String>,
}
struct WriteQueueState {
pending: Vec<PendingEntry>,
batch_start: Instant,
}
pub struct WriteQueue<E: StorageEngine> {
engine: Arc<E>,
config: BatchWriterConfig,
state: Mutex<WriteQueueState>,
commit_complete: Condvar,
flushing: AtomicBool,
tx_counter: AtomicU64,
}
impl<E: StorageEngine> WriteQueue<E> {
pub fn new(engine: Arc<E>, config: BatchWriterConfig) -> Self {
Self {
engine,
config,
state: Mutex::new(WriteQueueState { pending: Vec::new(), batch_start: Instant::now() }),
commit_complete: Condvar::new(),
flushing: AtomicBool::new(false),
tx_counter: AtomicU64::new(0),
}
}
#[must_use]
pub fn next_tx_id(&self) -> u64 {
self.tx_counter.fetch_add(1, Ordering::Relaxed)
}
pub fn submit(&self, tx_id: u64, ops: Vec<WriteOp>) -> Result<(), TransactionError> {
if !self.config.enabled || ops.is_empty() {
return self.commit_immediately(ops);
}
let should_flush = {
let mut state = self.state.lock().map_err(|e| {
TransactionError::Internal(format!("failed to acquire write queue lock: {e}"))
})?;
if state.pending.is_empty() {
state.batch_start = Instant::now();
}
state.pending.push(PendingEntry {
tx_id,
ops,
status: BatchEntryStatus::Pending,
error: None,
});
state.pending.len() >= self.config.max_batch_size
};
if should_flush {
self.flush()?;
} else {
self.maybe_flush_on_timeout()?;
}
self.wait_for_commit(tx_id)
}
fn commit_immediately(&self, ops: Vec<WriteOp>) -> Result<(), TransactionError> {
if ops.is_empty() {
return Ok(());
}
let mut tx = self.engine.begin_write().map_err(|e| {
TransactionError::Storage(format!("failed to begin write transaction: {e}"))
})?;
for op in ops {
match op {
WriteOp::Put { table, key, value } => {
tx.put(&table, &key, &value)
.map_err(|e| TransactionError::Storage(format!("put failed: {e}")))?;
}
WriteOp::Delete { table, key } => {
tx.delete(&table, &key)
.map_err(|e| TransactionError::Storage(format!("delete failed: {e}")))?;
}
}
}
tx.commit().map_err(|e| TransactionError::Storage(format!("commit failed: {e}")))
}
fn maybe_flush_on_timeout(&self) -> Result<(), TransactionError> {
let should_flush = {
let state = self.state.lock().map_err(|e| {
TransactionError::Internal(format!("failed to acquire write queue lock: {e}"))
})?;
!state.pending.is_empty() && state.batch_start.elapsed() >= self.config.flush_interval
};
if should_flush {
self.flush()?;
}
Ok(())
}
pub fn flush(&self) -> Result<(), TransactionError> {
if self.flushing.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err()
{
return Ok(());
}
let result = self.do_flush();
self.flushing.store(false, Ordering::SeqCst);
self.commit_complete.notify_all();
result
}
fn do_flush(&self) -> Result<(), TransactionError> {
let entries: Vec<PendingEntry> = {
let mut state = self.state.lock().map_err(|e| {
TransactionError::Internal(format!("failed to acquire write queue lock: {e}"))
})?;
std::mem::take(&mut state.pending)
};
if entries.is_empty() {
return Ok(());
}
let commit_result = self.apply_batch(&entries);
{
let mut state = self.state.lock().map_err(|e| {
TransactionError::Internal(format!("failed to acquire write queue lock: {e}"))
})?;
for mut entry in entries {
match &commit_result {
Ok(()) => {
entry.status = BatchEntryStatus::Committed;
}
Err(e) => {
entry.status = BatchEntryStatus::Failed;
entry.error = Some(e.to_string());
}
}
state.pending.push(entry);
}
}
commit_result
}
fn apply_batch(&self, entries: &[PendingEntry]) -> Result<(), TransactionError> {
let mut tx = self.engine.begin_write().map_err(|e| {
TransactionError::Storage(format!("failed to begin write transaction: {e}"))
})?;
for entry in entries {
for op in &entry.ops {
match op {
WriteOp::Put { table, key, value } => {
tx.put(table, key, value)
.map_err(|e| TransactionError::Storage(format!("put failed: {e}")))?;
}
WriteOp::Delete { table, key } => {
tx.delete(table, key).map_err(|e| {
TransactionError::Storage(format!("delete failed: {e}"))
})?;
}
}
}
}
tx.commit().map_err(|e| TransactionError::Storage(format!("commit failed: {e}")))
}
fn wait_for_commit(&self, tx_id: u64) -> Result<(), TransactionError> {
loop {
{
let mut state = self.state.lock().map_err(|e| {
TransactionError::Internal(format!("failed to acquire write queue lock: {e}"))
})?;
let mut found_idx = None;
for (i, entry) in state.pending.iter().enumerate() {
if entry.tx_id == tx_id {
match entry.status {
BatchEntryStatus::Pending => {
}
BatchEntryStatus::Committed => {
found_idx = Some(i);
break;
}
BatchEntryStatus::Failed => {
let error =
entry.error.clone().unwrap_or_else(|| "unknown".to_string());
state.pending.remove(i);
return Err(TransactionError::Storage(format!(
"batch commit failed: {error}"
)));
}
}
}
}
if let Some(idx) = found_idx {
state.pending.remove(idx);
return Ok(());
}
}
self.maybe_flush_on_timeout()?;
{
let state = self.state.lock().map_err(|e| {
TransactionError::Internal(format!("failed to acquire write queue lock: {e}"))
})?;
let _result = self
.commit_complete
.wait_timeout(state, Duration::from_millis(1))
.map_err(|e| {
TransactionError::Internal(format!("condition variable wait failed: {e}"))
})?;
}
}
}
#[must_use]
pub fn pending_count(&self) -> usize {
self.state.lock().map(|s| s.pending.len()).unwrap_or(0)
}
#[must_use]
pub const fn config(&self) -> &BatchWriterConfig {
&self.config
}
}
pub struct BatchWriter<E: StorageEngine> {
queue: Arc<WriteQueue<E>>,
}
impl<E: StorageEngine> BatchWriter<E> {
pub fn new(engine: Arc<E>, config: BatchWriterConfig) -> Self {
Self { queue: Arc::new(WriteQueue::new(engine, config)) }
}
pub fn with_defaults(engine: Arc<E>) -> Self {
Self::new(engine, BatchWriterConfig::default())
}
#[must_use]
pub fn queue(&self) -> &Arc<WriteQueue<E>> {
&self.queue
}
#[must_use]
pub fn begin(&self) -> BatchedTransaction<E> {
let tx_id = self.queue.next_tx_id();
BatchedTransaction::new(tx_id, Arc::clone(&self.queue))
}
pub fn flush(&self) -> Result<(), TransactionError> {
self.queue.flush()
}
#[must_use]
pub fn pending_count(&self) -> usize {
self.queue.pending_count()
}
}
impl<E: StorageEngine> Clone for BatchWriter<E> {
fn clone(&self) -> Self {
Self { queue: Arc::clone(&self.queue) }
}
}
pub struct BatchedTransaction<E: StorageEngine> {
tx_id: u64,
queue: Arc<WriteQueue<E>>,
buffer: WriteBuffer,
completed: bool,
}
impl<E: StorageEngine> BatchedTransaction<E> {
fn new(tx_id: u64, queue: Arc<WriteQueue<E>>) -> Self {
Self { tx_id, queue, buffer: WriteBuffer::new(), completed: false }
}
#[must_use]
pub const fn id(&self) -> u64 {
self.tx_id
}
pub fn get(&self, table: &str, key: &[u8]) -> Result<Option<Vec<u8>>, TransactionError> {
if self.completed {
return Err(TransactionError::AlreadyCompleted);
}
if let Some(buffered) = self.buffer.get(table, key) {
return Ok(buffered.map(|v| v.to_vec()));
}
let tx = self.queue.engine.begin_read().map_err(|e| {
TransactionError::Storage(format!("failed to begin read transaction: {e}"))
})?;
tx.get(table, key).map_err(|e| TransactionError::Storage(format!("get failed: {e}")))
}
pub fn put(&mut self, table: &str, key: &[u8], value: &[u8]) -> Result<(), TransactionError> {
if self.completed {
return Err(TransactionError::AlreadyCompleted);
}
self.buffer.put(table.to_string(), key.to_vec(), value.to_vec());
Ok(())
}
pub fn delete(&mut self, table: &str, key: &[u8]) -> Result<bool, TransactionError> {
if self.completed {
return Err(TransactionError::AlreadyCompleted);
}
let exists = self.get(table, key)?.is_some();
if exists {
self.buffer.delete(table.to_string(), key.to_vec());
}
Ok(exists)
}
pub fn commit(mut self) -> Result<(), TransactionError> {
if self.completed {
return Err(TransactionError::AlreadyCompleted);
}
self.completed = true;
let ops = std::mem::take(&mut self.buffer).into_ops();
self.queue.submit(self.tx_id, ops)
}
pub fn rollback(mut self) -> Result<(), TransactionError> {
if self.completed {
return Err(TransactionError::AlreadyCompleted);
}
self.completed = true;
Ok(())
}
#[must_use]
pub fn buffered_ops(&self) -> usize {
self.buffer.len()
}
}
impl<E: StorageEngine> Drop for BatchedTransaction<E> {
fn drop(&mut self) {
if !self.completed {
self.completed = true;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use manifoldb_storage::backends::RedbEngine;
use std::sync::atomic::AtomicUsize;
use std::thread;
fn create_test_engine() -> RedbEngine {
RedbEngine::in_memory().expect("failed to create in-memory engine")
}
#[test]
fn test_write_buffer_basic() {
let mut buffer = WriteBuffer::new();
buffer.put("table".to_string(), b"key1".to_vec(), b"value1".to_vec());
buffer.put("table".to_string(), b"key2".to_vec(), b"value2".to_vec());
assert_eq!(buffer.len(), 2);
assert_eq!(buffer.get("table", b"key1"), Some(Some(b"value1".as_slice())));
assert_eq!(buffer.get("table", b"key2"), Some(Some(b"value2".as_slice())));
assert_eq!(buffer.get("table", b"key3"), None);
}
#[test]
fn test_write_buffer_overwrite() {
let mut buffer = WriteBuffer::new();
buffer.put("table".to_string(), b"key".to_vec(), b"value1".to_vec());
buffer.put("table".to_string(), b"key".to_vec(), b"value2".to_vec());
assert_eq!(buffer.get("table", b"key"), Some(Some(b"value2".as_slice())));
}
#[test]
fn test_write_buffer_delete() {
let mut buffer = WriteBuffer::new();
buffer.put("table".to_string(), b"key".to_vec(), b"value".to_vec());
buffer.delete("table".to_string(), b"key".to_vec());
assert_eq!(buffer.get("table", b"key"), Some(None));
assert!(buffer.is_deleted("table", b"key"));
}
#[test]
fn test_batch_writer_immediate_commit() {
let engine = Arc::new(create_test_engine());
let writer = BatchWriter::new(engine.clone(), BatchWriterConfig::disabled());
let mut tx = writer.begin();
tx.put("test", b"key", b"value").expect("put failed");
tx.commit().expect("commit failed");
let read_tx = engine.begin_read().expect("begin_read failed");
let value = read_tx.get("test", b"key").expect("get failed");
assert_eq!(value, Some(b"value".to_vec()));
}
#[test]
fn test_batch_writer_read_your_writes() {
let engine = Arc::new(create_test_engine());
let writer = BatchWriter::new(engine, BatchWriterConfig::default());
let mut tx = writer.begin();
tx.put("test", b"key", b"value").expect("put failed");
let value = tx.get("test", b"key").expect("get failed");
assert_eq!(value, Some(b"value".to_vec()));
tx.commit().expect("commit failed");
}
#[test]
fn test_batch_writer_isolation() {
let engine = Arc::new(create_test_engine());
let writer = BatchWriter::new(engine.clone(), BatchWriterConfig::disabled());
{
let mut tx = writer.begin();
tx.put("test", b"key", b"initial").expect("put failed");
tx.commit().expect("commit failed");
}
let mut tx1 = writer.begin();
let mut tx2 = writer.begin();
tx1.put("test", b"key", b"tx1_value").expect("put failed");
let value = tx2.get("test", b"key").expect("get failed");
assert_eq!(value, Some(b"initial".to_vec()));
tx2.put("test", b"key", b"tx2_value").expect("put failed");
let value = tx2.get("test", b"key").expect("get failed");
assert_eq!(value, Some(b"tx2_value".to_vec()));
tx1.commit().expect("commit failed");
tx2.commit().expect("commit failed");
let read_tx = engine.begin_read().expect("begin_read failed");
let value = read_tx.get("test", b"key").expect("get failed");
assert_eq!(value, Some(b"tx2_value".to_vec()));
}
#[test]
fn test_batch_writer_rollback() {
let engine = Arc::new(create_test_engine());
let writer = BatchWriter::new(engine.clone(), BatchWriterConfig::disabled());
let mut tx = writer.begin();
tx.put("test", b"key", b"value").expect("put failed");
tx.rollback().expect("rollback failed");
let read_tx = engine.begin_read().expect("begin_read failed");
let value = read_tx.get("test", b"key").expect("get failed");
assert_eq!(value, None);
}
#[test]
fn test_batch_writer_concurrent() {
let engine = Arc::new(create_test_engine());
let writer = BatchWriter::new(
engine.clone(),
BatchWriterConfig::default()
.max_batch_size(10)
.flush_interval(Duration::from_millis(5)),
);
let num_threads = 4;
let writes_per_thread = 25;
let counter = Arc::new(AtomicUsize::new(0));
let handles: Vec<_> = (0..num_threads)
.map(|thread_id| {
let writer = writer.clone();
let counter = Arc::clone(&counter);
thread::spawn(move || {
for i in 0..writes_per_thread {
let key = format!("thread{thread_id}_key{i}");
let value = format!("value{i}");
let mut tx = writer.begin();
tx.put("test", key.as_bytes(), value.as_bytes()).expect("put failed");
tx.commit().expect("commit failed");
counter.fetch_add(1, Ordering::Relaxed);
}
})
})
.collect();
for handle in handles {
handle.join().expect("thread panicked");
}
assert_eq!(counter.load(Ordering::Relaxed), num_threads * writes_per_thread);
let read_tx = engine.begin_read().expect("begin_read failed");
for thread_id in 0..num_threads {
for i in 0..writes_per_thread {
let key = format!("thread{thread_id}_key{i}");
let expected_value = format!("value{i}");
let value = read_tx.get("test", key.as_bytes()).expect("get failed");
assert_eq!(
value,
Some(expected_value.into_bytes()),
"missing or wrong value for {key}"
);
}
}
}
#[test]
fn test_batch_writer_config() {
let config = BatchWriterConfig::new()
.max_batch_size(50)
.flush_interval(Duration::from_millis(20))
.enabled(true);
assert_eq!(config.max_batch_size, 50);
assert_eq!(config.flush_interval, Duration::from_millis(20));
assert!(config.enabled);
let disabled = BatchWriterConfig::disabled();
assert!(!disabled.enabled);
}
}