use std::{
collections::{
BTreeMap, HashMap,
btree_map::{IntoIter as BTreeMapIntoIter, Iter as BTreeMapIter, Range as BTreeMapRange},
},
mem::size_of,
ops::RangeBounds,
};
use reifydb_core::encoded::{key::EncodedKey, row::EncodedRow};
use crate::multi::types::Pending;
#[derive(Debug, Default, Clone)]
pub struct PendingWrites {
writes: BTreeMap<EncodedKey, Pending>,
insertion_order: Vec<EncodedKey>,
position_index: HashMap<EncodedKey, usize>,
estimated_size: u64,
}
impl PendingWrites {
pub fn new() -> Self {
Self {
writes: BTreeMap::new(),
insertion_order: Vec::new(),
position_index: HashMap::new(),
estimated_size: 0,
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.writes.is_empty()
}
#[inline]
pub fn len(&self) -> usize {
self.writes.len()
}
#[inline]
pub fn max_batch_size(&self) -> u64 {
1024 * 1024 * 1024 }
#[inline]
pub fn max_batch_entries(&self) -> u64 {
1_000_000 }
#[inline]
pub fn estimate_size(&self, _entry: &Pending) -> u64 {
(size_of::<EncodedKey>() + size_of::<EncodedRow>()) as u64
}
#[inline]
pub fn get(&self, key: &EncodedKey) -> Option<&Pending> {
self.writes.get(key)
}
#[inline]
pub fn get_entry(&self, key: &EncodedKey) -> Option<(&EncodedKey, &Pending)> {
self.writes.get_key_value(key)
}
#[inline]
pub fn contains_key(&self, key: &EncodedKey) -> bool {
self.writes.contains_key(key)
}
pub fn insert(&mut self, key: EncodedKey, value: Pending) {
let size_estimate = self.estimate_size(&value);
if let Some(pre) = self.writes.insert(key.clone(), value) {
let pre_size = self.estimate_size(&pre);
if size_estimate != pre_size {
self.estimated_size =
self.estimated_size.saturating_sub(pre_size).saturating_add(size_estimate);
}
} else {
let position = self.insertion_order.len();
self.insertion_order.push(key.clone());
self.position_index.insert(key, position);
self.estimated_size = self.estimated_size.saturating_add(size_estimate);
}
}
pub fn remove_entry(&mut self, key: &EncodedKey) -> Option<(EncodedKey, Pending)> {
if let Some((removed_key, removed_value)) = self.writes.remove_entry(key) {
if let Some(position) = self.position_index.remove(key)
&& position < self.insertion_order.len()
{
let swapped_position = self.insertion_order.len() - 1;
if position != swapped_position {
self.insertion_order.swap(position, swapped_position);
if let Some(swapped_key) = self.insertion_order.get(position) {
self.position_index.insert(swapped_key.clone(), position);
}
}
self.insertion_order.pop();
}
let size_estimate = self.estimate_size(&removed_value);
self.estimated_size = self.estimated_size.saturating_sub(size_estimate);
Some((removed_key, removed_value))
} else {
None
}
}
pub fn iter(&self) -> BTreeMapIter<'_, EncodedKey, Pending> {
self.writes.iter()
}
pub fn into_iter_insertion_order(self) -> impl Iterator<Item = (EncodedKey, Pending)> {
let mut writes = self.writes;
self.insertion_order.into_iter().filter_map(move |key| writes.remove_entry(&key))
}
pub fn rollback(&mut self) {
self.writes.clear();
self.insertion_order.clear();
self.position_index.clear();
self.estimated_size = 0;
}
#[inline]
pub fn total_estimated_size(&self) -> u64 {
self.estimated_size
}
pub fn range<R>(&self, range: R) -> BTreeMapRange<'_, EncodedKey, Pending>
where
R: RangeBounds<EncodedKey>,
{
self.writes.range(range)
}
pub fn range_comparable<R>(&self, range: R) -> BTreeMapRange<'_, EncodedKey, Pending>
where
R: RangeBounds<EncodedKey>,
{
self.writes.range(range)
}
#[inline]
pub fn get_comparable(&self, key: &EncodedKey) -> Option<&Pending> {
self.get(key)
}
#[inline]
pub fn get_entry_comparable(&self, key: &EncodedKey) -> Option<(&EncodedKey, &Pending)> {
self.get_entry(key)
}
#[inline]
pub fn contains_key_comparable(&self, key: &EncodedKey) -> bool {
self.contains_key(key)
}
#[inline]
pub fn remove_entry_comparable(&mut self, key: &EncodedKey) -> Option<(EncodedKey, Pending)> {
self.remove_entry(key)
}
}
impl IntoIterator for PendingWrites {
type Item = (EncodedKey, Pending);
type IntoIter = BTreeMapIntoIter<EncodedKey, Pending>;
fn into_iter(self) -> Self::IntoIter {
self.writes.into_iter()
}
}
#[cfg(test)]
pub mod tests {
use reifydb_core::{common::CommitVersion, delta::Delta, encoded::key::EncodedKey};
use reifydb_type::util::cowvec::CowVec;
use super::*;
fn create_test_key(s: &str) -> EncodedKey {
EncodedKey::new(s.as_bytes())
}
fn create_test_row(s: &str) -> EncodedRow {
EncodedRow(CowVec::new(s.as_bytes().to_vec()))
}
fn create_test_pending(version: CommitVersion, key: &str, values_data: &str) -> Pending {
Pending {
delta: Delta::Set {
key: create_test_key(key),
row: create_test_row(values_data),
},
version,
}
}
#[test]
fn test_basic_operations() {
let mut pw = PendingWrites::new();
assert!(pw.is_empty());
assert_eq!(pw.len(), 0);
let key1 = create_test_key("key1");
let pending1 = create_test_pending(CommitVersion(1), "key1", "value1");
pw.insert(key1.clone(), pending1.clone());
assert!(!pw.is_empty());
assert_eq!(pw.len(), 1);
assert!(pw.contains_key(&key1));
assert_eq!(pw.get(&key1).unwrap(), &pending1);
}
#[test]
fn test_update_operations() {
let mut pw = PendingWrites::new();
let key = create_test_key("key");
let pending1 = create_test_pending(CommitVersion(1), "key", "value1");
let pending2 = create_test_pending(CommitVersion(2), "key", "value2");
pw.insert(key.clone(), pending1);
assert_eq!(pw.len(), 1);
pw.insert(key.clone(), pending2.clone());
assert_eq!(pw.len(), 1); assert_eq!(pw.get(&key).unwrap(), &pending2);
}
#[test]
fn test_range_operations() {
let mut pw = PendingWrites::new();
for i in 0..10 {
let key = create_test_key(&format!("key{:02}", i));
let pending =
create_test_pending(CommitVersion(i), &format!("key{:02}", i), &format!("value{}", i));
pw.insert(key, pending);
}
let start = create_test_key("key03");
let end = create_test_key("key07");
let range_results: Vec<_> = pw.range(start..end).collect();
assert_eq!(range_results.len(), 4); }
#[test]
fn test_iterator_compatibility() {
let mut pw = PendingWrites::new();
for i in 0..5 {
let key = create_test_key(&format!("key{}", i));
let pending =
create_test_pending(CommitVersion(i), &format!("key{}", i), &format!("value{}", i));
pw.insert(key, pending);
}
let iter = pw.iter();
let items: Vec<_> = iter.collect();
assert_eq!(items.len(), 5);
let keys: Vec<_> = items.iter().map(|(k, _)| k).collect();
let mut expected_keys = keys.clone();
expected_keys.sort();
assert_eq!(keys, expected_keys);
let start = create_test_key("key1");
let end = create_test_key("key4");
let range_items: Vec<_> = pw.range(start..end).collect();
assert_eq!(range_items.len(), 3); }
#[test]
fn test_performance_operations() {
let mut pw = PendingWrites::new();
for i in 0..1000 {
let key = create_test_key(&format!("key{:06}", i));
let pending =
create_test_pending(CommitVersion(i), &format!("key{:06}", i), &format!("value{}", i));
pw.insert(key, pending);
}
assert_eq!(pw.len(), 1000);
let lookup_key = create_test_key("key000500");
assert!(pw.contains_key(&lookup_key));
assert!(pw.get(&lookup_key).is_some());
let removed = pw.remove_entry(&lookup_key);
assert!(removed.is_some());
assert_eq!(pw.len(), 999);
assert!(!pw.contains_key(&lookup_key));
}
#[test]
fn test_rollback() {
let mut pw = PendingWrites::new();
for i in 0..10 {
let key = create_test_key(&format!("key{}", i));
let pending =
create_test_pending(CommitVersion(i), &format!("key{}", i), &format!("value{}", i));
pw.insert(key, pending);
}
assert_eq!(pw.len(), 10);
assert!(pw.total_estimated_size() > 0);
pw.rollback();
assert!(pw.is_empty());
assert_eq!(pw.total_estimated_size(), 0);
}
}