reifydb_transaction/multi/transaction/
mod.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the AGPL-3.0-or-later, see license.md file
3
4// This file includes and modifies code from the skipdb project (https://github.com/al8n/skipdb),
5// originally licensed under the Apache License, Version 2.0.
6// Original copyright:
7//   Copyright (c) 2024 Al Liu
8//
9// The original Apache License can be found at:
10//   http://www.apache.org/licenses/LICENSE-2.0
11
12use core::mem;
13use std::{ops::Deref, sync::Arc, time::Duration};
14
15pub use command::*;
16use oracle::*;
17use reifydb_core::{CommitVersion, EncodedKey, EncodedKeyRange, event::EventBus, interface::TransactionId};
18use reifydb_store_transaction::{
19	MultiVersionBatch, MultiVersionContains, MultiVersionGet, MultiVersionRange, MultiVersionRangeRev,
20	TransactionStore,
21};
22use reifydb_type::util::hex;
23use tracing::instrument;
24use version::{StandardVersionProvider, VersionProvider};
25
26pub use crate::multi::types::*;
27use crate::single::{TransactionSingle, TransactionSvl};
28
29mod command;
30mod command_tx;
31mod oracle;
32mod oracle_cleanup;
33pub mod query;
34mod query_tx;
35mod version;
36
37pub use command_tx::CommandTransaction;
38pub use oracle::MAX_COMMITTED_TXNS;
39pub use query_tx::QueryTransaction;
40
41use crate::multi::{
42	AwaitWatermarkError, conflict::ConflictManager, pending::PendingWrites,
43	transaction::query::TransactionManagerQuery,
44};
45
46pub struct TransactionManager<L>
47where
48	L: VersionProvider,
49{
50	inner: Arc<Oracle<L>>,
51}
52
53impl<L> Clone for TransactionManager<L>
54where
55	L: VersionProvider,
56{
57	fn clone(&self) -> Self {
58		Self {
59			inner: self.inner.clone(),
60		}
61	}
62}
63
64impl<L> TransactionManager<L>
65where
66	L: VersionProvider,
67{
68	#[instrument(name = "transaction::manager::write", level = "debug", skip(self))]
69	pub async fn write(&self) -> Result<TransactionManagerCommand<L>, reifydb_type::Error> {
70		Ok(TransactionManagerCommand {
71			id: TransactionId::generate(),
72			oracle: self.inner.clone(),
73			version: self.inner.version().await?,
74			read_version: None,
75			size: 0,
76			count: 0,
77			conflicts: ConflictManager::new(),
78			pending_writes: PendingWrites::new(),
79			duplicates: Vec::new(),
80			discarded: false,
81			done_query: false,
82		})
83	}
84}
85
86impl<L> TransactionManager<L>
87where
88	L: VersionProvider,
89{
90	#[instrument(name = "transaction::manager::new", level = "debug", skip(clock))]
91	pub async fn new(clock: L) -> crate::Result<Self> {
92		let version = clock.next().await?;
93		let oracle = Oracle::new(clock).await;
94		oracle.query.done(version);
95		oracle.command.done(version);
96		Ok(Self {
97			inner: Arc::new(oracle),
98		})
99	}
100
101	#[instrument(name = "transaction::manager::version", level = "trace", skip(self))]
102	pub async fn version(&self) -> crate::Result<CommitVersion> {
103		self.inner.version().await
104	}
105}
106
107impl<L> TransactionManager<L>
108where
109	L: VersionProvider,
110{
111	#[instrument(name = "transaction::manager::discard_hint", level = "trace", skip(self))]
112	pub fn discard_hint(&self) -> CommitVersion {
113		self.inner.discard_at_or_below()
114	}
115
116	#[instrument(name = "transaction::manager::query", level = "debug", skip(self), fields(as_of_version = ?version))]
117	pub async fn query(&self, version: Option<CommitVersion>) -> crate::Result<TransactionManagerQuery<L>> {
118		Ok(if let Some(version) = version {
119			TransactionManagerQuery::new_time_travel(TransactionId::generate(), self.clone(), version)
120		} else {
121			TransactionManagerQuery::new_current(
122				TransactionId::generate(),
123				self.clone(),
124				self.inner.version().await?,
125			)
126		})
127	}
128
129	/// Wait for the command watermark to reach the specified version.
130	/// Returns Ok(()) if the watermark reaches the version within the timeout,
131	/// or Err(AwaitWatermarkError) if the timeout expires.
132	///
133	/// This is useful for CDC polling to ensure all in-flight commits have
134	/// completed their storage writes before querying for CDC events.
135	#[instrument(name = "transaction::manager::wait_for_watermark", level = "debug", skip(self))]
136	pub async fn try_wait_for_watermark(
137		&self,
138		version: CommitVersion,
139		timeout: Duration,
140	) -> Result<(), AwaitWatermarkError> {
141		if self.inner.command.wait_for_mark_timeout(version, timeout).await {
142			Ok(())
143		} else {
144			Err(AwaitWatermarkError {
145				version,
146				timeout,
147			})
148		}
149	}
150
151	/// Returns the highest version where ALL prior versions have completed.
152	/// This is useful for CDC polling to know the safe upper bound for fetching
153	/// CDC events - all events up to this version are guaranteed to be in storage.
154	#[instrument(name = "transaction::manager::done_until", level = "trace", skip(self))]
155	pub fn done_until(&self) -> CommitVersion {
156		self.inner.command.done_until()
157	}
158
159	/// Returns (query_done_until, command_done_until) for debugging watermark state.
160	pub fn watermarks(&self) -> (CommitVersion, CommitVersion) {
161		(self.inner.query.done_until(), self.inner.command.done_until())
162	}
163}
164
165// ============================================================================
166// Transaction - The main multi-version transaction type
167// ============================================================================
168
169pub struct TransactionMulti(Arc<Inner>);
170
171pub struct Inner {
172	pub(crate) tm: TransactionManager<StandardVersionProvider>,
173	pub(crate) store: TransactionStore,
174	pub(crate) event_bus: EventBus,
175}
176
177impl Deref for TransactionMulti {
178	type Target = Inner;
179
180	fn deref(&self) -> &Self::Target {
181		&self.0
182	}
183}
184
185impl Clone for TransactionMulti {
186	fn clone(&self) -> Self {
187		Self(self.0.clone())
188	}
189}
190
191impl Inner {
192	async fn new(store: TransactionStore, single: TransactionSingle, event_bus: EventBus) -> crate::Result<Self> {
193		let version_provider = StandardVersionProvider::new(single).await?;
194		let tm = TransactionManager::new(version_provider).await?;
195
196		Ok(Self {
197			tm,
198			store,
199			event_bus,
200		})
201	}
202
203	async fn version(&self) -> crate::Result<CommitVersion> {
204		self.tm.version().await
205	}
206}
207
208impl TransactionMulti {
209	pub async fn testing() -> Self {
210		let store = TransactionStore::testing_memory().await;
211		let event_bus = EventBus::new();
212		Self::new(
213			store.clone(),
214			TransactionSingle::SingleVersionLock(TransactionSvl::new(store, event_bus.clone())),
215			event_bus,
216		)
217		.await
218		.unwrap()
219	}
220}
221
222impl TransactionMulti {
223	#[instrument(name = "transaction::new", level = "debug", skip(store, single, event_bus))]
224	pub async fn new(
225		store: TransactionStore,
226		single: TransactionSingle,
227		event_bus: EventBus,
228	) -> crate::Result<Self> {
229		Ok(Self(Arc::new(Inner::new(store, single, event_bus).await?)))
230	}
231}
232
233impl TransactionMulti {
234	#[instrument(name = "transaction::version", level = "trace", skip(self))]
235	pub async fn version(&self) -> crate::Result<CommitVersion> {
236		self.0.version().await
237	}
238
239	#[instrument(name = "transaction::begin_query", level = "debug", skip(self))]
240	pub async fn begin_query(&self) -> crate::Result<QueryTransaction> {
241		QueryTransaction::new(self.clone(), None).await
242	}
243}
244
245impl TransactionMulti {
246	#[instrument(name = "transaction::begin_command", level = "debug", skip(self))]
247	pub async fn begin_command(&self) -> crate::Result<CommandTransaction> {
248		CommandTransaction::new(self.clone()).await
249	}
250}
251
252pub enum TransactionType {
253	Query(QueryTransaction),
254	Command(CommandTransaction),
255}
256
257impl TransactionMulti {
258	#[instrument(name = "transaction::get", level = "trace", skip(self), fields(key_hex = %hex::encode(key.as_ref()), version = version.0))]
259	pub async fn get(
260		&self,
261		key: &EncodedKey,
262		version: CommitVersion,
263	) -> Result<Option<Committed>, reifydb_type::Error> {
264		Ok(MultiVersionGet::get(&self.store, key, version).await?.map(|sv| sv.into()))
265	}
266
267	#[instrument(name = "transaction::contains_key", level = "trace", skip(self), fields(key_hex = %hex::encode(key.as_ref()), version = version.0))]
268	pub async fn contains_key(
269		&self,
270		key: &EncodedKey,
271		version: CommitVersion,
272	) -> Result<bool, reifydb_type::Error> {
273		MultiVersionContains::contains(&self.store, key, version).await
274	}
275
276	#[instrument(name = "transaction::range_batch", level = "trace", skip(self), fields(version = version.0, batch_size = batch_size))]
277	pub async fn range_batch(
278		&self,
279		range: EncodedKeyRange,
280		version: CommitVersion,
281		batch_size: u64,
282	) -> reifydb_type::Result<MultiVersionBatch> {
283		MultiVersionRange::range_batch(&self.store, range, version, batch_size).await
284	}
285
286	pub async fn range(
287		&self,
288		range: EncodedKeyRange,
289		version: CommitVersion,
290	) -> reifydb_type::Result<MultiVersionBatch> {
291		self.range_batch(range, version, 1024).await
292	}
293
294	pub async fn range_rev_batch(
295		&self,
296		range: EncodedKeyRange,
297		version: CommitVersion,
298		batch_size: u64,
299	) -> reifydb_type::Result<MultiVersionBatch> {
300		MultiVersionRangeRev::range_rev_batch(&self.store, range, version, batch_size).await
301	}
302
303	pub async fn range_rev(
304		&self,
305		range: EncodedKeyRange,
306		version: CommitVersion,
307	) -> reifydb_type::Result<MultiVersionBatch> {
308		self.range_rev_batch(range, version, 1024).await
309	}
310
311	/// Get a reference to the underlying transaction store.
312	pub fn store(&self) -> &TransactionStore {
313		&self.store
314	}
315}