Skip to main content

reifydb_engine/
session.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4//! Unified session type for database access.
5//!
6//! A `Session` binds an identity to an engine and provides query, command, and
7//! admin methods. Sessions are created either from a validated auth token
8//! (server path) or directly from an `IdentityId` (embedded/trusted path).
9
10use std::{thread, time::Duration};
11
12use reifydb_core::{execution::ExecutionResult, interface::catalog::token::Token};
13use reifydb_runtime::context::rng::Rng;
14use reifydb_type::{params::Params, value::identity::IdentityId};
15use tracing::{debug, instrument, warn};
16
17use crate::engine::StandardEngine;
18
19/// Backoff strategy between retry attempts.
20pub enum Backoff {
21	/// No delay between retries.
22	None,
23	/// Fixed delay between each retry attempt.
24	Fixed(Duration),
25	/// Exponential backoff: delay doubles each attempt, capped at `max`.
26	Exponential {
27		base: Duration,
28		max: Duration,
29	},
30	ExponentialJitter {
31		base: Duration,
32		max: Duration,
33	},
34}
35
36/// Controls how many times a write transaction is retried on conflict (`TXN_001`).
37pub struct RetryStrategy {
38	pub max_attempts: u32,
39	pub backoff: Backoff,
40}
41
42impl Default for RetryStrategy {
43	fn default() -> Self {
44		Self {
45			max_attempts: 10,
46			backoff: Backoff::ExponentialJitter {
47				base: Duration::from_millis(5),
48				max: Duration::from_millis(200),
49			},
50		}
51	}
52}
53
54impl RetryStrategy {
55	/// No retries — fail immediately on conflict.
56	pub fn no_retry() -> Self {
57		Self {
58			max_attempts: 1,
59			backoff: Backoff::None,
60		}
61	}
62
63	pub fn default_conflict_retry() -> Self {
64		Self::default()
65	}
66
67	/// Fixed delay between retry attempts.
68	pub fn with_fixed_backoff(max_attempts: u32, delay: Duration) -> Self {
69		Self {
70			max_attempts,
71			backoff: Backoff::Fixed(delay),
72		}
73	}
74
75	/// Exponential backoff: delay doubles each attempt, capped at `max`.
76	pub fn with_exponential_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
77		Self {
78			max_attempts,
79			backoff: Backoff::Exponential {
80				base,
81				max,
82			},
83		}
84	}
85
86	pub fn with_jittered_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
87		Self {
88			max_attempts,
89			backoff: Backoff::ExponentialJitter {
90				base,
91				max,
92			},
93		}
94	}
95
96	pub fn execute<F>(&self, rng: &Rng, rql: &str, mut f: F) -> ExecutionResult
97	where
98		F: FnMut() -> ExecutionResult,
99	{
100		let mut last_result = None;
101		for attempt in 0..self.max_attempts {
102			let result = f();
103			match &result.error {
104				None => return result,
105				Some(err) if err.code == "TXN_001" => {
106					last_result = Some(result);
107					let is_last_attempt = attempt + 1 >= self.max_attempts;
108					if is_last_attempt {
109						warn!(
110							attempt = attempt + 1,
111							max_attempts = self.max_attempts,
112							rql = %rql,
113							"Transaction conflict retries exhausted"
114						);
115					} else {
116						let delay = compute_backoff(&self.backoff, attempt, rng);
117						debug!(
118							attempt = attempt + 1,
119							max_attempts = self.max_attempts,
120							delay_us = delay.as_micros() as u64,
121							rql = %rql,
122							"Transaction conflict detected, retrying after backoff"
123						);
124						if !delay.is_zero() {
125							thread::sleep(delay);
126						}
127					}
128				}
129				Some(_) => {
130					return result;
131				}
132			}
133		}
134		last_result.unwrap()
135	}
136}
137
138fn compute_backoff(backoff: &Backoff, attempt: u32, rng: &Rng) -> Duration {
139	match backoff {
140		Backoff::None => Duration::ZERO,
141		Backoff::Fixed(d) => *d,
142		Backoff::Exponential {
143			base,
144			max,
145		} => exponential_cap(*base, *max, attempt),
146		Backoff::ExponentialJitter {
147			base,
148			max,
149		} => {
150			let cap = exponential_cap(*base, *max, attempt);
151			let cap_nanos = cap.as_nanos().min(u64::MAX as u128) as u64;
152			if cap_nanos == 0 {
153				return Duration::ZERO;
154			}
155			let sampled = rng.infra_u64_inclusive(cap_nanos);
156			Duration::from_nanos(sampled)
157		}
158	}
159}
160
161fn exponential_cap(base: Duration, max: Duration, attempt: u32) -> Duration {
162	let shift = attempt.min(30);
163	let multiplier = 1u32 << shift;
164	base.saturating_mul(multiplier).min(max)
165}
166
167/// A unified session binding an identity to a database engine.
168pub struct Session {
169	engine: StandardEngine,
170	identity: IdentityId,
171	authenticated: bool,
172	token: Option<String>,
173	retry: RetryStrategy,
174}
175
176impl Session {
177	/// Create a session from a validated auth token (server path).
178	pub fn from_token(engine: StandardEngine, info: &Token) -> Self {
179		Self {
180			engine,
181			identity: info.identity,
182			authenticated: true,
183			token: None,
184			retry: RetryStrategy::default(),
185		}
186	}
187
188	/// Create a session from a validated auth token, preserving the token string.
189	pub fn from_token_with_value(engine: StandardEngine, info: &Token) -> Self {
190		Self {
191			engine,
192			identity: info.identity,
193			authenticated: true,
194			token: Some(info.token.clone()),
195			retry: RetryStrategy::default(),
196		}
197	}
198
199	/// Create a trusted session (embedded path, no authentication required).
200	pub fn trusted(engine: StandardEngine, identity: IdentityId) -> Self {
201		Self {
202			engine,
203			identity,
204			authenticated: false,
205			token: None,
206			retry: RetryStrategy::default(),
207		}
208	}
209
210	/// Create an anonymous session.
211	pub fn anonymous(engine: StandardEngine) -> Self {
212		Self::trusted(engine, IdentityId::anonymous())
213	}
214
215	/// Set the retry strategy for command and admin operations.
216	pub fn with_retry(mut self, strategy: RetryStrategy) -> Self {
217		self.retry = strategy;
218		self
219	}
220
221	/// The identity associated with this session.
222	#[inline]
223	pub fn identity(&self) -> IdentityId {
224		self.identity
225	}
226
227	/// The auth token, if this session was created from a validated token.
228	#[inline]
229	pub fn token(&self) -> Option<&str> {
230		self.token.as_deref()
231	}
232
233	/// Whether this session was created from authenticated credentials.
234	#[inline]
235	pub fn is_authenticated(&self) -> bool {
236		self.authenticated
237	}
238
239	/// Execute a read-only query.
240	#[instrument(name = "session::query", level = "debug", skip(self, params), fields(rql = %rql))]
241	pub fn query(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
242		self.engine.query_as(self.identity, rql, params.into())
243	}
244
245	/// Execute a transactional command (DML + Query) with retry on conflict.
246	#[instrument(name = "session::command", level = "debug", skip(self, params), fields(rql = %rql))]
247	pub fn command(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
248		let params = params.into();
249		self.retry
250			.execute(self.engine.rng(), rql, || self.engine.command_as(self.identity, rql, params.clone()))
251	}
252
253	/// Execute an admin (DDL + DML + Query) operation with retry on conflict.
254	#[instrument(name = "session::admin", level = "debug", skip(self, params), fields(rql = %rql))]
255	pub fn admin(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
256		let params = params.into();
257		self.retry.execute(self.engine.rng(), rql, || self.engine.admin_as(self.identity, rql, params.clone()))
258	}
259}
260
261#[cfg(test)]
262mod retry_tests {
263	use std::{cell::Cell, time::Duration};
264
265	use reifydb_core::{execution::ExecutionResult, metric::ExecutionMetrics};
266	use reifydb_runtime::context::rng::Rng;
267	use reifydb_type::{
268		error::{Diagnostic, Error},
269		fragment::Fragment,
270	};
271
272	use super::{Backoff, RetryStrategy, compute_backoff, exponential_cap};
273
274	fn ok() -> ExecutionResult {
275		ExecutionResult {
276			frames: vec![],
277			error: None,
278			metrics: ExecutionMetrics::default(),
279		}
280	}
281
282	fn err(code: &str) -> ExecutionResult {
283		ExecutionResult {
284			frames: vec![],
285			error: Some(Error(Box::new(Diagnostic {
286				code: code.to_string(),
287				rql: None,
288				message: format!("{} test", code),
289				column: None,
290				fragment: Fragment::None,
291				label: None,
292				help: None,
293				notes: vec![],
294				cause: None,
295				operator_chain: None,
296			}))),
297			metrics: ExecutionMetrics::default(),
298		}
299	}
300
301	fn no_sleep_strategy(max_attempts: u32) -> RetryStrategy {
302		RetryStrategy {
303			max_attempts,
304			backoff: Backoff::None,
305		}
306	}
307
308	#[test]
309	fn success_first_try_runs_closure_once() {
310		let strategy = no_sleep_strategy(5);
311		let rng = Rng::default();
312		let calls = Cell::new(0u32);
313		let result = strategy.execute(&rng, "", || {
314			calls.set(calls.get() + 1);
315			ok()
316		});
317		assert!(result.is_ok());
318		assert_eq!(calls.get(), 1);
319	}
320
321	#[test]
322	fn non_conflict_error_is_not_retried() {
323		let strategy = no_sleep_strategy(5);
324		let rng = Rng::default();
325		let calls = Cell::new(0u32);
326		let result = strategy.execute(&rng, "", || {
327			calls.set(calls.get() + 1);
328			err("TXN_002")
329		});
330		assert!(result.is_err());
331		assert_eq!(calls.get(), 1);
332	}
333
334	#[test]
335	fn conflict_retries_then_succeeds() {
336		let strategy = no_sleep_strategy(5);
337		let rng = Rng::default();
338		let calls = Cell::new(0u32);
339		let result = strategy.execute(&rng, "", || {
340			let n = calls.get();
341			calls.set(n + 1);
342			if n < 2 {
343				err("TXN_001")
344			} else {
345				ok()
346			}
347		});
348		assert!(result.is_ok());
349		assert_eq!(calls.get(), 3);
350	}
351
352	#[test]
353	fn conflict_exhausts_attempts_returns_last_error() {
354		let strategy = no_sleep_strategy(4);
355		let rng = Rng::default();
356		let calls = Cell::new(0u32);
357		let result = strategy.execute(&rng, "", || {
358			calls.set(calls.get() + 1);
359			err("TXN_001")
360		});
361		assert!(result.is_err());
362		assert_eq!(result.error.as_ref().unwrap().code, "TXN_001");
363		assert_eq!(calls.get(), 4);
364	}
365
366	#[test]
367	fn jittered_backoff_stays_within_cap() {
368		let base = Duration::from_millis(10);
369		let max = Duration::from_millis(100);
370		let backoff = Backoff::ExponentialJitter {
371			base,
372			max,
373		};
374		let rng = Rng::default();
375		for attempt in 0..8 {
376			let cap = exponential_cap(base, max, attempt);
377			for _ in 0..50 {
378				let d = compute_backoff(&backoff, attempt, &rng);
379				assert!(d <= cap, "attempt {}: {:?} exceeds cap {:?}", attempt, d, cap);
380			}
381		}
382	}
383
384	#[test]
385	fn seeded_rng_produces_deterministic_jitter() {
386		let base = Duration::from_millis(5);
387		let max = Duration::from_millis(200);
388		let backoff = Backoff::ExponentialJitter {
389			base,
390			max,
391		};
392		let sample = |seed: u64| -> Vec<Duration> {
393			let rng = Rng::seeded(seed);
394			(0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng)).collect()
395		};
396		assert_eq!(sample(42), sample(42));
397		assert_ne!(sample(42), sample(43));
398	}
399
400	#[test]
401	fn seeded_rng_produces_exact_pinned_jitter_values() {
402		let base = Duration::from_millis(5);
403		let max = Duration::from_millis(200);
404		let backoff = Backoff::ExponentialJitter {
405			base,
406			max,
407		};
408		let nanos = |seed: u64| -> Vec<u64> {
409			let rng = Rng::seeded(seed);
410			(0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng).as_nanos() as u64).collect()
411		};
412
413		let expected_42: Vec<u64> = vec![
414			3_848_394,
415			113_809,
416			2_934_288,
417			23_292_485,
418			77_680_508,
419			31_066_617,
420			36_519_179,
421			190_866_841,
422		];
423		let expected_43: Vec<u64> = vec![
424			3_974_671, 4_842_103, 12_057_439, 29_830_325, 72_334_216, 22_229_100, 36_417_439, 81_417_246,
425		];
426
427		assert_eq!(nanos(42), expected_42);
428		assert_eq!(nanos(43), expected_43);
429
430		assert_eq!(nanos(42), expected_42);
431		assert_eq!(nanos(43), expected_43);
432	}
433
434	#[test]
435	fn exponential_cap_saturates_at_max() {
436		let base = Duration::from_millis(5);
437		let max = Duration::from_millis(200);
438		assert_eq!(exponential_cap(base, max, 0), Duration::from_millis(5));
439		assert_eq!(exponential_cap(base, max, 1), Duration::from_millis(10));
440		assert_eq!(exponential_cap(base, max, 5), Duration::from_millis(160));
441		assert_eq!(exponential_cap(base, max, 6), max);
442		assert_eq!(exponential_cap(base, max, 100), max);
443	}
444
445	#[test]
446	fn default_uses_jittered_backoff() {
447		let s = RetryStrategy::default();
448		assert_eq!(s.max_attempts, 10);
449		match s.backoff {
450			Backoff::ExponentialJitter {
451				base,
452				max,
453			} => {
454				assert_eq!(base, Duration::from_millis(5));
455				assert_eq!(max, Duration::from_millis(200));
456			}
457			_ => panic!("expected ExponentialJitter default"),
458		}
459	}
460}