use alloc::vec::Vec;
use std::sync::Arc;
use super::database::{BfTreeDatabase, BfTreeDatabaseWriteTxn};
use super::error::BfTreeError;
pub type WriteBatchFn = Box<dyn FnOnce(&BfTreeDatabaseWriteTxn) -> Result<(), BfTreeError> + Send>;
pub struct GroupCommit {
db: Arc<BfTreeDatabase>,
batches: Vec<WriteBatchFn>,
}
impl GroupCommit {
pub fn new(db: Arc<BfTreeDatabase>) -> Self {
Self {
db,
batches: Vec::new(),
}
}
pub fn add<F>(&mut self, batch: F)
where
F: FnOnce(&BfTreeDatabaseWriteTxn) -> Result<(), BfTreeError> + Send + 'static,
{
self.batches.push(Box::new(batch));
}
pub fn execute(self) -> Result<usize, BfTreeError> {
let count = self.batches.len();
let wtxn = self.db.begin_write();
for batch in self.batches {
batch(&wtxn)?;
}
wtxn.commit()?;
Ok(count)
}
pub fn execute_with_snapshot(self) -> Result<(usize, std::path::PathBuf), BfTreeError> {
let count = self.batches.len();
let wtxn = self.db.begin_write();
for batch in self.batches {
batch(&wtxn)?;
}
wtxn.commit()?;
let path = self.db.snapshot()?;
Ok((count, path))
}
}
pub fn concurrent_group_commit(
db: Arc<BfTreeDatabase>,
batches: Vec<WriteBatchFn>,
) -> Result<usize, BfTreeError> {
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
let count = batches.len();
let abort = Arc::new(AtomicBool::new(false));
let mut handles = Vec::with_capacity(count);
for batch in batches {
let db = db.clone();
let abort = abort.clone();
handles.push(thread::spawn(
move || -> Result<BfTreeDatabaseWriteTxn, BfTreeError> {
if abort.load(Ordering::Acquire) {
return Err(BfTreeError::InvalidOperation(
"batch aborted due to earlier failure".into(),
));
}
let batch_txn = db.begin_write();
match batch(&batch_txn) {
Ok(()) => Ok(batch_txn),
Err(e) => {
abort.store(true, Ordering::Release);
Err(e)
}
}
},
));
}
let mut batch_txns: Vec<BfTreeDatabaseWriteTxn> = Vec::with_capacity(count);
let mut first_error: Option<BfTreeError> = None;
for handle in handles {
match handle.join() {
Ok(Ok(txn)) => batch_txns.push(txn),
Ok(Err(e)) => {
if first_error.is_none() {
first_error = Some(e);
}
}
Err(_) => {
if first_error.is_none() {
first_error = Some(BfTreeError::InvalidOperation(
"batch thread panicked".into(),
));
}
}
}
}
if let Some(err) = first_error {
return Err(err);
}
let commit_txn = db.begin_write();
for batch_txn in &batch_txns {
commit_txn.merge_buffer_from(batch_txn)?;
}
commit_txn.commit()?;
Ok(count)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TableDefinition;
use crate::bf_tree_store::config::BfTreeConfig;
const DATA: TableDefinition<&str, u64> = TableDefinition::new("data");
fn test_db() -> Arc<BfTreeDatabase> {
Arc::new(BfTreeDatabase::create(BfTreeConfig::new_memory(8)).unwrap())
}
#[test]
fn sequential_group_commit() {
let db = test_db();
let mut gc = GroupCommit::new(db.clone());
gc.add(|txn| {
let mut t = txn.open_table(DATA)?;
t.insert(&"a", &1u64)?;
Ok(())
});
gc.add(|txn| {
let mut t = txn.open_table(DATA)?;
t.insert(&"b", &2u64)?;
Ok(())
});
gc.add(|txn| {
let mut t = txn.open_table(DATA)?;
t.insert(&"c", &3u64)?;
Ok(())
});
let count = gc.execute().unwrap();
assert_eq!(count, 3);
let rtxn = db.begin_read();
let mut t = rtxn.open_table(DATA).unwrap();
assert!(t.get(&"a").unwrap().is_some());
assert!(t.get(&"b").unwrap().is_some());
assert!(t.get(&"c").unwrap().is_some());
}
#[test]
fn concurrent_group_commit_test() {
let db = test_db();
let batches: Vec<WriteBatchFn> = (0u64..4)
.map(|i| {
let batch: WriteBatchFn = Box::new(move |txn| {
let mut t = txn.open_table(DATA)?;
let key = alloc::format!("key_{i}");
t.insert(&key.as_str(), &(i * 10))?;
Ok(())
});
batch
})
.collect();
let count = concurrent_group_commit(db.clone(), batches).unwrap();
assert_eq!(count, 4);
let rtxn = db.begin_read();
let mut t = rtxn.open_table(DATA).unwrap();
for i in 0u64..4 {
let key = alloc::format!("key_{i}");
assert!(t.get(&key.as_str()).unwrap().is_some());
}
}
#[test]
fn empty_group_commit() {
let db = test_db();
let gc = GroupCommit::new(db);
assert_eq!(gc.execute().unwrap(), 0);
}
}