use bytes::Bytes;
use std::sync::atomic::Ordering;
use crate::db::{DBError, Result, DB};
use crate::wal::{BatchOp, Record};
#[derive(Clone, Debug)]
enum Operation {
Put { key: Bytes, value: Bytes },
Delete { key: Bytes },
}
pub struct Batch<'db> {
db: &'db DB,
operations: Vec<Operation>,
}
impl<'db> Batch<'db> {
pub(crate) const fn new(db: &'db DB) -> Self {
Self {
db,
operations: Vec::new(),
}
}
pub fn with_capacity(db: &'db DB, capacity: usize) -> Self {
Self {
db,
operations: Vec::with_capacity(capacity),
}
}
pub fn put(&mut self, key: impl AsRef<[u8]>, value: impl AsRef<[u8]>) {
let key = Bytes::copy_from_slice(key.as_ref());
let value = Bytes::copy_from_slice(value.as_ref());
self.operations.push(Operation::Put { key, value });
}
pub fn delete(&mut self, key: impl AsRef<[u8]>) {
let key = Bytes::copy_from_slice(key.as_ref());
self.operations.push(Operation::Delete { key });
}
#[must_use]
pub const fn len(&self) -> usize {
self.operations.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.operations.is_empty()
}
pub fn clear(&mut self) {
self.operations.clear();
}
pub fn commit(self) -> Result<()> {
if self.operations.is_empty() {
return Ok(());
}
let op_count = self.operations.len() as u64;
let base_seq = self.db.next_seq.fetch_add(op_count, Ordering::SeqCst);
let wal_ops: Vec<BatchOp> = self
.operations
.iter()
.map(|op| match op {
Operation::Put { key, value } => BatchOp::Put {
key: key.clone(),
value: value.clone(),
},
Operation::Delete { key } => BatchOp::Delete { key: key.clone() },
})
.collect();
let batch_record = Record::Batch {
base_seq,
operations: wal_ops,
};
if self.db.options.skip_wal {
self.db.apply_wal_records(&[batch_record]);
} else {
self.db
.pipelined_wal
.put(batch_record, |records| {
self.db.apply_wal_records(records);
})
.map_err(DBError::Wal)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DB;
use tempfile::tempdir;
#[test]
fn test_batch_basic() {
let dir = tempdir().unwrap();
let db = DB::open(dir.path()).unwrap();
let mut batch = db.batch();
batch.put(b"key1", b"value1");
batch.put(b"key2", b"value2");
batch.delete(b"key3");
assert_eq!(batch.len(), 3);
assert!(!batch.is_empty());
batch.commit().unwrap();
assert_eq!(db.get(b"key1").unwrap(), Some(Bytes::from("value1")));
assert_eq!(db.get(b"key2").unwrap(), Some(Bytes::from("value2")));
assert_eq!(db.get(b"key3").unwrap(), None);
}
#[test]
fn test_batch_empty() {
let dir = tempdir().unwrap();
let db = DB::open(dir.path()).unwrap();
let batch = db.batch();
assert!(batch.is_empty());
assert_eq!(batch.len(), 0);
batch.commit().unwrap();
}
#[test]
fn test_batch_with_capacity() {
let dir = tempdir().unwrap();
let db = DB::open(dir.path()).unwrap();
let mut batch = db.batch_with_capacity(100);
for i in 0..100 {
batch.put(format!("key_{}", i).as_bytes(), b"value");
}
assert_eq!(batch.len(), 100);
batch.commit().unwrap();
for i in 0..100 {
let key = format!("key_{}", i);
assert_eq!(db.get(key.as_bytes()).unwrap(), Some(Bytes::from("value")));
}
}
}