reifydb_transaction/multi/transaction/
write.rs1use std::ops::RangeBounds;
13
14use reifydb_core::{
15 common::CommitVersion,
16 encoded::{
17 encoded::EncodedValues,
18 key::{EncodedKey, EncodedKeyRange},
19 },
20 event::transaction::PostCommitEvent,
21 interface::store::{MultiVersionBatch, MultiVersionCommit, MultiVersionContains, MultiVersionGet},
22};
23use reifydb_type::{
24 Result,
25 util::{cowvec::CowVec, hex},
26};
27use tracing::instrument;
28
29use super::{MultiTransaction, TransactionManagerCommand, version::StandardVersionProvider};
30use crate::{delta::optimize_deltas, multi::types::TransactionValue};
31
32pub struct WriteSavepoint {
34 pub(crate) pending_writes: PendingWrites,
35 pub(crate) count: u64,
36 pub(crate) size: u64,
37 pub(crate) duplicates: Vec<Pending>,
38}
39
40pub struct MultiWriteTransaction {
41 engine: MultiTransaction,
42 pub(crate) tm: TransactionManagerCommand<StandardVersionProvider>,
43}
44
45impl MultiWriteTransaction {
46 #[instrument(name = "transaction::command::new", level = "debug", skip(engine))]
47 pub fn new(engine: MultiTransaction) -> Result<Self> {
48 let tm = engine.tm.write()?;
49 Ok(Self {
50 engine,
51 tm,
52 })
53 }
54}
55
56impl MultiWriteTransaction {
57 pub fn savepoint(&self) -> WriteSavepoint {
59 WriteSavepoint {
60 pending_writes: self.tm.pending_writes.clone(),
61 count: self.tm.count,
62 size: self.tm.size,
63 duplicates: self.tm.duplicates.clone(),
64 }
65 }
66
67 pub fn restore_savepoint(&mut self, sp: WriteSavepoint) {
69 self.tm.pending_writes = sp.pending_writes;
70 self.tm.count = sp.count;
71 self.tm.size = sp.size;
72 self.tm.duplicates = sp.duplicates;
73 }
74}
75
76impl MultiWriteTransaction {
77 #[instrument(name = "transaction::command::commit", level = "debug", skip(self), fields(pending_count = self.tm.pending_writes().len()))]
78 pub fn commit(&mut self) -> Result<CommitVersion> {
79 if self.tm.pending_writes().is_empty() {
81 self.tm.discard();
82 return Ok(CommitVersion(0));
83 }
84
85 let (commit_version, entries) = self.tm.commit_pending()?;
88
89 if entries.is_empty() {
90 self.tm.discard();
91 return Ok(CommitVersion(0));
92 }
93
94 let mut raw_deltas = CowVec::with_capacity(entries.len());
96 for pending in &entries {
97 raw_deltas.push(pending.delta.clone());
98 }
99 let optimized = optimize_deltas(raw_deltas.iter().cloned());
100 let deltas = CowVec::new(optimized);
101
102 MultiVersionCommit::commit(&self.engine.store, deltas.clone(), commit_version)?;
103
104 self.tm.oracle.done_commit(commit_version);
105 self.tm.discard();
106
107 self.engine.event_bus.emit(PostCommitEvent::new(deltas, commit_version));
108
109 Ok(commit_version)
110 }
111}
112
113impl MultiWriteTransaction {
114 pub fn version(&self) -> CommitVersion {
115 self.tm.version()
116 }
117
118 pub fn pending_writes(&self) -> &PendingWrites {
119 self.tm.pending_writes()
120 }
121
122 pub fn read_as_of_version_exclusive(&mut self, version: CommitVersion) {
123 self.tm.read_as_of_version_exclusive(version);
124 }
125
126 pub fn read_as_of_version_inclusive(&mut self, version: CommitVersion) -> Result<()> {
127 self.read_as_of_version_exclusive(CommitVersion(version.0 + 1));
128 Ok(())
129 }
130
131 #[instrument(name = "transaction::command::rollback", level = "debug", skip(self))]
132 pub fn rollback(&mut self) -> Result<()> {
133 self.tm.rollback()
134 }
135
136 #[instrument(name = "transaction::command::contains_key", level = "trace", skip(self), fields(key_hex = %hex::display(key.as_ref())))]
137 pub fn contains_key(&mut self, key: &EncodedKey) -> Result<bool> {
138 let version = self.tm.version();
139 match self.tm.contains_key(key)? {
140 Some(true) => Ok(true),
141 Some(false) => Ok(false),
142 None => MultiVersionContains::contains(&self.engine.store, key, version),
143 }
144 }
145
146 #[instrument(name = "transaction::command::get", level = "trace", skip(self), fields(key_hex = %hex::display(key.as_ref())))]
147 pub fn get(&mut self, key: &EncodedKey) -> Result<Option<TransactionValue>> {
148 let version = self.tm.version();
149 match self.tm.get(key)? {
150 Some(v) => {
151 if v.values().is_some() {
152 Ok(Some(v.into()))
153 } else {
154 Ok(None)
155 }
156 }
157 None => Ok(MultiVersionGet::get(&self.engine.store, key, version)?.map(Into::into)),
158 }
159 }
160
161 #[instrument(name = "transaction::command::set", level = "trace", skip(self, values), fields(key_hex = %hex::display(key.as_ref()), value_len = values.as_ref().len()))]
162 pub fn set(&mut self, key: &EncodedKey, values: EncodedValues) -> Result<()> {
163 self.tm.set(key, values)
164 }
165
166 #[instrument(name = "transaction::command::unset", level = "trace", skip(self, values), fields(key_hex = %hex::display(key.as_ref()), value_len = values.len()))]
167 pub fn unset(&mut self, key: &EncodedKey, values: EncodedValues) -> Result<()> {
168 self.tm.unset(key, values)
169 }
170
171 #[instrument(name = "transaction::command::remove", level = "trace", skip(self), fields(key_hex = %hex::display(key.as_ref())))]
172 pub fn remove(&mut self, key: &EncodedKey) -> Result<()> {
173 self.tm.remove(key)
174 }
175
176 pub fn prefix(&mut self, prefix: &EncodedKey) -> Result<MultiVersionBatch> {
177 let items: Vec<_> = self.range(EncodedKeyRange::prefix(prefix), 1024).collect::<Result<Vec<_>>>()?;
178 Ok(MultiVersionBatch {
179 items,
180 has_more: false,
181 })
182 }
183
184 pub fn prefix_rev(&mut self, prefix: &EncodedKey) -> Result<MultiVersionBatch> {
185 let items: Vec<_> =
186 self.range_rev(EncodedKeyRange::prefix(prefix), 1024).collect::<Result<Vec<_>>>()?;
187 Ok(MultiVersionBatch {
188 items,
189 has_more: false,
190 })
191 }
192
193 pub fn range(
200 &mut self,
201 range: EncodedKeyRange,
202 batch_size: usize,
203 ) -> Box<dyn Iterator<Item = Result<MultiVersionValues>> + Send + '_> {
204 let version = self.tm.version();
205 let (mut marker, pw) = self.tm.marker_with_pending_writes();
206 let start = range.start_bound();
207 let end = range.end_bound();
208
209 marker.mark_range(range.clone());
210
211 let pending: Vec<(EncodedKey, Pending)> =
213 pw.range((start, end)).map(|(k, v)| (k.clone(), v.clone())).collect();
214
215 let storage_iter = self.engine.store.range(range, version, batch_size);
216
217 Box::new(MergePendingIterator::new(pending, storage_iter, false))
218 }
219
220 pub fn range_rev(
226 &mut self,
227 range: EncodedKeyRange,
228 batch_size: usize,
229 ) -> Box<dyn Iterator<Item = Result<MultiVersionValues>> + Send + '_> {
230 let version = self.tm.version();
231 let (mut marker, pw) = self.tm.marker_with_pending_writes();
232 let start = range.start_bound();
233 let end = range.end_bound();
234
235 marker.mark_range(range.clone());
236
237 let pending: Vec<(EncodedKey, Pending)> =
239 pw.range((start, end)).rev().map(|(k, v)| (k.clone(), v.clone())).collect();
240
241 let storage_iter = self.engine.store.range_rev(range, version, batch_size);
242
243 Box::new(MergePendingIterator::new(pending, storage_iter, true))
244 }
245}
246
247use std::{cmp::Ordering, iter, vec};
248
249use reifydb_core::interface::store::MultiVersionValues;
250
251use crate::multi::{pending::PendingWrites, types::Pending};
252
253struct MergePendingIterator<I> {
255 pending_iter: iter::Peekable<vec::IntoIter<(EncodedKey, Pending)>>,
256 storage_iter: I,
257 next_storage: Option<MultiVersionValues>,
258 reverse: bool,
259}
260
261impl<I> MergePendingIterator<I>
262where
263 I: Iterator<Item = Result<MultiVersionValues>>,
264{
265 fn new(pending: Vec<(EncodedKey, Pending)>, storage_iter: I, reverse: bool) -> Self {
266 Self {
267 pending_iter: pending.into_iter().peekable(),
268 storage_iter,
269 next_storage: None,
270 reverse,
271 }
272 }
273}
274
275impl<I> Iterator for MergePendingIterator<I>
276where
277 I: Iterator<Item = Result<MultiVersionValues>>,
278{
279 type Item = Result<MultiVersionValues>;
280
281 fn next(&mut self) -> Option<Self::Item> {
282 loop {
283 if self.next_storage.is_none() {
285 self.next_storage = match self.storage_iter.next() {
286 Some(Ok(v)) => Some(v),
287 Some(Err(e)) => return Some(Err(e)),
288 None => None,
289 };
290 }
291
292 match (self.pending_iter.peek(), &self.next_storage) {
293 (Some((pending_key, _)), Some(storage_val)) => {
294 let cmp = pending_key.cmp(&storage_val.key);
295 let should_yield_pending = if self.reverse {
296 matches!(cmp, Ordering::Greater)
298 } else {
299 matches!(cmp, Ordering::Less)
301 };
302
303 if should_yield_pending {
304 let (key, value) = self.pending_iter.next().unwrap();
306 if let Some(values) = value.values() {
307 return Some(Ok(MultiVersionValues {
308 key,
309 values: values.clone(),
310 version: value.version,
311 }));
312 }
313 } else if matches!(cmp, Ordering::Equal) {
315 let (key, value) = self.pending_iter.next().unwrap();
317 self.next_storage = None; if let Some(values) = value.values() {
319 return Some(Ok(MultiVersionValues {
320 key,
321 values: values.clone(),
322 version: value.version,
323 }));
324 }
325 } else {
327 return Some(Ok(self.next_storage.take().unwrap()));
329 }
330 }
331 (Some(_), None) => {
332 let (key, value) = self.pending_iter.next().unwrap();
334 if let Some(values) = value.values() {
335 return Some(Ok(MultiVersionValues {
336 key,
337 values: values.clone(),
338 version: value.version,
339 }));
340 }
341 }
343 (None, Some(_)) => {
344 return Some(Ok(self.next_storage.take().unwrap()));
346 }
347 (None, None) => return None,
348 }
349 }
350 }
351}