Skip to main content

reifydb_transaction/multi/transaction/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4// This file includes and modifies code from the skipdb project (https://github.com/al8n/skipdb),
5// originally licensed under the Apache License, Version 2.0.
6// Original copyright:
7//   Copyright (c) 2024 Al Liu
8//
9// The original Apache License can be found at:
10//   http://www.apache.org/licenses/LICENSE-2.0
11
12use std::{ops::Deref, sync::Arc, time::Duration};
13
14use reifydb_core::{
15	common::CommitVersion,
16	config::SystemConfig,
17	encoded::key::EncodedKey,
18	event::EventBus,
19	interface::store::{MultiVersionContains, MultiVersionGet},
20};
21use reifydb_runtime::{actor::system::ActorSystem, clock::Clock};
22use reifydb_store_multi::MultiStore;
23use reifydb_type::{Result, util::hex};
24use tracing::instrument;
25use version::{StandardVersionProvider, VersionProvider};
26
27use crate::{
28	TransactionId,
29	multi::{oracle, oracle::*, types::*},
30	single::SingleTransaction,
31};
32
33pub mod manager;
34pub mod read;
35pub(crate) mod version;
36pub mod write;
37
38use reifydb_runtime::SharedRuntimeConfig;
39use reifydb_store_single::SingleStore;
40
41use crate::multi::{
42	MultiReadTransaction, MultiWriteTransaction,
43	conflict::ConflictManager,
44	pending::PendingWrites,
45	transaction::manager::{TransactionManagerCommand, TransactionManagerQuery},
46};
47
48pub struct TransactionManager<L>
49where
50	L: VersionProvider,
51{
52	inner: Arc<Oracle<L>>,
53}
54
55impl<L> Clone for TransactionManager<L>
56where
57	L: VersionProvider,
58{
59	fn clone(&self) -> Self {
60		Self {
61			inner: self.inner.clone(),
62		}
63	}
64}
65
66impl<L> TransactionManager<L>
67where
68	L: VersionProvider,
69{
70	#[instrument(name = "transaction::manager::write", level = "debug", skip(self))]
71	pub fn write(&self) -> Result<TransactionManagerCommand<L>> {
72		Ok(TransactionManagerCommand {
73			id: TransactionId::generate(),
74			oracle: self.inner.clone(),
75			version: self.inner.version()?,
76			read_version: None,
77			size: 0,
78			count: 0,
79			conflicts: ConflictManager::new(),
80			pending_writes: PendingWrites::new(),
81			duplicates: Vec::new(),
82			discarded: false,
83			done_query: false,
84		})
85	}
86}
87
88impl<L> TransactionManager<L>
89where
90	L: VersionProvider,
91{
92	#[instrument(
93		name = "transaction::manager::new",
94		level = "debug",
95		skip(clock, actor_system, metrics_clock, config)
96	)]
97	pub fn new(clock: L, actor_system: ActorSystem, metrics_clock: Clock, config: SystemConfig) -> Result<Self> {
98		let version = clock.next()?;
99		let oracle = Oracle::new(clock, actor_system, metrics_clock, config);
100		oracle.query.done(version);
101		oracle.command.done(version);
102		Ok(Self {
103			inner: Arc::new(oracle),
104		})
105	}
106
107	/// Get the actor system
108	pub fn actor_system(&self) -> ActorSystem {
109		self.inner.actor_system()
110	}
111
112	/// Get the shared system config from the oracle.
113	pub fn system_config(&self) -> SystemConfig {
114		self.inner.system_config()
115	}
116
117	#[instrument(name = "transaction::manager::version", level = "trace", skip(self))]
118	pub fn version(&self) -> Result<CommitVersion> {
119		self.inner.version()
120	}
121}
122
123impl<L> TransactionManager<L>
124where
125	L: VersionProvider,
126{
127	#[instrument(name = "transaction::manager::query", level = "debug", skip(self), fields(as_of_version = ?version))]
128	pub fn query(&self, version: Option<CommitVersion>) -> Result<TransactionManagerQuery<L>> {
129		let safe_version = self.inner.version()?;
130
131		Ok(if let Some(version) = version {
132			assert!(version <= safe_version);
133			TransactionManagerQuery::new_time_travel(TransactionId::generate(), self.clone(), version)
134		} else {
135			TransactionManagerQuery::new_current(TransactionId::generate(), self.clone(), safe_version)
136		})
137	}
138
139	/// Returns the highest version where ALL prior versions have completed.
140	/// This is useful for CDC polling to know the safe upper bound for fetching
141	/// CDC events - all events up to this version are guaranteed to be in storage.
142	#[instrument(name = "transaction::manager::done_until", level = "trace", skip(self))]
143	pub fn done_until(&self) -> CommitVersion {
144		self.inner.command.done_until()
145	}
146
147	/// Wait for the watermark to reach the given version with a timeout.
148	/// Returns true if the watermark reached the target, false if timeout occurred.
149	#[instrument(name = "transaction::manager::wait_for_mark_timeout", level = "trace", skip(self))]
150	pub fn wait_for_mark_timeout(&self, version: CommitVersion, timeout: Duration) -> bool {
151		self.inner.command.wait_for_mark_timeout(version, timeout)
152	}
153}
154
155pub struct MultiTransaction(Arc<Inner>);
156
157pub struct Inner {
158	pub(crate) tm: TransactionManager<StandardVersionProvider>,
159	pub(crate) store: MultiStore,
160	pub(crate) event_bus: EventBus,
161}
162
163impl Deref for MultiTransaction {
164	type Target = Inner;
165
166	fn deref(&self) -> &Self::Target {
167		&self.0
168	}
169}
170
171impl Clone for MultiTransaction {
172	fn clone(&self) -> Self {
173		Self(self.0.clone())
174	}
175}
176
177impl Inner {
178	fn new(
179		store: MultiStore,
180		single: SingleTransaction,
181		event_bus: EventBus,
182		actor_system: ActorSystem,
183		metrics_clock: Clock,
184		config: SystemConfig,
185	) -> Result<Self> {
186		let version_provider = StandardVersionProvider::new(single)?;
187		let tm = TransactionManager::new(version_provider, actor_system, metrics_clock, config)?;
188
189		Ok(Self {
190			tm,
191			store,
192			event_bus,
193		})
194	}
195
196	fn version(&self) -> Result<CommitVersion> {
197		self.tm.version()
198	}
199
200	fn actor_system(&self) -> ActorSystem {
201		self.tm.actor_system()
202	}
203}
204
205impl MultiTransaction {
206	pub fn testing() -> Self {
207		let multi_store = MultiStore::testing_memory();
208		let single_store = SingleStore::testing_memory();
209		let actor_system = ActorSystem::new(SharedRuntimeConfig::default().actor_system_config());
210		let event_bus = EventBus::new(&actor_system);
211		let system_config = SystemConfig::new();
212		oracle::register_defaults(&system_config);
213		Self::new(
214			multi_store,
215			SingleTransaction::new(single_store, event_bus.clone()),
216			event_bus,
217			actor_system,
218			Clock::default(),
219			system_config,
220		)
221		.unwrap()
222	}
223}
224
225impl MultiTransaction {
226	#[instrument(
227		name = "transaction::new",
228		level = "debug",
229		skip(store, single, event_bus, actor_system, metrics_clock, system_config)
230	)]
231	pub fn new(
232		store: MultiStore,
233		single: SingleTransaction,
234		event_bus: EventBus,
235		actor_system: ActorSystem,
236		metrics_clock: Clock,
237		system_config: SystemConfig,
238	) -> Result<Self> {
239		Ok(Self(Arc::new(Inner::new(store, single, event_bus, actor_system, metrics_clock, system_config)?)))
240	}
241
242	/// Get the actor system
243	pub fn actor_system(&self) -> ActorSystem {
244		self.0.actor_system()
245	}
246
247	/// Get the shared system config from the oracle.
248	pub fn system_config(&self) -> SystemConfig {
249		self.0.tm.system_config()
250	}
251}
252
253/// Register oracle config defaults into a SystemConfig registry.
254pub fn register_oracle_defaults(config: &SystemConfig) {
255	oracle::register_defaults(config)
256}
257
258impl MultiTransaction {
259	#[instrument(name = "transaction::version", level = "trace", skip(self))]
260	pub fn version(&self) -> Result<CommitVersion> {
261		self.0.version()
262	}
263
264	#[instrument(name = "transaction::begin_query", level = "debug", skip(self))]
265	pub fn begin_query(&self) -> Result<MultiReadTransaction> {
266		MultiReadTransaction::new(self.clone(), None)
267	}
268
269	/// Begin a query transaction at a specific version.
270	///
271	/// This is used for parallel query execution where multiple tasks need to
272	/// read from the same snapshot (same CommitVersion) for consistency.
273	#[instrument(name = "transaction::begin_query_at_version", level = "debug", skip(self), fields(version = %version.0))]
274	pub fn begin_query_at_version(&self, version: CommitVersion) -> Result<MultiReadTransaction> {
275		MultiReadTransaction::new(self.clone(), Some(version))
276	}
277}
278
279impl MultiTransaction {
280	#[instrument(name = "transaction::begin_command", level = "debug", skip(self))]
281	pub fn begin_command(&self) -> Result<MultiWriteTransaction> {
282		MultiWriteTransaction::new(self.clone())
283	}
284}
285
286pub enum TransactionType {
287	Query(MultiReadTransaction),
288	Command(MultiWriteTransaction),
289}
290
291impl MultiTransaction {
292	#[instrument(name = "transaction::get", level = "trace", skip(self), fields(key_hex = %hex::encode(key.as_ref()), version = version.0))]
293	pub fn get(&self, key: &EncodedKey, version: CommitVersion) -> Result<Option<Committed>> {
294		Ok(MultiVersionGet::get(&self.store, key, version)?.map(|sv| sv.into()))
295	}
296
297	#[instrument(name = "transaction::contains_key", level = "trace", skip(self), fields(key_hex = %hex::encode(key.as_ref()), version = version.0))]
298	pub fn contains_key(&self, key: &EncodedKey, version: CommitVersion) -> Result<bool> {
299		MultiVersionContains::contains(&self.store, key, version)
300	}
301
302	/// Get a reference to the underlying transaction store.
303	pub fn store(&self) -> &MultiStore {
304		&self.store
305	}
306}