use core::mem;
use reifydb_core::{
common::CommitVersion,
delta::Delta,
encoded::{key::EncodedKey, row::EncodedRow},
};
use reifydb_type::util::hex;
use tracing::instrument;
use crate::{
TransactionId,
error::TransactionError,
multi::{
conflict::ConflictManager,
marker::Marker,
pending::PendingWrites,
transaction::{version::VersionProvider, *},
types::Pending,
},
};
pub enum TransactionKind {
Current(CommitVersion),
TimeTravel(CommitVersion),
}
pub struct TransactionManagerQuery<L>
where
L: VersionProvider,
{
id: TransactionId,
engine: TransactionManager<L>,
transaction: TransactionKind,
}
impl<L> TransactionManagerQuery<L>
where
L: VersionProvider,
{
pub fn new_current(id: TransactionId, engine: TransactionManager<L>, version: CommitVersion) -> Self {
Self {
id,
engine,
transaction: TransactionKind::Current(version),
}
}
pub fn new_time_travel(id: TransactionId, engine: TransactionManager<L>, version: CommitVersion) -> Self {
Self {
id,
engine,
transaction: TransactionKind::TimeTravel(version),
}
}
pub fn id(&self) -> TransactionId {
self.id
}
pub fn version(&self) -> CommitVersion {
match self.transaction {
TransactionKind::Current(version) => version,
TransactionKind::TimeTravel(version) => version,
}
}
pub fn read_as_of_version_exclusive(&mut self, version: CommitVersion) {
self.transaction = TransactionKind::TimeTravel(version);
}
}
impl<L> Drop for TransactionManagerQuery<L>
where
L: VersionProvider,
{
fn drop(&mut self) {
if let TransactionKind::Current(version) = self.transaction {
self.engine.inner.done_query(version);
}
}
}
pub struct TransactionManagerCommand<L>
where
L: VersionProvider,
{
pub(super) id: TransactionId,
pub(super) version: CommitVersion,
pub(super) read_version: Option<CommitVersion>, pub(super) size: u64,
pub(super) count: u64,
pub(super) oracle: Arc<Oracle<L>>,
pub(super) conflicts: ConflictManager,
pub(super) pending_writes: PendingWrites,
pub(super) duplicates: Vec<Pending>,
pub(super) discarded: bool,
pub(super) done_query: bool,
}
impl<L> Drop for TransactionManagerCommand<L>
where
L: VersionProvider,
{
fn drop(&mut self) {
if !self.discarded {
self.discard();
}
}
}
impl<L> TransactionManagerCommand<L>
where
L: VersionProvider,
{
pub fn id(&self) -> TransactionId {
self.id
}
pub fn version(&self) -> CommitVersion {
self.read_version.unwrap_or(self.version)
}
pub fn base_version(&self) -> CommitVersion {
self.version
}
pub fn read_as_of_version_exclusive(&mut self, version: CommitVersion) {
self.read_version = Some(version);
}
pub fn pending_writes(&self) -> &PendingWrites {
&self.pending_writes
}
pub fn conflicts(&self) -> &ConflictManager {
&self.conflicts
}
}
impl<L> TransactionManagerCommand<L>
where
L: VersionProvider,
{
pub fn marker(&mut self) -> Marker<'_> {
Marker::new(&mut self.conflicts)
}
pub fn marker_with_pending_writes(&mut self) -> (Marker<'_>, &PendingWrites) {
(Marker::new(&mut self.conflicts), &self.pending_writes)
}
pub fn mark_read(&mut self, k: &EncodedKey) {
self.conflicts.mark_read(k);
}
pub fn mark_write(&mut self, k: &EncodedKey) {
self.conflicts.mark_write(k);
}
}
impl<L> TransactionManagerCommand<L>
where
L: VersionProvider,
{
#[instrument(name = "transaction::command::set", level = "debug", skip(self, row), fields(
txn_id = %self.id,
key_hex = %hex::display(key.as_ref()),
value_len = row.len()
))]
pub fn set(&mut self, key: &EncodedKey, row: EncodedRow) -> Result<()> {
if self.discarded {
return Err(TransactionError::RolledBack.into());
}
self.set_internal(key, row)
}
#[instrument(name = "transaction::command::unset", level = "debug", skip(self, row), fields(
txn_id = %self.id,
key_hex = %hex::display(key.as_ref()),
value_len = row.len()
))]
pub fn unset(&mut self, key: &EncodedKey, row: EncodedRow) -> Result<()> {
if self.discarded {
return Err(TransactionError::RolledBack.into());
}
self.modify(Pending {
delta: Delta::Unset {
key: key.clone(),
row,
},
version: self.base_version(),
})
}
#[instrument(name = "transaction::command::remove", level = "trace", skip(self), fields(
txn_id = %self.id,
key_len = key.len()
))]
pub fn remove(&mut self, key: &EncodedKey) -> Result<()> {
if self.discarded {
return Err(TransactionError::RolledBack.into());
}
self.modify(Pending {
delta: Delta::Remove {
key: key.clone(),
},
version: self.base_version(),
})
}
#[instrument(name = "transaction::command::rollback", level = "debug", skip(self), fields(txn_id = %self.id))]
pub fn rollback(&mut self) -> Result<()> {
if self.discarded {
return Err(TransactionError::RolledBack.into());
}
self.pending_writes.rollback();
self.conflicts.rollback();
Ok(())
}
#[instrument(name = "transaction::command::contains_key", level = "trace", skip(self), fields(
txn_id = %self.id,
key_hex = %hex::display(key.as_ref())
))]
pub fn contains_key(&mut self, key: &EncodedKey) -> Result<Option<bool>> {
if self.discarded {
return Err(TransactionError::RolledBack.into());
}
match self.pending_writes.get(key) {
Some(pending) => {
if pending.was_removed() {
return Ok(Some(false));
}
Ok(Some(true))
}
None => {
self.conflicts.mark_read(key);
Ok(None)
}
}
}
#[instrument(name = "transaction::command::get", level = "trace", skip(self), fields(
txn_id = %self.id,
key_hex = %hex::display(key.as_ref())
))]
pub fn get<'a, 'b: 'a>(&'a mut self, key: &'b EncodedKey) -> Result<Option<Pending>> {
if self.discarded {
return Err(TransactionError::RolledBack.into());
}
if let Some(v) = self.pending_writes.get(key) {
Ok(Some(Pending {
delta: match v.row() {
Some(row) => Delta::Set {
key: key.clone(),
row: row.clone(),
},
None => Delta::Remove {
key: key.clone(),
},
},
version: v.version,
}))
} else {
self.conflicts.mark_read(key);
Ok(None)
}
}
}
impl<L> TransactionManagerCommand<L>
where
L: VersionProvider,
{
#[instrument(name = "transaction::command::set_internal", level = "trace", skip(self, row), fields(
txn_id = %self.id,
key_hex = %hex::display(key.as_ref())
))]
fn set_internal(&mut self, key: &EncodedKey, row: EncodedRow) -> Result<()> {
if self.discarded {
return Err(TransactionError::RolledBack.into());
}
self.modify(Pending {
delta: Delta::Set {
key: key.clone(),
row,
},
version: self.base_version(),
})
}
#[instrument(name = "transaction::command::modify", level = "trace", skip(self, pending), fields(
txn_id = %self.id,
key_hex = %hex::display(pending.key().as_ref()),
is_remove = pending.was_removed()
))]
fn modify(&mut self, pending: Pending) -> Result<()> {
if self.discarded {
return Err(TransactionError::RolledBack.into());
}
let pending_writes = &mut self.pending_writes;
let cnt = self.count + 1;
let size = self.size + pending_writes.estimate_size(&pending);
if cnt >= pending_writes.max_batch_entries() || size >= pending_writes.max_batch_size() {
return Err(TransactionError::TooLarge.into());
}
self.count = cnt;
self.size = size;
self.conflicts.mark_write(pending.key());
let key = pending.key();
let row = pending.row();
let version = pending.version;
if let Some((old_key, old_value)) = pending_writes.remove_entry(key)
&& old_value.version != version
{
self.duplicates.push(Pending {
delta: match row {
Some(row) => Delta::Set {
key: old_key,
row: row.clone(),
},
None => Delta::Remove {
key: old_key,
},
},
version,
})
}
pending_writes.insert(key.clone(), pending);
Ok(())
}
}
impl<L> TransactionManagerCommand<L>
where
L: VersionProvider,
{
#[instrument(name = "transaction::command::commit_pending", level = "debug", skip(self), fields(
txn_id = %self.id,
pending_count = self.pending_writes.len()
))]
pub(crate) fn commit_pending(&mut self) -> Result<(CommitVersion, Vec<Pending>)> {
if self.discarded {
return Err(TransactionError::RolledBack.into());
}
let conflict_manager = mem::take(&mut self.conflicts);
let base_version = self.base_version();
match self.oracle.new_commit(&mut self.done_query, base_version, conflict_manager)? {
CreateCommitResult::Conflict(conflicts) => {
self.conflicts = conflicts;
Err(TransactionError::Conflict.into())
}
CreateCommitResult::TooOld => Err(TransactionError::TooOld.into()),
CreateCommitResult::Success(version) => {
let pending_writes = mem::take(&mut self.pending_writes);
let duplicate_writes = mem::take(&mut self.duplicates);
let mut all = Vec::with_capacity(pending_writes.len() + duplicate_writes.len());
let process = |entries: &mut Vec<Pending>, mut pending: Pending| {
pending.version = version;
entries.push(pending);
};
pending_writes.into_iter_insertion_order().for_each(|(_k, v)| {
let (ver, delta) = v.into_components();
process(
&mut all,
Pending {
delta,
version: ver,
},
)
});
duplicate_writes.into_iter().for_each(|item| process(&mut all, item));
debug_assert_ne!(version, 0);
Ok((version, all))
}
}
}
}
impl<L> TransactionManagerCommand<L>
where
L: VersionProvider,
{
#[instrument(name = "transaction::command::done", level = "trace", skip(self), fields(txn_id = %self.id))]
fn done_query(&mut self) {
if !self.done_query {
self.done_query = true;
self.oracle().query.done(self.version);
}
}
fn oracle(&self) -> &Oracle<L> {
&self.oracle
}
#[instrument(name = "transaction::command::discard", level = "trace", skip(self), fields(txn_id = %self.id))]
pub fn discard(&mut self) {
if self.discarded {
return;
}
self.discarded = true;
self.done_query();
}
pub fn is_discard(&self) -> bool {
self.discarded
}
}