use std::ops::RangeBounds;
use reifydb_core::{
common::CommitVersion,
encoded::{
key::{EncodedKey, EncodedKeyRange},
row::EncodedRow,
},
event::transaction::PostCommitEvent,
interface::store::{MultiVersionBatch, MultiVersionCommit, MultiVersionContains, MultiVersionGet},
};
use reifydb_type::{
Result,
util::{cowvec::CowVec, hex},
};
use tracing::instrument;
use super::{MultiTransaction, TransactionManagerCommand, version::StandardVersionProvider};
use crate::{delta::optimize_deltas, multi::types::TransactionValue};
pub struct WriteSavepoint {
pub(crate) pending_writes: PendingWrites,
pub(crate) count: u64,
pub(crate) size: u64,
pub(crate) duplicates: Vec<Pending>,
}
pub struct MultiWriteTransaction {
engine: MultiTransaction,
pub(crate) tm: TransactionManagerCommand<StandardVersionProvider>,
}
impl MultiWriteTransaction {
#[instrument(name = "transaction::command::new", level = "debug", skip(engine))]
pub fn new(engine: MultiTransaction) -> Result<Self> {
let tm = engine.tm.write()?;
Ok(Self {
engine,
tm,
})
}
}
impl MultiWriteTransaction {
pub fn savepoint(&self) -> WriteSavepoint {
WriteSavepoint {
pending_writes: self.tm.pending_writes.clone(),
count: self.tm.count,
size: self.tm.size,
duplicates: self.tm.duplicates.clone(),
}
}
pub fn restore_savepoint(&mut self, sp: WriteSavepoint) {
self.tm.pending_writes = sp.pending_writes;
self.tm.count = sp.count;
self.tm.size = sp.size;
self.tm.duplicates = sp.duplicates;
}
}
impl MultiWriteTransaction {
#[instrument(name = "transaction::command::commit", level = "debug", skip(self), fields(pending_count = self.tm.pending_writes().len()))]
pub fn commit(&mut self) -> Result<CommitVersion> {
if self.tm.pending_writes().is_empty() {
self.tm.discard();
return Ok(CommitVersion(0));
}
let (commit_version, entries) = self.tm.commit_pending()?;
if entries.is_empty() {
self.tm.discard();
return Ok(CommitVersion(0));
}
let mut raw_deltas = CowVec::with_capacity(entries.len());
for pending in &entries {
raw_deltas.push(pending.delta.clone());
}
let optimized = optimize_deltas(raw_deltas.iter().cloned());
let deltas = CowVec::new(optimized);
MultiVersionCommit::commit(&self.engine.store, deltas.clone(), commit_version)?;
self.tm.discard();
self.engine.event_bus.emit(PostCommitEvent::new(deltas, commit_version));
self.tm.oracle.done_commit(commit_version);
Ok(commit_version)
}
}
impl MultiWriteTransaction {
pub fn version(&self) -> CommitVersion {
self.tm.version()
}
pub fn pending_writes(&self) -> &PendingWrites {
self.tm.pending_writes()
}
pub fn read_as_of_version_exclusive(&mut self, version: CommitVersion) {
self.tm.read_as_of_version_exclusive(version);
}
pub fn read_as_of_version_inclusive(&mut self, version: CommitVersion) -> Result<()> {
self.read_as_of_version_exclusive(CommitVersion(version.0 + 1));
Ok(())
}
#[instrument(name = "transaction::command::rollback", level = "debug", skip(self))]
pub fn rollback(&mut self) -> Result<()> {
self.tm.rollback()
}
#[instrument(name = "transaction::command::contains_key", level = "trace", skip(self), fields(key_hex = %hex::display(key.as_ref())))]
pub fn contains_key(&mut self, key: &EncodedKey) -> Result<bool> {
let version = self.tm.version();
match self.tm.contains_key(key)? {
Some(true) => Ok(true),
Some(false) => Ok(false),
None => MultiVersionContains::contains(&self.engine.store, key, version),
}
}
#[instrument(name = "transaction::command::get", level = "trace", skip(self), fields(key_hex = %hex::display(key.as_ref())))]
pub fn get(&mut self, key: &EncodedKey) -> Result<Option<TransactionValue>> {
let version = self.tm.version();
match self.tm.get(key)? {
Some(v) => {
if v.row().is_some() {
Ok(Some(v.into()))
} else {
Ok(None)
}
}
None => Ok(MultiVersionGet::get(&self.engine.store, key, version)?.map(Into::into)),
}
}
#[instrument(name = "transaction::command::set", level = "trace", skip(self, row), fields(key_hex = %hex::display(key.as_ref()), value_len = row.len()))]
pub fn set(&mut self, key: &EncodedKey, row: EncodedRow) -> Result<()> {
self.tm.set(key, row)
}
#[instrument(name = "transaction::command::unset", level = "trace", skip(self, row), fields(key_hex = %hex::display(key.as_ref()), value_len = row.len()))]
pub fn unset(&mut self, key: &EncodedKey, row: EncodedRow) -> Result<()> {
self.tm.unset(key, row)
}
#[instrument(name = "transaction::command::remove", level = "trace", skip(self), fields(key_hex = %hex::display(key.as_ref())))]
pub fn remove(&mut self, key: &EncodedKey) -> Result<()> {
self.tm.remove(key)
}
pub fn prefix(&mut self, prefix: &EncodedKey) -> Result<MultiVersionBatch> {
let items: Vec<_> = self.range(EncodedKeyRange::prefix(prefix), 1024).collect::<Result<Vec<_>>>()?;
Ok(MultiVersionBatch {
items,
has_more: false,
})
}
pub fn prefix_rev(&mut self, prefix: &EncodedKey) -> Result<MultiVersionBatch> {
let items: Vec<_> =
self.range_rev(EncodedKeyRange::prefix(prefix), 1024).collect::<Result<Vec<_>>>()?;
Ok(MultiVersionBatch {
items,
has_more: false,
})
}
pub fn range(
&mut self,
range: EncodedKeyRange,
batch_size: usize,
) -> Box<dyn Iterator<Item = Result<MultiVersionRow>> + Send + '_> {
let version = self.tm.version();
let (mut marker, pw) = self.tm.marker_with_pending_writes();
let start = range.start_bound();
let end = range.end_bound();
marker.mark_range(range.clone());
let pending: Vec<(EncodedKey, Pending)> =
pw.range((start, end)).map(|(k, v)| (k.clone(), v.clone())).collect();
let storage_iter = self.engine.store.range(range, version, batch_size);
Box::new(MergePendingIterator::new(pending, storage_iter, false))
}
pub fn range_rev(
&mut self,
range: EncodedKeyRange,
batch_size: usize,
) -> Box<dyn Iterator<Item = Result<MultiVersionRow>> + Send + '_> {
let version = self.tm.version();
let (mut marker, pw) = self.tm.marker_with_pending_writes();
let start = range.start_bound();
let end = range.end_bound();
marker.mark_range(range.clone());
let pending: Vec<(EncodedKey, Pending)> =
pw.range((start, end)).rev().map(|(k, v)| (k.clone(), v.clone())).collect();
let storage_iter = self.engine.store.range_rev(range, version, batch_size);
Box::new(MergePendingIterator::new(pending, storage_iter, true))
}
}
use std::{cmp::Ordering, iter, vec};
use reifydb_core::interface::store::MultiVersionRow;
use crate::multi::{pending::PendingWrites, types::Pending};
pub(crate) struct MergePendingIterator<I> {
pending_iter: iter::Peekable<vec::IntoIter<(EncodedKey, Pending)>>,
storage_iter: I,
next_storage: Option<MultiVersionRow>,
reverse: bool,
}
impl<I> MergePendingIterator<I>
where
I: Iterator<Item = Result<MultiVersionRow>>,
{
pub(crate) fn new(pending: Vec<(EncodedKey, Pending)>, storage_iter: I, reverse: bool) -> Self {
Self {
pending_iter: pending.into_iter().peekable(),
storage_iter,
next_storage: None,
reverse,
}
}
}
impl<I> Iterator for MergePendingIterator<I>
where
I: Iterator<Item = Result<MultiVersionRow>>,
{
type Item = Result<MultiVersionRow>;
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.next_storage.is_none() {
self.next_storage = match self.storage_iter.next() {
Some(Ok(v)) => Some(v),
Some(Err(e)) => return Some(Err(e)),
None => None,
};
}
match (self.pending_iter.peek(), &self.next_storage) {
(Some((pending_key, _)), Some(storage_val)) => {
let cmp = pending_key.cmp(&storage_val.key);
let should_yield_pending = if self.reverse {
matches!(cmp, Ordering::Greater)
} else {
matches!(cmp, Ordering::Less)
};
if should_yield_pending {
let (key, value) = self.pending_iter.next().unwrap();
if let Some(row) = value.row() {
return Some(Ok(MultiVersionRow {
key,
row: row.clone(),
version: value.version,
}));
}
} else if matches!(cmp, Ordering::Equal) {
let (key, value) = self.pending_iter.next().unwrap();
self.next_storage = None; if let Some(row) = value.row() {
return Some(Ok(MultiVersionRow {
key,
row: row.clone(),
version: value.version,
}));
}
} else {
return Some(Ok(self.next_storage.take().unwrap()));
}
}
(Some(_), None) => {
let (key, value) = self.pending_iter.next().unwrap();
if let Some(row) = value.row() {
return Some(Ok(MultiVersionRow {
key,
row: row.clone(),
version: value.version,
}));
}
}
(None, Some(_)) => {
return Some(Ok(self.next_storage.take().unwrap()));
}
(None, None) => return None,
}
}
}
}