use alloc::collections::BTreeMap;
use alloc::vec;
use alloc::vec::Vec;
use core::marker::PhantomData;
use crate::storage_traits::OwnedKv;
use crate::types::{Key, Value};
use alloc::sync::Arc;
use super::adapter::BfTreeAdapter;
use super::database::{BfTreeTableScan, TableKind, table_prefix, table_prefix_end};
use super::error::BfTreeError;
use super::verification::{VerifyMode, should_verify, unwrap_value};
pub(crate) enum BufferLookup {
Found(Vec<u8>),
Tombstone,
NotInBuffer,
}
const DEFAULT_MAX_BUFFER_ENTRIES: usize = 1_000_000;
pub(crate) struct WriteBuffer {
entries: BTreeMap<Vec<u8>, Option<Vec<u8>>>,
max_buffer_entries: usize,
}
impl WriteBuffer {
pub(crate) fn new() -> Self {
Self {
entries: BTreeMap::new(),
max_buffer_entries: DEFAULT_MAX_BUFFER_ENTRIES,
}
}
#[allow(dead_code)]
pub(crate) fn with_max_entries(max_buffer_entries: usize) -> Self {
Self {
entries: BTreeMap::new(),
max_buffer_entries,
}
}
pub(crate) fn put(&mut self, encoded_key: Vec<u8>, value: Vec<u8>) -> Result<(), BfTreeError> {
if self.entries.len() >= self.max_buffer_entries && !self.entries.contains_key(&encoded_key)
{
return Err(BfTreeError::InvalidKV(alloc::format!(
"write buffer full: {} entries (limit {})",
self.entries.len(),
self.max_buffer_entries
)));
}
self.entries.insert(encoded_key, Some(value));
Ok(())
}
pub(crate) fn delete(&mut self, encoded_key: Vec<u8>) {
self.entries.insert(encoded_key, None);
}
pub(crate) fn get(&self, encoded_key: &[u8]) -> BufferLookup {
match self.entries.get(encoded_key) {
Some(Some(value)) => BufferLookup::Found(value.clone()),
Some(None) => BufferLookup::Tombstone,
None => BufferLookup::NotInBuffer,
}
}
pub(crate) fn range(
&self,
start: &[u8],
end: &[u8],
) -> alloc::collections::btree_map::Range<'_, Vec<u8>, Option<Vec<u8>>> {
use core::ops::Bound;
self.entries.range::<Vec<u8>, _>((
Bound::Included(start.to_vec()),
Bound::Included(end.to_vec()),
))
}
pub(crate) fn range_excluded_end(
&self,
start: &[u8],
end: &[u8],
) -> alloc::collections::btree_map::Range<'_, Vec<u8>, Option<Vec<u8>>> {
use core::ops::Bound;
self.entries.range::<Vec<u8>, _>((
Bound::Included(start.to_vec()),
Bound::Excluded(end.to_vec()),
))
}
pub(crate) fn prefix_range(
&self,
prefix: &[u8],
) -> impl Iterator<Item = (&Vec<u8>, &Option<Vec<u8>>)> {
use core::ops::Bound;
let prefix_vec = prefix.to_vec();
self.entries
.range::<Vec<u8>, _>((Bound::Included(prefix_vec.clone()), Bound::Unbounded))
.take_while(move |(k, _)| k.starts_with(&prefix_vec))
}
#[allow(dead_code)]
pub(crate) fn len(&self) -> usize {
self.entries.len()
}
#[allow(dead_code)]
pub(crate) fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub(crate) fn flush(
&mut self,
adapter: &BfTreeAdapter,
durability: super::config::DurabilityMode,
) -> Result<(), BfTreeError> {
let max_record_size = adapter.inner().config().get_cb_max_record_size();
let max_key_len = adapter.max_key_len();
for (key, value) in &self.entries {
if key.is_empty() {
return Err(BfTreeError::InvalidKV(alloc::string::String::from(
"key must not be empty",
)));
}
if key.len() > max_key_len {
return Err(BfTreeError::InvalidKV(alloc::format!(
"key size {} exceeds max {}",
key.len(),
max_key_len
)));
}
if let Some(val) = value.as_ref().filter(|v| v.len() > max_record_size) {
return Err(BfTreeError::InvalidKV(alloc::format!(
"value size {} exceeds max {}",
val.len(),
max_record_size
)));
}
}
let mut insert_pairs: Vec<(&[u8], &[u8])> = Vec::new();
let mut delete_keys: Vec<&[u8]> = Vec::new();
let mut delete_prev_values: Vec<(Vec<u8>, Option<Vec<u8>>)> = Vec::new();
let mut delete_read_buf = vec![0u8; max_record_size];
for (key, value) in &self.entries {
if let Some(val) = value {
insert_pairs.push((key.as_slice(), val.as_slice()));
} else {
let prev = match adapter.read(key, &mut delete_read_buf) {
Ok(len) => Some(delete_read_buf[..len as usize].to_vec()),
Err(_) => None,
};
delete_keys.push(key.as_slice());
delete_prev_values.push((key.clone(), prev));
}
}
if !insert_pairs.is_empty()
&& let Err(flush_err) = adapter.batch_insert_sorted_deferred_wal(&insert_pairs)
{
let flushed_inserts: Vec<Vec<u8>> =
insert_pairs.iter().map(|(k, _)| k.to_vec()).collect();
if let Err((rollback_failures, last_rollback_error)) =
Self::compensate_rollback(adapter, &flushed_inserts, &[])
{
return Err(BfTreeError::PartialFlushRollbackFailed {
flush_error: alloc::format!("{flush_err}"),
rollback_failures,
last_rollback_error,
});
}
return Err(flush_err);
}
if !delete_keys.is_empty()
&& let Err(flush_err) = adapter.batch_delete_sorted_deferred_wal(&delete_keys)
{
let flushed_inserts: Vec<Vec<u8>> =
insert_pairs.iter().map(|(k, _)| k.to_vec()).collect();
if let Err((rollback_failures, last_rollback_error)) =
Self::compensate_rollback(adapter, &flushed_inserts, &delete_prev_values)
{
return Err(BfTreeError::PartialFlushRollbackFailed {
flush_error: alloc::format!("{flush_err}"),
rollback_failures,
last_rollback_error,
});
}
return Err(flush_err);
}
if durability == super::config::DurabilityMode::Sync {
adapter.flush_wal().map_err(BfTreeError::from)?;
}
self.entries.clear();
Ok(())
}
fn compensate_rollback(
adapter: &BfTreeAdapter,
flushed_inserts: &[Vec<u8>],
flushed_deletes: &[(Vec<u8>, Option<Vec<u8>>)],
) -> Result<(), (usize, alloc::string::String)> {
let mut failure_count: usize = 0;
let mut last_error = alloc::string::String::new();
for key in flushed_inserts {
adapter.delete(key);
}
for (key, prev) in flushed_deletes {
if let Some(val) = prev
&& let Err(e) = adapter.insert(key, val)
{
failure_count += 1;
last_error = alloc::format!("{e}");
}
}
if failure_count > 0 {
Err((failure_count, last_error))
} else {
Ok(())
}
}
pub(crate) fn discard(&mut self) {
self.entries.clear();
}
pub(crate) fn merge_from(&mut self, other: WriteBuffer) -> Result<(), BfTreeError> {
let new_keys = other
.entries
.keys()
.filter(|k| !self.entries.contains_key(*k))
.count();
let post_merge_len = self.entries.len() + new_keys;
if post_merge_len > self.max_buffer_entries {
return Err(BfTreeError::InvalidKV(alloc::format!(
"merged buffer would have {} entries, exceeding limit of {}",
post_merge_len,
self.max_buffer_entries,
)));
}
for (key, value) in other.entries {
self.entries.insert(key, value);
}
Ok(())
}
pub(crate) fn drain_table(
&mut self,
bftree_encoded_keys: &[Vec<u8>],
prefix: &[u8],
prefix_end: &[u8],
) -> u64 {
use core::ops::Bound;
let mut count = 0u64;
for key in bftree_encoded_keys {
match self.get(key) {
BufferLookup::Tombstone => {} _ => count += 1,
}
self.delete(key.clone());
}
let buffer_only: Vec<Vec<u8>> = self
.entries
.range::<Vec<u8>, _>((
Bound::Included(prefix.to_vec()),
Bound::Excluded(prefix_end.to_vec()),
))
.filter_map(|(k, v)| if v.is_some() { Some(k.clone()) } else { None })
.collect();
count += buffer_only.len() as u64;
for key in buffer_only {
self.entries.insert(key, None);
}
count
}
}
pub struct BufferedScanIter<'a, K: Key + 'static, V: Value + 'static> {
buf_entries: Vec<(Vec<u8>, Option<Vec<u8>>)>,
buf_idx: usize,
scan: BfTreeTableScan<'a>,
scan_buf: Vec<u8>,
scan_peek: Option<(Vec<u8>, Vec<u8>)>,
scan_exhausted: bool,
exclude_start: Option<Vec<u8>>,
exclude_end: Option<Vec<u8>>,
verify_mode: Arc<VerifyMode>,
_key: PhantomData<K>,
_val: PhantomData<V>,
}
impl<'a, K: Key + 'static, V: Value + 'static> BufferedScanIter<'a, K, V> {
pub(crate) fn new(
buf_entries: Vec<(Vec<u8>, Option<Vec<u8>>)>,
scan: BfTreeTableScan<'a>,
max_record_size: usize,
exclude_start: Option<Vec<u8>>,
exclude_end: Option<Vec<u8>>,
verify_mode: Arc<VerifyMode>,
) -> Self {
Self {
buf_entries,
buf_idx: 0,
scan,
scan_buf: vec![0u8; max_record_size * 2],
scan_peek: None,
scan_exhausted: false,
exclude_start,
exclude_end,
verify_mode,
_key: PhantomData,
_val: PhantomData,
}
}
fn advance_scan(&mut self) {
if self.scan_exhausted {
return;
}
if let Some((key_bytes, val_bytes)) = self.scan.next(&mut self.scan_buf) {
self.scan_peek = Some((key_bytes.to_vec(), val_bytes.to_vec()));
} else {
self.scan_peek = None;
self.scan_exhausted = true;
}
}
fn buf_peek(&self) -> Option<(&[u8], &Option<Vec<u8>>)> {
if self.buf_idx < self.buf_entries.len() {
let (ref k, ref v) = self.buf_entries[self.buf_idx];
Some((k.as_slice(), v))
} else {
None
}
}
fn advance_buf(&mut self) {
self.buf_idx += 1;
}
}
impl<K: Key + 'static, V: Value + 'static> Iterator for BufferedScanIter<'_, K, V> {
type Item = crate::Result<(OwnedKv<K>, OwnedKv<V>)>;
fn next(&mut self) -> Option<Self::Item> {
if self.scan_peek.is_none() && !self.scan_exhausted {
self.advance_scan();
}
loop {
let buf = self.buf_peek();
let scan = self.scan_peek.as_ref();
let entry: Option<(Vec<u8>, Vec<u8>)> = match (buf, scan) {
(None, None) => return None,
(Some((bk, bv)), None) => {
let key = bk.to_vec();
let val = bv.clone();
self.advance_buf();
val.map(|v| (key, v))
}
(None, Some((sk, sv))) => {
let key = sk.clone();
let val = sv.clone();
self.advance_scan();
Some((key, val))
}
(Some((bk, bv)), Some((sk, sv))) => {
use core::cmp::Ordering;
match bk.cmp(sk.as_slice()) {
Ordering::Less => {
let key = bk.to_vec();
let val = bv.clone();
self.advance_buf();
val.map(|v| (key, v))
}
Ordering::Equal => {
let key = bk.to_vec();
let val = bv.clone();
self.advance_buf();
self.advance_scan();
val.map(|v| (key, v))
}
Ordering::Greater => {
let key = sk.clone();
let val = sv.clone();
self.advance_scan();
Some((key, val))
}
}
}
};
let Some((key, val)) = entry else {
continue;
};
if let Some(ref excl) = self.exclude_start {
match key.as_slice().cmp(excl.as_slice()) {
core::cmp::Ordering::Equal => {
self.exclude_start = None;
continue;
}
core::cmp::Ordering::Greater => {
self.exclude_start = None;
}
core::cmp::Ordering::Less => {
}
}
}
if self.exclude_end.as_ref().is_some_and(|excl| key == *excl) {
return None;
}
let k = OwnedKv::new(key);
let v = if self.verify_mode.is_enabled() {
let verify = should_verify(self.verify_mode.as_ref());
match unwrap_value(&val, verify) {
Ok(data) => OwnedKv::new(data.to_vec()),
Err(e) => return Some(Err(e.into())),
}
} else {
OwnedKv::new(val)
};
return Some(Ok((k, v)));
}
}
}
pub(crate) fn collect_buffer_entries_for_table(
buffer: &WriteBuffer,
table_name: &str,
kind: TableKind,
start_encoded: &[u8],
end_encoded: &[u8],
) -> Vec<(Vec<u8>, Option<Vec<u8>>)> {
let prefix = table_prefix(table_name, kind);
let prefix_len = prefix.len();
buffer
.range(start_encoded, end_encoded)
.filter_map(|(key, val)| {
if key.len() > prefix_len && key.starts_with(&prefix) {
let user_key = key[prefix_len..].to_vec();
Some((user_key, val.clone()))
} else {
None
}
})
.collect()
}
#[allow(dead_code)]
pub(crate) fn collect_all_buffer_entries_for_table(
buffer: &WriteBuffer,
table_name: &str,
kind: TableKind,
) -> Vec<(Vec<u8>, Option<Vec<u8>>)> {
let prefix = table_prefix(table_name, kind);
let prefix_end = table_prefix_end(table_name, kind);
collect_buffer_entries_for_table(buffer, table_name, kind, &prefix, &prefix_end)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TableDefinition;
use crate::bf_tree_store::config::{BfTreeConfig, DurabilityMode};
use crate::bf_tree_store::database::{BfTreeDatabase, TableKind, encode_table_key};
use crate::storage_traits::WriteTable;
const ITEMS: TableDefinition<&str, u64> = TableDefinition::new("items");
#[test]
fn buffer_put_get() {
let mut buf = WriteBuffer::new();
let key = b"test_key".to_vec();
assert!(matches!(buf.get(&key), BufferLookup::NotInBuffer));
buf.put(key.clone(), b"value1".to_vec()).unwrap();
match buf.get(&key) {
BufferLookup::Found(v) => assert_eq!(v, b"value1"),
_ => panic!("expected Found"),
}
buf.put(key.clone(), b"value2".to_vec()).unwrap();
match buf.get(&key) {
BufferLookup::Found(v) => assert_eq!(v, b"value2"),
_ => panic!("expected Found"),
}
}
#[test]
fn buffer_delete_tombstone() {
let mut buf = WriteBuffer::new();
let key = b"key".to_vec();
buf.put(key.clone(), b"val".to_vec()).unwrap();
buf.delete(key.clone());
assert!(matches!(buf.get(&key), BufferLookup::Tombstone));
}
#[test]
fn buffer_flush_applies_to_adapter() {
let db = BfTreeDatabase::create(BfTreeConfig::new_memory(4)).unwrap();
let adapter = db.adapter();
let mut buf = WriteBuffer::new();
let key1 = encode_table_key("test", TableKind::Regular, b"k1");
let key2 = encode_table_key("test", TableKind::Regular, b"k2");
buf.put(key1.clone(), b"v1".to_vec()).unwrap();
buf.put(key2.clone(), b"v2".to_vec()).unwrap();
buf.flush(adapter, DurabilityMode::Sync).unwrap();
let max = adapter.inner().config().get_cb_max_record_size();
let mut rbuf = vec![0u8; max];
let len = adapter.read(&key1, &mut rbuf).unwrap();
assert_eq!(&rbuf[..len as usize], b"v1");
}
#[test]
fn buffer_discard_rollback() {
let db = BfTreeDatabase::create(BfTreeConfig::new_memory(4)).unwrap();
let adapter = db.adapter();
let mut buf = WriteBuffer::new();
let key = encode_table_key("test", TableKind::Regular, b"k1");
buf.put(key.clone(), b"val".to_vec()).unwrap();
buf.discard();
assert!(buf.is_empty());
let max = adapter.inner().config().get_cb_max_record_size();
let mut rbuf = vec![0u8; max];
assert!(adapter.read(&key, &mut rbuf).is_err());
}
#[test]
fn buffer_flush_with_deletes() {
let db = BfTreeDatabase::create(BfTreeConfig::new_memory(4)).unwrap();
let adapter = db.adapter();
let key = encode_table_key("test", TableKind::Regular, b"existing");
adapter.insert(&key, b"old_val").unwrap();
let mut buf = WriteBuffer::new();
buf.delete(key.clone());
buf.flush(adapter, DurabilityMode::Sync).unwrap();
let max = adapter.inner().config().get_cb_max_record_size();
let mut rbuf = vec![0u8; max];
assert!(adapter.read(&key, &mut rbuf).is_err());
}
#[test]
fn buffered_write_txn_read_your_writes() {
let db = BfTreeDatabase::create(BfTreeConfig::new_memory(4)).unwrap();
let wtxn = db.begin_write();
let mut table = wtxn.open_table(ITEMS).unwrap();
WriteTable::st_insert(&mut table, &"hello", &42u64).unwrap();
let val = WriteTable::st_get(&table, &"hello").unwrap();
assert!(val.is_some());
assert_eq!(val.unwrap().value(), 42u64);
}
#[test]
fn buffered_write_txn_abort_on_drop() {
let db = BfTreeDatabase::create(BfTreeConfig::new_memory(4)).unwrap();
{
let wtxn = db.begin_write();
let mut table = wtxn.open_table(ITEMS).unwrap();
WriteTable::st_insert(&mut table, &"temp", &99u64).unwrap();
drop(table);
}
let rtxn = db.begin_read();
let mut ro = rtxn.open_table(ITEMS).unwrap();
assert!(ro.get(&"temp").unwrap().is_none());
}
#[test]
fn buffered_write_txn_commit_visible() {
let db = BfTreeDatabase::create(BfTreeConfig::new_memory(4)).unwrap();
let wtxn = db.begin_write();
let mut table = wtxn.open_table(ITEMS).unwrap();
WriteTable::st_insert(&mut table, &"committed", &77u64).unwrap();
drop(table);
wtxn.commit().unwrap();
let rtxn = db.begin_read();
let mut ro = rtxn.open_table(ITEMS).unwrap();
let val = ro.get(&"committed").unwrap().unwrap();
assert_eq!(u64::from_le_bytes(val.as_slice().try_into().unwrap()), 77);
}
#[test]
fn buffered_scan_merges_buffer_and_bftree() {
let db = BfTreeDatabase::create(BfTreeConfig::new_memory(4)).unwrap();
{
let wtxn = db.begin_write();
let mut table = wtxn.open_table(ITEMS).unwrap();
WriteTable::st_insert(&mut table, &"a", &1u64).unwrap();
WriteTable::st_insert(&mut table, &"c", &3u64).unwrap();
WriteTable::st_insert(&mut table, &"e", &5u64).unwrap();
drop(table);
wtxn.commit().unwrap();
}
let wtxn = db.begin_write();
let mut table = wtxn.open_table(ITEMS).unwrap();
WriteTable::st_insert(&mut table, &"b", &2u64).unwrap();
WriteTable::st_insert(&mut table, &"d", &4u64).unwrap();
WriteTable::st_remove(&mut table, &"c").unwrap();
let iter = WriteTable::st_range(&table, None, None, true, true).unwrap();
let entries: Vec<_> = iter.collect::<Result<Vec<_>, _>>().unwrap();
let keys: Vec<&str> = entries.iter().map(|(k, _)| k.value()).collect();
assert_eq!(keys, vec!["a", "b", "d", "e"]);
let vals: Vec<u64> = entries.iter().map(|(_, v)| v.value()).collect();
assert_eq!(vals, vec![1, 2, 4, 5]);
}
#[test]
fn buffered_overwrite_supersedes_bftree() {
let db = BfTreeDatabase::create(BfTreeConfig::new_memory(4)).unwrap();
{
let wtxn = db.begin_write();
let mut table = wtxn.open_table(ITEMS).unwrap();
WriteTable::st_insert(&mut table, &"key", &100u64).unwrap();
drop(table);
wtxn.commit().unwrap();
}
let wtxn = db.begin_write();
let mut table = wtxn.open_table(ITEMS).unwrap();
WriteTable::st_insert(&mut table, &"key", &200u64).unwrap();
let val = WriteTable::st_get(&table, &"key").unwrap().unwrap();
assert_eq!(val.value(), 200u64);
let iter = WriteTable::st_range(&table, None, None, true, true).unwrap();
let entries: Vec<_> = iter.collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].1.value(), 200u64);
}
#[test]
fn buffered_delete_hides_bftree_entry() {
let db = BfTreeDatabase::create(BfTreeConfig::new_memory(4)).unwrap();
{
let wtxn = db.begin_write();
let mut table = wtxn.open_table(ITEMS).unwrap();
WriteTable::st_insert(&mut table, &"visible", &1u64).unwrap();
WriteTable::st_insert(&mut table, &"hidden", &2u64).unwrap();
drop(table);
wtxn.commit().unwrap();
}
let wtxn = db.begin_write();
let mut table = wtxn.open_table(ITEMS).unwrap();
WriteTable::st_remove(&mut table, &"hidden").unwrap();
assert!(WriteTable::st_get(&table, &"hidden").unwrap().is_none());
let iter = WriteTable::st_range(&table, None, None, true, true).unwrap();
let entries: Vec<_> = iter.collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(entries.len(), 1);
let keys: Vec<&str> = entries.iter().map(|(k, _)| k.value()).collect();
assert_eq!(keys, vec!["visible"]);
}
#[test]
fn flush_rollback_undoes_partial_writes() {
let db = BfTreeDatabase::create(BfTreeConfig::new_memory(4)).unwrap();
let adapter = db.adapter();
let max_record_size = adapter.inner().config().get_cb_max_record_size();
let max_key_len = adapter.max_key_len();
let key_a = encode_table_key("t", TableKind::Regular, b"aaa");
let val_a = vec![1u8; 8];
let raw_key_b = vec![b'b'; max_key_len - 4]; let key_b = encode_table_key("t", TableKind::Regular, &raw_key_b);
assert!(key_b.len() <= max_key_len, "key_b must pass pre-validation");
let val_b_len = max_record_size - key_b.len() + 1;
assert!(
val_b_len <= max_record_size,
"val_b must pass pre-validation"
);
let val_b = vec![2u8; val_b_len];
let mut buf = WriteBuffer::new();
buf.put(key_a.clone(), val_a.clone()).unwrap();
buf.put(key_b.clone(), val_b).unwrap();
let result = buf.flush(adapter, DurabilityMode::Sync);
assert!(
result.is_err(),
"flush must fail on oversized combined record"
);
let mut rbuf = vec![0u8; max_record_size];
assert!(
adapter.read(&key_a, &mut rbuf).is_err(),
"key_a must be rolled back after partial flush failure"
);
}
#[test]
fn flush_rollback_restores_deleted_values() {
let db = BfTreeDatabase::create(BfTreeConfig::new_memory(4)).unwrap();
let adapter = db.adapter();
let max_record_size = adapter.inner().config().get_cb_max_record_size();
let max_key_len = adapter.max_key_len();
let key_a = encode_table_key("t", TableKind::Regular, b"aaa");
let original_val = b"original_value";
adapter.insert(&key_a, original_val).unwrap();
let raw_key_z = vec![b'z'; max_key_len - 4];
let key_z = encode_table_key("t", TableKind::Regular, &raw_key_z);
let val_z_len = max_record_size - key_z.len() + 1;
let val_z = vec![3u8; val_z_len];
let mut buf = WriteBuffer::new();
buf.delete(key_a.clone());
buf.put(key_z.clone(), val_z).unwrap();
let result = buf.flush(adapter, DurabilityMode::Sync);
assert!(
result.is_err(),
"flush must fail on oversized combined record"
);
let mut rbuf = vec![0u8; max_record_size];
let len = adapter
.read(&key_a, &mut rbuf)
.expect("key_a must be restored after rollback");
assert_eq!(
&rbuf[..len as usize],
original_val,
"restored value must match original"
);
}
}