1use core::mem;
13use std::{cmp::Ordering, collections::HashSet, iter, ops::RangeBounds, sync::Arc, vec};
14
15use reifydb_core::{
16 common::CommitVersion,
17 delta::Delta,
18 encoded::{
19 key::{EncodedKey, EncodedKeyRange},
20 row::EncodedRow,
21 },
22 event::transaction::PostCommitEvent,
23 interface::store::{
24 MultiVersionBatch, MultiVersionCommit, MultiVersionContains, MultiVersionGet, MultiVersionRow,
25 },
26};
27use reifydb_type::{
28 Result,
29 util::{cowvec::CowVec, hex},
30};
31use tracing::instrument;
32
33use super::{MultiTransaction, version::StandardVersionProvider};
34use crate::{
35 TransactionId,
36 delta::optimize_deltas,
37 error::TransactionError,
38 multi::{
39 conflict::ConflictManager,
40 marker::Marker,
41 oracle::{CreateCommitResult, Oracle},
42 pending::PendingWrites,
43 types::{DeltaEntry, TransactionValue},
44 },
45};
46
47pub struct WriteSavepoint {
48 pub(crate) pending_writes: PendingWrites,
49 pub(crate) count: u64,
50 pub(crate) size: u64,
51 pub(crate) duplicates: Vec<DeltaEntry>,
52 pub(crate) delta_log_len: usize,
53 pub(crate) conflicts: ConflictManager,
54 pub(crate) preexisting_keys: HashSet<Vec<u8>>,
55}
56
57#[derive(Clone, Copy, PartialEq)]
58pub(crate) enum Lifecycle {
59 Active,
60 QueryDone,
61 Discarded,
62}
63
64pub struct MultiWriteTransaction {
65 engine: MultiTransaction,
66
67 pub(crate) id: TransactionId,
68 pub(crate) version: CommitVersion,
69 pub(crate) read_version: Option<CommitVersion>,
70 pub(crate) size: u64,
71 pub(crate) count: u64,
72 pub(crate) oracle: Arc<Oracle<StandardVersionProvider>>,
73 pub(crate) conflicts: ConflictManager,
74 pub(crate) pending_writes: PendingWrites,
75 pub(crate) duplicates: Vec<DeltaEntry>,
76 pub(crate) delta_log: Vec<DeltaEntry>,
81 pub(crate) preexisting_keys: HashSet<Vec<u8>>,
87
88 pub(crate) lifecycle: Lifecycle,
89}
90
91impl MultiWriteTransaction {
92 #[instrument(name = "transaction::command::new", level = "debug", skip(engine))]
93 pub fn new(engine: MultiTransaction) -> Result<Self> {
94 let oracle = engine.tm.oracle().clone();
95 let version = oracle.version()?;
96 oracle.query.register_in_flight(version);
97
98 let id = TransactionId::generate(oracle.metrics_clock(), oracle.rng());
99 Ok(Self {
100 engine,
101 id,
102 version,
103 read_version: None,
104 size: 0,
105 count: 0,
106 oracle,
107 conflicts: ConflictManager::new(),
108 pending_writes: PendingWrites::new(),
109 duplicates: Vec::new(),
110 delta_log: Vec::new(),
111 preexisting_keys: HashSet::new(),
112 lifecycle: Lifecycle::Active,
113 })
114 }
115
116 fn transition_to(&mut self, next: Lifecycle) {
117 debug_assert!(matches!(
118 (self.lifecycle, next),
119 (Lifecycle::Active, Lifecycle::QueryDone)
120 | (Lifecycle::Active, Lifecycle::Discarded)
121 | (Lifecycle::QueryDone, Lifecycle::Discarded)
122 ));
123 self.lifecycle = next;
124 }
125}
126
127impl Drop for MultiWriteTransaction {
128 fn drop(&mut self) {
129 if self.lifecycle != Lifecycle::Discarded {
130 self.discard();
131 }
132 }
133}
134
135impl MultiWriteTransaction {
136 pub fn id(&self) -> TransactionId {
137 self.id
138 }
139
140 pub fn version(&self) -> CommitVersion {
141 self.read_version.unwrap_or(self.version)
142 }
143
144 pub fn base_version(&self) -> CommitVersion {
145 self.version
146 }
147
148 pub fn read_as_of_version_exclusive(&mut self, version: CommitVersion) {
149 self.read_version = Some(version);
150 }
151
152 pub fn read_as_of_version_inclusive(&mut self, version: CommitVersion) -> Result<()> {
153 self.read_as_of_version_exclusive(CommitVersion(version.0 + 1));
154 Ok(())
155 }
156
157 pub fn pending_writes(&self) -> &PendingWrites {
158 &self.pending_writes
159 }
160
161 pub fn conflicts(&self) -> &ConflictManager {
162 &self.conflicts
163 }
164
165 pub fn mark_preexisting(&mut self, key: &EncodedKey) {
166 self.preexisting_keys.insert(key.as_ref().to_vec());
167 }
168
169 pub fn preexisting_keys(&self) -> &HashSet<Vec<u8>> {
170 &self.preexisting_keys
171 }
172}
173
174impl MultiWriteTransaction {
175 pub fn savepoint(&self) -> WriteSavepoint {
176 WriteSavepoint {
177 pending_writes: self.pending_writes.clone(),
178 count: self.count,
179 size: self.size,
180 duplicates: self.duplicates.clone(),
181 delta_log_len: self.delta_log.len(),
182 conflicts: self.conflicts.clone(),
183 preexisting_keys: self.preexisting_keys.clone(),
184 }
185 }
186
187 pub fn restore_savepoint(&mut self, sp: WriteSavepoint) {
188 self.pending_writes = sp.pending_writes;
189 self.count = sp.count;
190 self.size = sp.size;
191 self.duplicates = sp.duplicates;
192 self.delta_log.truncate(sp.delta_log_len);
193 self.conflicts = sp.conflicts;
194 self.preexisting_keys = sp.preexisting_keys;
195 }
196}
197
198impl MultiWriteTransaction {
199 pub fn marker(&mut self) -> Marker<'_> {
200 Marker::new(&mut self.conflicts)
201 }
202
203 pub fn marker_with_pending_writes(&mut self) -> (Marker<'_>, &PendingWrites) {
204 (Marker::new(&mut self.conflicts), &self.pending_writes)
205 }
206
207 pub fn mark_read(&mut self, k: &EncodedKey) {
208 self.conflicts.mark_read(k);
209 }
210
211 pub fn mark_write(&mut self, k: &EncodedKey) {
212 self.conflicts.mark_write(k);
213 }
214
215 pub fn reserve_writes(&mut self, additional: usize) {
216 self.conflicts.reserve_writes(additional);
217 }
218
219 pub(crate) fn disable_conflict_tracking(&mut self) {
221 self.conflicts.set_disabled();
222 }
223}
224
225impl MultiWriteTransaction {
226 #[instrument(name = "transaction::command::set", level = "debug", skip(self, row), fields(
227 txn_id = %self.id,
228 key_hex = %hex::display(key.as_ref()),
229 value_len = row.len()
230 ))]
231 pub fn set(&mut self, key: &EncodedKey, row: EncodedRow) -> Result<()> {
232 if self.lifecycle == Lifecycle::Discarded {
233 return Err(TransactionError::RolledBack.into());
234 }
235 self.modify(DeltaEntry {
236 delta: Delta::Set {
237 key: key.clone(),
238 row,
239 },
240 version: self.base_version(),
241 })
242 }
243
244 #[instrument(name = "transaction::command::unset", level = "debug", skip(self, row), fields(
247 txn_id = %self.id,
248 key_hex = %hex::display(key.as_ref()),
249 value_len = row.len()
250 ))]
251 pub fn unset(&mut self, key: &EncodedKey, row: EncodedRow) -> Result<()> {
252 if self.lifecycle == Lifecycle::Discarded {
253 return Err(TransactionError::RolledBack.into());
254 }
255 self.modify(DeltaEntry {
256 delta: Delta::Unset {
257 key: key.clone(),
258 row,
259 },
260 version: self.base_version(),
261 })
262 }
263
264 #[instrument(name = "transaction::command::remove", level = "trace", skip(self), fields(
268 txn_id = %self.id,
269 key_len = key.len()
270 ))]
271 pub fn remove(&mut self, key: &EncodedKey) -> Result<()> {
272 if self.lifecycle == Lifecycle::Discarded {
273 return Err(TransactionError::RolledBack.into());
274 }
275 self.modify(DeltaEntry {
276 delta: Delta::Remove {
277 key: key.clone(),
278 },
279 version: self.base_version(),
280 })
281 }
282
283 #[instrument(name = "transaction::command::rollback", level = "debug", skip(self), fields(txn_id = %self.id))]
284 pub fn rollback(&mut self) -> Result<()> {
285 if self.lifecycle == Lifecycle::Discarded {
286 return Err(TransactionError::RolledBack.into());
287 }
288
289 self.pending_writes.rollback();
290 self.conflicts.rollback();
291 self.delta_log.clear();
292 self.duplicates.clear();
293 Ok(())
294 }
295
296 #[instrument(name = "transaction::command::contains_key", level = "trace", skip(self), fields(
297 txn_id = %self.id,
298 key_hex = %hex::display(key.as_ref())
299 ))]
300 pub fn contains_key(&mut self, key: &EncodedKey) -> Result<bool> {
301 if self.lifecycle == Lifecycle::Discarded {
302 return Err(TransactionError::RolledBack.into());
303 }
304
305 let version = self.version();
306 match self.pending_writes.get(key) {
307 Some(pending) => {
308 if pending.was_removed() {
309 return Ok(false);
310 }
311 Ok(true)
312 }
313 None => {
314 self.conflicts.mark_read(key);
315 MultiVersionContains::contains(&self.engine.store, key, version)
316 }
317 }
318 }
319
320 #[instrument(name = "transaction::command::get", level = "trace", skip(self), fields(
321 txn_id = %self.id,
322 key_hex = %hex::display(key.as_ref())
323 ))]
324 pub fn get(&mut self, key: &EncodedKey) -> Result<Option<TransactionValue>> {
325 if self.lifecycle == Lifecycle::Discarded {
326 return Err(TransactionError::RolledBack.into());
327 }
328
329 let version = self.version();
330 if let Some(v) = self.pending_writes.get(key) {
331 if v.row().is_some() {
332 return Ok(Some(DeltaEntry {
333 delta: match v.row() {
334 Some(row) => Delta::Set {
335 key: key.clone(),
336 row: row.clone(),
337 },
338 None => Delta::Remove {
339 key: key.clone(),
340 },
341 },
342 version: v.version,
343 }
344 .into()));
345 }
346 return Ok(None);
347 }
348 self.conflicts.mark_read(key);
349 Ok(MultiVersionGet::get(&self.engine.store, key, version)?.map(Into::into))
350 }
351
352 #[instrument(name = "transaction::command::get_committed", level = "trace", skip(self), fields(
359 txn_id = %self.id,
360 key_hex = %hex::display(key.as_ref())
361 ))]
362 pub fn get_committed(&mut self, key: &EncodedKey) -> Result<Option<TransactionValue>> {
363 if self.lifecycle == Lifecycle::Discarded {
364 return Err(TransactionError::RolledBack.into());
365 }
366 let version = self.version();
367 self.conflicts.mark_read(key);
368 Ok(MultiVersionGet::get(&self.engine.store, key, version)?.map(Into::into))
369 }
370}
371
372impl MultiWriteTransaction {
373 #[instrument(name = "transaction::command::modify", level = "trace", skip(self, pending), fields(
374 txn_id = %self.id,
375 key_hex = %hex::display(pending.key().as_ref()),
376 is_remove = pending.was_removed()
377 ))]
378 fn modify(&mut self, pending: DeltaEntry) -> Result<()> {
379 let cnt = self.count + 1;
380 let size = self.size + self.pending_writes.estimate_size(&pending);
381 if cnt >= self.pending_writes.max_batch_entries() || size >= self.pending_writes.max_batch_size() {
382 return Err(TransactionError::TooLarge.into());
383 }
384
385 self.count = cnt;
386 self.size = size;
387
388 self.conflicts.mark_write(pending.key());
389
390 let key = pending.key();
391 let row = pending.row();
392 let version = pending.version;
393
394 if let Some((old_key, old_value)) = self.pending_writes.remove_entry(key)
395 && old_value.version != version
396 {
397 self.duplicates.push(DeltaEntry {
398 delta: match row {
399 Some(row) => Delta::Set {
400 key: old_key,
401 row: row.clone(),
402 },
403 None => Delta::Remove {
404 key: old_key,
405 },
406 },
407 version,
408 })
409 }
410 self.delta_log.push(pending.clone());
414 self.pending_writes.insert(key.clone(), pending);
415
416 Ok(())
417 }
418}
419
420impl MultiWriteTransaction {
421 #[instrument(name = "transaction::command::commit_pending", level = "debug", skip(self), fields(
422 txn_id = %self.id,
423 pending_count = self.pending_writes.len()
424 ))]
425 fn commit_pending(&mut self) -> Result<(CommitVersion, Vec<DeltaEntry>)> {
426 if self.lifecycle == Lifecycle::Discarded {
427 return Err(TransactionError::RolledBack.into());
428 }
429 let conflict_manager = mem::take(&mut self.conflicts);
430 let base_version = self.base_version();
431
432 let result = self.oracle.new_commit(base_version, conflict_manager);
433 self.release_read_snapshot(base_version);
434
435 match result? {
436 CreateCommitResult::Conflict(conflicts) => {
437 self.conflicts = conflicts;
438 Err(TransactionError::Conflict.into())
439 }
440 CreateCommitResult::TooOld => Err(TransactionError::TooOld.into()),
441 CreateCommitResult::Success(version) => Ok((version, self.assemble_committed_deltas(version))),
442 }
443 }
444
445 #[instrument(name = "transaction::command::commit_pending_unchecked", level = "debug", skip(self), fields(
447 txn_id = %self.id,
448 pending_count = self.pending_writes.len()
449 ))]
450 fn commit_pending_unchecked(&mut self) -> Result<(CommitVersion, Vec<DeltaEntry>)> {
451 if self.lifecycle == Lifecycle::Discarded {
452 return Err(TransactionError::RolledBack.into());
453 }
454 let _ = mem::take(&mut self.conflicts);
455 let base_version = self.base_version();
456
457 let result = self.oracle.advance_unchecked(base_version);
458 self.release_read_snapshot(base_version);
459
460 match result? {
461 CreateCommitResult::Conflict(_) => unreachable!("advance_unchecked never reports a conflict"),
462 CreateCommitResult::TooOld => Err(TransactionError::TooOld.into()),
463 CreateCommitResult::Success(version) => Ok((version, self.assemble_committed_deltas(version))),
464 }
465 }
466
467 #[inline]
471 fn release_read_snapshot(&mut self, base_version: CommitVersion) {
472 if self.lifecycle == Lifecycle::Active {
473 self.oracle.query.mark_finished(base_version);
474 self.transition_to(Lifecycle::QueryDone);
475 }
476 }
477
478 #[inline]
482 fn assemble_committed_deltas(&mut self, version: CommitVersion) -> Vec<DeltaEntry> {
483 debug_assert_ne!(version, 0);
484 let _ = mem::take(&mut self.pending_writes);
485 let duplicate_writes = mem::take(&mut self.duplicates);
486 let mut all = mem::take(&mut self.delta_log);
487 all.reserve(duplicate_writes.len());
488
489 for pending in all.iter_mut() {
490 pending.version = version;
491 }
492 for mut pending in duplicate_writes {
493 pending.version = version;
494 all.push(pending);
495 }
496 all
497 }
498}
499
500impl MultiWriteTransaction {
501 #[instrument(name = "transaction::command::commit", level = "debug", skip(self), fields(pending_count = self.pending_writes().len()))]
502 pub fn commit(&mut self) -> Result<CommitVersion> {
503 if self.pending_writes.is_empty() {
504 self.discard();
505 return Ok(CommitVersion(0));
506 }
507 let (commit_version, entries) = self.commit_pending()?;
508 self.finalize_commit(commit_version, entries)
509 }
510
511 #[instrument(name = "transaction::command::commit_unchecked", level = "debug", skip(self), fields(pending_count = self.pending_writes().len()))]
513 pub(crate) fn commit_unchecked(&mut self) -> Result<CommitVersion> {
514 if self.pending_writes.is_empty() {
515 self.discard();
516 return Ok(CommitVersion(0));
517 }
518 let (commit_version, entries) = self.commit_pending_unchecked()?;
519 self.finalize_commit(commit_version, entries)
520 }
521
522 #[inline]
526 fn finalize_commit(
527 &mut self,
528 commit_version: CommitVersion,
529 entries: Vec<DeltaEntry>,
530 ) -> Result<CommitVersion> {
531 if entries.is_empty() {
532 self.discard();
533 return Ok(CommitVersion(0));
534 }
535 let deltas = self.optimize_for_storage(&entries);
536 MultiVersionCommit::commit(&self.engine.store, deltas.clone(), commit_version)?;
537 self.discard();
538 self.publish(commit_version, deltas);
539 Ok(commit_version)
540 }
541
542 #[inline]
543 fn optimize_for_storage(&self, entries: &[DeltaEntry]) -> CowVec<Delta> {
544 let mut raw_deltas = CowVec::with_capacity(entries.len());
545 for pending in entries {
546 raw_deltas.push(pending.delta.clone());
547 }
548 let optimized = optimize_deltas(raw_deltas.iter().cloned(), self.preexisting_keys());
549 CowVec::new(optimized)
550 }
551
552 #[inline]
559 fn publish(&self, commit_version: CommitVersion, deltas: CowVec<Delta>) {
560 self.engine.event_bus.emit(PostCommitEvent::new(deltas, commit_version));
561 self.oracle.done_commit(commit_version);
562 }
563}
564
565impl MultiWriteTransaction {
566 #[instrument(name = "transaction::command::discard", level = "trace", skip(self), fields(txn_id = %self.id))]
567 pub fn discard(&mut self) {
568 match self.lifecycle {
569 Lifecycle::Discarded => return,
570 Lifecycle::Active => self.oracle.query.mark_finished(self.version),
571 Lifecycle::QueryDone => {}
572 }
573 self.transition_to(Lifecycle::Discarded);
574 }
575
576 pub fn is_discard(&self) -> bool {
577 self.lifecycle == Lifecycle::Discarded
578 }
579}
580
581impl MultiWriteTransaction {
582 pub fn prefix(&mut self, prefix: &EncodedKey) -> Result<MultiVersionBatch> {
583 let items: Vec<_> = self.range(EncodedKeyRange::prefix(prefix), 1024).collect::<Result<Vec<_>>>()?;
584 Ok(MultiVersionBatch {
585 items,
586 has_more: false,
587 })
588 }
589
590 pub fn prefix_rev(&mut self, prefix: &EncodedKey) -> Result<MultiVersionBatch> {
591 let items: Vec<_> =
592 self.range_rev(EncodedKeyRange::prefix(prefix), 1024).collect::<Result<Vec<_>>>()?;
593 Ok(MultiVersionBatch {
594 items,
595 has_more: false,
596 })
597 }
598
599 pub fn range(
600 &mut self,
601 range: EncodedKeyRange,
602 batch_size: usize,
603 ) -> Box<dyn Iterator<Item = Result<MultiVersionRow>> + Send + '_> {
604 let version = self.version();
605 let (mut marker, pw) = self.marker_with_pending_writes();
606 let start = range.start_bound();
607 let end = range.end_bound();
608
609 marker.mark_range(range.clone());
610
611 let pending: Vec<(EncodedKey, DeltaEntry)> =
612 pw.range((start, end)).map(|(k, v)| (k.clone(), v.clone())).collect();
613
614 let storage_iter = self.engine.store.range(range, version, batch_size);
615
616 Box::new(MergePendingIterator::new(pending, storage_iter, false))
617 }
618
619 pub fn range_rev(
620 &mut self,
621 range: EncodedKeyRange,
622 batch_size: usize,
623 ) -> Box<dyn Iterator<Item = Result<MultiVersionRow>> + Send + '_> {
624 let version = self.version();
625 let (mut marker, pw) = self.marker_with_pending_writes();
626 let start = range.start_bound();
627 let end = range.end_bound();
628
629 marker.mark_range(range.clone());
630
631 let pending: Vec<(EncodedKey, DeltaEntry)> =
632 pw.range((start, end)).rev().map(|(k, v)| (k.clone(), v.clone())).collect();
633
634 let storage_iter = self.engine.store.range_rev(range, version, batch_size);
635
636 Box::new(MergePendingIterator::new(pending, storage_iter, true))
637 }
638}
639
640pub(crate) struct MergePendingIterator<I> {
641 pending_iter: iter::Peekable<vec::IntoIter<(EncodedKey, DeltaEntry)>>,
642 storage_iter: I,
643 next_storage: Option<MultiVersionRow>,
644 reverse: bool,
645}
646
647impl<I> MergePendingIterator<I>
648where
649 I: Iterator<Item = Result<MultiVersionRow>>,
650{
651 pub(crate) fn new(pending: Vec<(EncodedKey, DeltaEntry)>, storage_iter: I, reverse: bool) -> Self {
652 Self {
653 pending_iter: pending.into_iter().peekable(),
654 storage_iter,
655 next_storage: None,
656 reverse,
657 }
658 }
659}
660
661impl<I> Iterator for MergePendingIterator<I>
662where
663 I: Iterator<Item = Result<MultiVersionRow>>,
664{
665 type Item = Result<MultiVersionRow>;
666
667 fn next(&mut self) -> Option<Self::Item> {
668 loop {
669 if self.next_storage.is_none() {
670 self.next_storage = match self.storage_iter.next() {
671 Some(Ok(v)) => Some(v),
672 Some(Err(e)) => return Some(Err(e)),
673 None => None,
674 };
675 }
676
677 match (self.pending_iter.peek(), &self.next_storage) {
678 (Some((pending_key, _)), Some(storage_val)) => {
679 let cmp = pending_key.cmp(&storage_val.key);
680 let should_yield_pending = if self.reverse {
681 matches!(cmp, Ordering::Greater)
682 } else {
683 matches!(cmp, Ordering::Less)
684 };
685
686 if should_yield_pending {
687 let (key, value) = self.pending_iter.next().unwrap();
688 if let Some(row) = value.row() {
689 return Some(Ok(MultiVersionRow {
690 key,
691 row: row.clone(),
692 version: value.version,
693 }));
694 }
695 } else if matches!(cmp, Ordering::Equal) {
696 let (key, value) = self.pending_iter.next().unwrap();
698 self.next_storage = None;
699 if let Some(row) = value.row() {
700 return Some(Ok(MultiVersionRow {
701 key,
702 row: row.clone(),
703 version: value.version,
704 }));
705 }
706 } else {
707 return Some(Ok(self.next_storage.take().unwrap()));
708 }
709 }
710 (Some(_), None) => {
711 let (key, value) = self.pending_iter.next().unwrap();
712 if let Some(row) = value.row() {
713 return Some(Ok(MultiVersionRow {
714 key,
715 row: row.clone(),
716 version: value.version,
717 }));
718 }
719 }
720 (None, Some(_)) => {
721 return Some(Ok(self.next_storage.take().unwrap()));
722 }
723 (None, None) => return None,
724 }
725 }
726 }
727}