Skip to main content

reifydb_engine/
session.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{thread, time::Duration};
5
6use reifydb_core::{execution::ExecutionResult, interface::catalog::token::Token};
7use reifydb_runtime::context::rng::Rng;
8use reifydb_type::{params::Params, value::identity::IdentityId};
9use tracing::{debug, instrument, warn};
10
11use crate::engine::StandardEngine;
12
13/// Backoff strategy between retry attempts.
14pub enum Backoff {
15	/// No delay between retries.
16	None,
17	/// Fixed delay between each retry attempt.
18	Fixed(Duration),
19	/// Exponential backoff: delay doubles each attempt, capped at `max`.
20	Exponential {
21		base: Duration,
22		max: Duration,
23	},
24	ExponentialJitter {
25		base: Duration,
26		max: Duration,
27	},
28}
29
30/// Controls how many times a write transaction is retried on conflict (`TXN_001`).
31pub struct RetryStrategy {
32	pub max_attempts: u32,
33	pub backoff: Backoff,
34}
35
36impl Default for RetryStrategy {
37	fn default() -> Self {
38		Self {
39			max_attempts: 10,
40			backoff: Backoff::ExponentialJitter {
41				base: Duration::from_millis(5),
42				max: Duration::from_millis(200),
43			},
44		}
45	}
46}
47
48impl RetryStrategy {
49	/// No retries - fail immediately on conflict.
50	pub fn no_retry() -> Self {
51		Self {
52			max_attempts: 1,
53			backoff: Backoff::None,
54		}
55	}
56
57	pub fn default_conflict_retry() -> Self {
58		Self::default()
59	}
60
61	/// Fixed delay between retry attempts.
62	pub fn with_fixed_backoff(max_attempts: u32, delay: Duration) -> Self {
63		Self {
64			max_attempts,
65			backoff: Backoff::Fixed(delay),
66		}
67	}
68
69	/// Exponential backoff: delay doubles each attempt, capped at `max`.
70	pub fn with_exponential_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
71		Self {
72			max_attempts,
73			backoff: Backoff::Exponential {
74				base,
75				max,
76			},
77		}
78	}
79
80	pub fn with_jittered_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
81		Self {
82			max_attempts,
83			backoff: Backoff::ExponentialJitter {
84				base,
85				max,
86			},
87		}
88	}
89
90	pub fn execute<F>(&self, rng: &Rng, rql: &str, mut f: F) -> ExecutionResult
91	where
92		F: FnMut() -> ExecutionResult,
93	{
94		let mut last_result = None;
95		for attempt in 0..self.max_attempts {
96			let result = f();
97			match &result.error {
98				None => return result,
99				Some(err) if err.code == "TXN_001" => {
100					last_result = Some(result);
101					let is_last_attempt = attempt + 1 >= self.max_attempts;
102					if is_last_attempt {
103						warn!(
104							attempt = attempt + 1,
105							max_attempts = self.max_attempts,
106							rql = %rql,
107							"Transaction conflict retries exhausted"
108						);
109					} else {
110						let delay = compute_backoff(&self.backoff, attempt, rng);
111						debug!(
112							attempt = attempt + 1,
113							max_attempts = self.max_attempts,
114							delay_us = delay.as_micros() as u64,
115							rql = %rql,
116							"Transaction conflict detected, retrying after backoff"
117						);
118						if !delay.is_zero() {
119							thread::sleep(delay);
120						}
121					}
122				}
123				Some(_) => {
124					return result;
125				}
126			}
127		}
128		last_result.unwrap()
129	}
130}
131
132fn compute_backoff(backoff: &Backoff, attempt: u32, rng: &Rng) -> Duration {
133	match backoff {
134		Backoff::None => Duration::ZERO,
135		Backoff::Fixed(d) => *d,
136		Backoff::Exponential {
137			base,
138			max,
139		} => exponential_cap(*base, *max, attempt),
140		Backoff::ExponentialJitter {
141			base,
142			max,
143		} => {
144			let cap = exponential_cap(*base, *max, attempt);
145			let cap_nanos = cap.as_nanos().min(u64::MAX as u128) as u64;
146			if cap_nanos == 0 {
147				return Duration::ZERO;
148			}
149			let sampled = rng.infra_u64_inclusive(cap_nanos);
150			Duration::from_nanos(sampled)
151		}
152	}
153}
154
155fn exponential_cap(base: Duration, max: Duration, attempt: u32) -> Duration {
156	let shift = attempt.min(30);
157	let multiplier = 1u32 << shift;
158	base.saturating_mul(multiplier).min(max)
159}
160
161/// A unified session binding an identity to a database engine.
162pub struct Session {
163	engine: StandardEngine,
164	identity: IdentityId,
165	authenticated: bool,
166	token: Option<String>,
167	retry: RetryStrategy,
168}
169
170impl Session {
171	/// Create a session from a validated auth token (server path).
172	pub fn from_token(engine: StandardEngine, info: &Token) -> Self {
173		Self {
174			engine,
175			identity: info.identity,
176			authenticated: true,
177			token: None,
178			retry: RetryStrategy::default(),
179		}
180	}
181
182	/// Create a session from a validated auth token, preserving the token string.
183	pub fn from_token_with_value(engine: StandardEngine, info: &Token) -> Self {
184		Self {
185			engine,
186			identity: info.identity,
187			authenticated: true,
188			token: Some(info.token.clone()),
189			retry: RetryStrategy::default(),
190		}
191	}
192
193	/// Create a trusted session (embedded path, no authentication required).
194	pub fn trusted(engine: StandardEngine, identity: IdentityId) -> Self {
195		Self {
196			engine,
197			identity,
198			authenticated: false,
199			token: None,
200			retry: RetryStrategy::default(),
201		}
202	}
203
204	/// Create an anonymous session.
205	pub fn anonymous(engine: StandardEngine) -> Self {
206		Self::trusted(engine, IdentityId::anonymous())
207	}
208
209	/// Set the retry strategy for command and admin operations.
210	pub fn with_retry(mut self, strategy: RetryStrategy) -> Self {
211		self.retry = strategy;
212		self
213	}
214
215	/// The identity associated with this session.
216	#[inline]
217	pub fn identity(&self) -> IdentityId {
218		self.identity
219	}
220
221	/// The auth token, if this session was created from a validated token.
222	#[inline]
223	pub fn token(&self) -> Option<&str> {
224		self.token.as_deref()
225	}
226
227	/// Whether this session was created from authenticated credentials.
228	#[inline]
229	pub fn is_authenticated(&self) -> bool {
230		self.authenticated
231	}
232
233	/// Execute a read-only query.
234	#[instrument(name = "session::query", level = "debug", skip(self, params), fields(rql = %rql))]
235	pub fn query(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
236		self.engine.query_as(self.identity, rql, params.into())
237	}
238
239	/// Execute a transactional command (DML + Query) with retry on conflict.
240	#[instrument(name = "session::command", level = "debug", skip(self, params), fields(rql = %rql))]
241	pub fn command(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
242		let params = params.into();
243		self.retry
244			.execute(self.engine.rng(), rql, || self.engine.command_as(self.identity, rql, params.clone()))
245	}
246
247	/// Execute an admin (DDL + DML + Query) operation with retry on conflict.
248	#[instrument(name = "session::admin", level = "debug", skip(self, params), fields(rql = %rql))]
249	pub fn admin(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
250		let params = params.into();
251		self.retry.execute(self.engine.rng(), rql, || self.engine.admin_as(self.identity, rql, params.clone()))
252	}
253}
254
255#[cfg(test)]
256mod retry_tests {
257	use std::{cell::Cell, time::Duration};
258
259	use reifydb_core::{execution::ExecutionResult, metric::ExecutionMetrics};
260	use reifydb_runtime::context::rng::Rng;
261	use reifydb_type::{
262		error::{Diagnostic, Error},
263		fragment::Fragment,
264	};
265
266	use super::{Backoff, RetryStrategy, compute_backoff, exponential_cap};
267
268	fn ok() -> ExecutionResult {
269		ExecutionResult {
270			frames: vec![],
271			error: None,
272			metrics: ExecutionMetrics::default(),
273		}
274	}
275
276	fn err(code: &str) -> ExecutionResult {
277		ExecutionResult {
278			frames: vec![],
279			error: Some(Error(Box::new(Diagnostic {
280				code: code.to_string(),
281				rql: None,
282				message: format!("{} test", code),
283				column: None,
284				fragment: Fragment::None,
285				label: None,
286				help: None,
287				notes: vec![],
288				cause: None,
289				operator_chain: None,
290			}))),
291			metrics: ExecutionMetrics::default(),
292		}
293	}
294
295	fn no_sleep_strategy(max_attempts: u32) -> RetryStrategy {
296		RetryStrategy {
297			max_attempts,
298			backoff: Backoff::None,
299		}
300	}
301
302	#[test]
303	fn success_first_try_runs_closure_once() {
304		let strategy = no_sleep_strategy(5);
305		let rng = Rng::default();
306		let calls = Cell::new(0u32);
307		let result = strategy.execute(&rng, "", || {
308			calls.set(calls.get() + 1);
309			ok()
310		});
311		assert!(result.is_ok());
312		assert_eq!(calls.get(), 1);
313	}
314
315	#[test]
316	fn non_conflict_error_is_not_retried() {
317		let strategy = no_sleep_strategy(5);
318		let rng = Rng::default();
319		let calls = Cell::new(0u32);
320		let result = strategy.execute(&rng, "", || {
321			calls.set(calls.get() + 1);
322			err("TXN_002")
323		});
324		assert!(result.is_err());
325		assert_eq!(calls.get(), 1);
326	}
327
328	#[test]
329	fn conflict_retries_then_succeeds() {
330		let strategy = no_sleep_strategy(5);
331		let rng = Rng::default();
332		let calls = Cell::new(0u32);
333		let result = strategy.execute(&rng, "", || {
334			let n = calls.get();
335			calls.set(n + 1);
336			if n < 2 {
337				err("TXN_001")
338			} else {
339				ok()
340			}
341		});
342		assert!(result.is_ok());
343		assert_eq!(calls.get(), 3);
344	}
345
346	#[test]
347	fn conflict_exhausts_attempts_returns_last_error() {
348		let strategy = no_sleep_strategy(4);
349		let rng = Rng::default();
350		let calls = Cell::new(0u32);
351		let result = strategy.execute(&rng, "", || {
352			calls.set(calls.get() + 1);
353			err("TXN_001")
354		});
355		assert!(result.is_err());
356		assert_eq!(result.error.as_ref().unwrap().code, "TXN_001");
357		assert_eq!(calls.get(), 4);
358	}
359
360	#[test]
361	fn jittered_backoff_stays_within_cap() {
362		let base = Duration::from_millis(10);
363		let max = Duration::from_millis(100);
364		let backoff = Backoff::ExponentialJitter {
365			base,
366			max,
367		};
368		let rng = Rng::default();
369		for attempt in 0..8 {
370			let cap = exponential_cap(base, max, attempt);
371			for _ in 0..50 {
372				let d = compute_backoff(&backoff, attempt, &rng);
373				assert!(d <= cap, "attempt {}: {:?} exceeds cap {:?}", attempt, d, cap);
374			}
375		}
376	}
377
378	#[test]
379	fn seeded_rng_produces_deterministic_jitter() {
380		let base = Duration::from_millis(5);
381		let max = Duration::from_millis(200);
382		let backoff = Backoff::ExponentialJitter {
383			base,
384			max,
385		};
386		let sample = |seed: u64| -> Vec<Duration> {
387			let rng = Rng::seeded(seed);
388			(0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng)).collect()
389		};
390		assert_eq!(sample(42), sample(42));
391		assert_ne!(sample(42), sample(43));
392	}
393
394	#[test]
395	fn seeded_rng_produces_exact_pinned_jitter_values() {
396		let base = Duration::from_millis(5);
397		let max = Duration::from_millis(200);
398		let backoff = Backoff::ExponentialJitter {
399			base,
400			max,
401		};
402		let nanos = |seed: u64| -> Vec<u64> {
403			let rng = Rng::seeded(seed);
404			(0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng).as_nanos() as u64).collect()
405		};
406
407		let expected_42: Vec<u64> = vec![
408			3_848_394,
409			113_809,
410			2_934_288,
411			23_292_485,
412			77_680_508,
413			31_066_617,
414			36_519_179,
415			190_866_841,
416		];
417		let expected_43: Vec<u64> = vec![
418			3_974_671, 4_842_103, 12_057_439, 29_830_325, 72_334_216, 22_229_100, 36_417_439, 81_417_246,
419		];
420
421		assert_eq!(nanos(42), expected_42);
422		assert_eq!(nanos(43), expected_43);
423
424		assert_eq!(nanos(42), expected_42);
425		assert_eq!(nanos(43), expected_43);
426	}
427
428	#[test]
429	fn exponential_cap_saturates_at_max() {
430		let base = Duration::from_millis(5);
431		let max = Duration::from_millis(200);
432		assert_eq!(exponential_cap(base, max, 0), Duration::from_millis(5));
433		assert_eq!(exponential_cap(base, max, 1), Duration::from_millis(10));
434		assert_eq!(exponential_cap(base, max, 5), Duration::from_millis(160));
435		assert_eq!(exponential_cap(base, max, 6), max);
436		assert_eq!(exponential_cap(base, max, 100), max);
437	}
438
439	#[test]
440	fn default_uses_jittered_backoff() {
441		let s = RetryStrategy::default();
442		assert_eq!(s.max_attempts, 10);
443		match s.backoff {
444			Backoff::ExponentialJitter {
445				base,
446				max,
447			} => {
448				assert_eq!(base, Duration::from_millis(5));
449				assert_eq!(max, Duration::from_millis(200));
450			}
451			_ => panic!("expected ExponentialJitter default"),
452		}
453	}
454}