Skip to main content

reifydb_transaction/multi/transaction/
write.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4// This file includes and modifies code from the skipdb project (https://github.com/al8n/skipdb),
5// originally licensed under the Apache License, Version 2.0.
6// Original copyright:
7//   Copyright (c) 2024 Al Liu
8//
9// The original Apache License can be found at:
10//   http://www.apache.org/licenses/LICENSE-2.0
11
12use std::ops::RangeBounds;
13
14use reifydb_core::{
15	common::CommitVersion,
16	encoded::{
17		encoded::EncodedValues,
18		key::{EncodedKey, EncodedKeyRange},
19	},
20	event::transaction::PostCommitEvent,
21	interface::store::{MultiVersionBatch, MultiVersionCommit, MultiVersionContains, MultiVersionGet},
22};
23use reifydb_type::{
24	Result,
25	util::{cowvec::CowVec, hex},
26};
27use tracing::instrument;
28
29use super::{MultiTransaction, TransactionManagerCommand, version::StandardVersionProvider};
30use crate::{delta::optimize_deltas, multi::types::TransactionValue};
31
32/// Snapshot of write transaction state for savepoint/restore.
33pub struct WriteSavepoint {
34	pub(crate) pending_writes: PendingWrites,
35	pub(crate) count: u64,
36	pub(crate) size: u64,
37	pub(crate) duplicates: Vec<Pending>,
38}
39
40pub struct MultiWriteTransaction {
41	engine: MultiTransaction,
42	pub(crate) tm: TransactionManagerCommand<StandardVersionProvider>,
43}
44
45impl MultiWriteTransaction {
46	#[instrument(name = "transaction::command::new", level = "debug", skip(engine))]
47	pub fn new(engine: MultiTransaction) -> Result<Self> {
48		let tm = engine.tm.write()?;
49		Ok(Self {
50			engine,
51			tm,
52		})
53	}
54}
55
56impl MultiWriteTransaction {
57	/// Snapshot pending writes for later restore.
58	pub fn savepoint(&self) -> WriteSavepoint {
59		WriteSavepoint {
60			pending_writes: self.tm.pending_writes.clone(),
61			count: self.tm.count,
62			size: self.tm.size,
63			duplicates: self.tm.duplicates.clone(),
64		}
65	}
66
67	/// Restore pending writes from a savepoint.
68	pub fn restore_savepoint(&mut self, sp: WriteSavepoint) {
69		self.tm.pending_writes = sp.pending_writes;
70		self.tm.count = sp.count;
71		self.tm.size = sp.size;
72		self.tm.duplicates = sp.duplicates;
73	}
74}
75
76impl MultiWriteTransaction {
77	#[instrument(name = "transaction::command::commit", level = "debug", skip(self), fields(pending_count = self.tm.pending_writes().len()))]
78	pub fn commit(&mut self) -> Result<CommitVersion> {
79		// For read-only transactions (no pending writes), skip conflict detection
80		if self.tm.pending_writes().is_empty() {
81			self.tm.discard();
82			return Ok(CommitVersion(0));
83		}
84
85		// Use commit_pending to allocate the commit version via oracle BEFORE writing to storage
86		// This ensures entries have the correct commit version
87		let (commit_version, entries) = self.tm.commit_pending()?;
88
89		if entries.is_empty() {
90			self.tm.discard();
91			return Ok(CommitVersion(0));
92		}
93
94		// Collect and optimize deltas for storage commit
95		let mut raw_deltas = CowVec::with_capacity(entries.len());
96		for pending in &entries {
97			raw_deltas.push(pending.delta.clone());
98		}
99		let optimized = optimize_deltas(raw_deltas.iter().cloned());
100		let deltas = CowVec::new(optimized);
101
102		MultiVersionCommit::commit(&self.engine.store, deltas.clone(), commit_version)?;
103
104		self.tm.oracle.done_commit(commit_version);
105		self.tm.discard();
106
107		self.engine.event_bus.emit(PostCommitEvent::new(deltas, commit_version));
108
109		Ok(commit_version)
110	}
111}
112
113impl MultiWriteTransaction {
114	pub fn version(&self) -> CommitVersion {
115		self.tm.version()
116	}
117
118	pub fn pending_writes(&self) -> &PendingWrites {
119		self.tm.pending_writes()
120	}
121
122	pub fn read_as_of_version_exclusive(&mut self, version: CommitVersion) {
123		self.tm.read_as_of_version_exclusive(version);
124	}
125
126	pub fn read_as_of_version_inclusive(&mut self, version: CommitVersion) -> Result<()> {
127		self.read_as_of_version_exclusive(CommitVersion(version.0 + 1));
128		Ok(())
129	}
130
131	#[instrument(name = "transaction::command::rollback", level = "debug", skip(self))]
132	pub fn rollback(&mut self) -> Result<()> {
133		self.tm.rollback()
134	}
135
136	#[instrument(name = "transaction::command::contains_key", level = "trace", skip(self), fields(key_hex = %hex::display(key.as_ref())))]
137	pub fn contains_key(&mut self, key: &EncodedKey) -> Result<bool> {
138		let version = self.tm.version();
139		match self.tm.contains_key(key)? {
140			Some(true) => Ok(true),
141			Some(false) => Ok(false),
142			None => MultiVersionContains::contains(&self.engine.store, key, version),
143		}
144	}
145
146	#[instrument(name = "transaction::command::get", level = "trace", skip(self), fields(key_hex = %hex::display(key.as_ref())))]
147	pub fn get(&mut self, key: &EncodedKey) -> Result<Option<TransactionValue>> {
148		let version = self.tm.version();
149		match self.tm.get(key)? {
150			Some(v) => {
151				if v.values().is_some() {
152					Ok(Some(v.into()))
153				} else {
154					Ok(None)
155				}
156			}
157			None => Ok(MultiVersionGet::get(&self.engine.store, key, version)?.map(Into::into)),
158		}
159	}
160
161	#[instrument(name = "transaction::command::set", level = "trace", skip(self, values), fields(key_hex = %hex::display(key.as_ref()), value_len = values.as_ref().len()))]
162	pub fn set(&mut self, key: &EncodedKey, values: EncodedValues) -> Result<()> {
163		self.tm.set(key, values)
164	}
165
166	#[instrument(name = "transaction::command::unset", level = "trace", skip(self, values), fields(key_hex = %hex::display(key.as_ref()), value_len = values.len()))]
167	pub fn unset(&mut self, key: &EncodedKey, values: EncodedValues) -> Result<()> {
168		self.tm.unset(key, values)
169	}
170
171	#[instrument(name = "transaction::command::remove", level = "trace", skip(self), fields(key_hex = %hex::display(key.as_ref())))]
172	pub fn remove(&mut self, key: &EncodedKey) -> Result<()> {
173		self.tm.remove(key)
174	}
175
176	pub fn prefix(&mut self, prefix: &EncodedKey) -> Result<MultiVersionBatch> {
177		let items: Vec<_> = self.range(EncodedKeyRange::prefix(prefix), 1024).collect::<Result<Vec<_>>>()?;
178		Ok(MultiVersionBatch {
179			items,
180			has_more: false,
181		})
182	}
183
184	pub fn prefix_rev(&mut self, prefix: &EncodedKey) -> Result<MultiVersionBatch> {
185		let items: Vec<_> =
186			self.range_rev(EncodedKeyRange::prefix(prefix), 1024).collect::<Result<Vec<_>>>()?;
187		Ok(MultiVersionBatch {
188			items,
189			has_more: false,
190		})
191	}
192
193	/// Create a streaming iterator for forward range queries, merging pending writes.
194	///
195	/// This properly handles high version density by scanning until batch_size
196	/// unique logical keys are collected. The stream yields individual entries
197	/// and maintains cursor state internally. Pending writes are merged with
198	/// committed storage data.
199	pub fn range(
200		&mut self,
201		range: EncodedKeyRange,
202		batch_size: usize,
203	) -> Box<dyn Iterator<Item = Result<MultiVersionValues>> + Send + '_> {
204		let version = self.tm.version();
205		let (mut marker, pw) = self.tm.marker_with_pending_writes();
206		let start = range.start_bound();
207		let end = range.end_bound();
208
209		marker.mark_range(range.clone());
210
211		// Collect pending writes in range as owned data
212		let pending: Vec<(EncodedKey, Pending)> =
213			pw.range((start, end)).map(|(k, v)| (k.clone(), v.clone())).collect();
214
215		let storage_iter = self.engine.store.range(range, version, batch_size);
216
217		Box::new(MergePendingIterator::new(pending, storage_iter, false))
218	}
219
220	/// Create a streaming iterator for reverse range queries, merging pending writes.
221	///
222	/// This properly handles high version density by scanning until batch_size
223	/// unique logical keys are collected. The stream yields individual entries
224	/// in reverse key order and maintains cursor state internally.
225	pub fn range_rev(
226		&mut self,
227		range: EncodedKeyRange,
228		batch_size: usize,
229	) -> Box<dyn Iterator<Item = Result<MultiVersionValues>> + Send + '_> {
230		let version = self.tm.version();
231		let (mut marker, pw) = self.tm.marker_with_pending_writes();
232		let start = range.start_bound();
233		let end = range.end_bound();
234
235		marker.mark_range(range.clone());
236
237		// Collect pending writes in range as owned data (reversed)
238		let pending: Vec<(EncodedKey, Pending)> =
239			pw.range((start, end)).rev().map(|(k, v)| (k.clone(), v.clone())).collect();
240
241		let storage_iter = self.engine.store.range_rev(range, version, batch_size);
242
243		Box::new(MergePendingIterator::new(pending, storage_iter, true))
244	}
245}
246
247use std::{cmp::Ordering, iter, vec};
248
249use reifydb_core::interface::store::MultiVersionValues;
250
251use crate::multi::{pending::PendingWrites, types::Pending};
252
253/// Iterator that merges pending writes with storage iterator.
254struct MergePendingIterator<I> {
255	pending_iter: iter::Peekable<vec::IntoIter<(EncodedKey, Pending)>>,
256	storage_iter: I,
257	next_storage: Option<MultiVersionValues>,
258	reverse: bool,
259}
260
261impl<I> MergePendingIterator<I>
262where
263	I: Iterator<Item = Result<MultiVersionValues>>,
264{
265	fn new(pending: Vec<(EncodedKey, Pending)>, storage_iter: I, reverse: bool) -> Self {
266		Self {
267			pending_iter: pending.into_iter().peekable(),
268			storage_iter,
269			next_storage: None,
270			reverse,
271		}
272	}
273}
274
275impl<I> Iterator for MergePendingIterator<I>
276where
277	I: Iterator<Item = Result<MultiVersionValues>>,
278{
279	type Item = Result<MultiVersionValues>;
280
281	fn next(&mut self) -> Option<Self::Item> {
282		loop {
283			// Fetch next storage item if needed
284			if self.next_storage.is_none() {
285				self.next_storage = match self.storage_iter.next() {
286					Some(Ok(v)) => Some(v),
287					Some(Err(e)) => return Some(Err(e)),
288					None => None,
289				};
290			}
291
292			match (self.pending_iter.peek(), &self.next_storage) {
293				(Some((pending_key, _)), Some(storage_val)) => {
294					let cmp = pending_key.cmp(&storage_val.key);
295					let should_yield_pending = if self.reverse {
296						// Reverse: larger keys first
297						matches!(cmp, Ordering::Greater)
298					} else {
299						// Forward: smaller keys first
300						matches!(cmp, Ordering::Less)
301					};
302
303					if should_yield_pending {
304						// Pending key comes first
305						let (key, value) = self.pending_iter.next().unwrap();
306						if let Some(values) = value.values() {
307							return Some(Ok(MultiVersionValues {
308								key,
309								values: values.clone(),
310								version: value.version,
311							}));
312						}
313						// Tombstone: skip (continue loop)
314					} else if matches!(cmp, Ordering::Equal) {
315						// Same key - pending shadows storage
316						let (key, value) = self.pending_iter.next().unwrap();
317						self.next_storage = None; // Consume storage entry
318						if let Some(values) = value.values() {
319							return Some(Ok(MultiVersionValues {
320								key,
321								values: values.clone(),
322								version: value.version,
323							}));
324						}
325						// Tombstone: skip (continue loop)
326					} else {
327						// Storage key comes first
328						return Some(Ok(self.next_storage.take().unwrap()));
329					}
330				}
331				(Some(_), None) => {
332					// Only pending left
333					let (key, value) = self.pending_iter.next().unwrap();
334					if let Some(values) = value.values() {
335						return Some(Ok(MultiVersionValues {
336							key,
337							values: values.clone(),
338							version: value.version,
339						}));
340					}
341					// Tombstone: skip (continue loop)
342				}
343				(None, Some(_)) => {
344					// Only storage left
345					return Some(Ok(self.next_storage.take().unwrap()));
346				}
347				(None, None) => return None,
348			}
349		}
350	}
351}