Skip to main content

reifydb_engine/
session.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{thread, time::Duration};
5
6use reifydb_core::{execution::ExecutionResult, interface::catalog::token::Token};
7use reifydb_runtime::context::rng::Rng;
8use reifydb_type::{params::Params, value::identity::IdentityId};
9use tracing::{debug, instrument, warn};
10
11use crate::engine::StandardEngine;
12
13pub enum Backoff {
14	None,
15
16	Fixed(Duration),
17
18	Exponential {
19		base: Duration,
20		max: Duration,
21	},
22	ExponentialJitter {
23		base: Duration,
24		max: Duration,
25	},
26}
27
28pub struct RetryStrategy {
29	pub max_attempts: u32,
30	pub backoff: Backoff,
31}
32
33impl Default for RetryStrategy {
34	fn default() -> Self {
35		Self {
36			max_attempts: 10,
37			backoff: Backoff::ExponentialJitter {
38				base: Duration::from_millis(5),
39				max: Duration::from_millis(200),
40			},
41		}
42	}
43}
44
45impl RetryStrategy {
46	pub fn no_retry() -> Self {
47		Self {
48			max_attempts: 1,
49			backoff: Backoff::None,
50		}
51	}
52
53	pub fn default_conflict_retry() -> Self {
54		Self::default()
55	}
56
57	pub fn with_fixed_backoff(max_attempts: u32, delay: Duration) -> Self {
58		Self {
59			max_attempts,
60			backoff: Backoff::Fixed(delay),
61		}
62	}
63
64	pub fn with_exponential_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
65		Self {
66			max_attempts,
67			backoff: Backoff::Exponential {
68				base,
69				max,
70			},
71		}
72	}
73
74	pub fn with_jittered_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
75		Self {
76			max_attempts,
77			backoff: Backoff::ExponentialJitter {
78				base,
79				max,
80			},
81		}
82	}
83
84	pub fn execute<F>(&self, rng: &Rng, rql: &str, mut f: F) -> ExecutionResult
85	where
86		F: FnMut() -> ExecutionResult,
87	{
88		let mut last_result = None;
89		for attempt in 0..self.max_attempts {
90			let result = f();
91			match &result.error {
92				None => return result,
93				Some(err) if err.code == "TXN_001" => {
94					last_result = Some(result);
95					let is_last_attempt = attempt + 1 >= self.max_attempts;
96					if is_last_attempt {
97						warn!(
98							attempt = attempt + 1,
99							max_attempts = self.max_attempts,
100							rql = %rql,
101							"Transaction conflict retries exhausted"
102						);
103					} else {
104						let delay = compute_backoff(&self.backoff, attempt, rng);
105						debug!(
106							attempt = attempt + 1,
107							max_attempts = self.max_attempts,
108							delay_us = delay.as_micros() as u64,
109							rql = %rql,
110							"Transaction conflict detected, retrying after backoff"
111						);
112						if !delay.is_zero() {
113							thread::sleep(delay);
114						}
115					}
116				}
117				Some(_) => {
118					return result;
119				}
120			}
121		}
122		last_result.unwrap()
123	}
124}
125
126fn compute_backoff(backoff: &Backoff, attempt: u32, rng: &Rng) -> Duration {
127	match backoff {
128		Backoff::None => Duration::ZERO,
129		Backoff::Fixed(d) => *d,
130		Backoff::Exponential {
131			base,
132			max,
133		} => exponential_cap(*base, *max, attempt),
134		Backoff::ExponentialJitter {
135			base,
136			max,
137		} => {
138			let cap = exponential_cap(*base, *max, attempt);
139			let cap_nanos = cap.as_nanos().min(u64::MAX as u128) as u64;
140			if cap_nanos == 0 {
141				return Duration::ZERO;
142			}
143			let sampled = rng.infra_u64_inclusive(cap_nanos);
144			Duration::from_nanos(sampled)
145		}
146	}
147}
148
149fn exponential_cap(base: Duration, max: Duration, attempt: u32) -> Duration {
150	let shift = attempt.min(30);
151	let multiplier = 1u32 << shift;
152	base.saturating_mul(multiplier).min(max)
153}
154
155pub struct Session {
156	engine: StandardEngine,
157	identity: IdentityId,
158	authenticated: bool,
159	token: Option<String>,
160	retry: RetryStrategy,
161}
162
163impl Session {
164	pub fn from_token(engine: StandardEngine, info: &Token) -> Self {
165		Self {
166			engine,
167			identity: info.identity,
168			authenticated: true,
169			token: None,
170			retry: RetryStrategy::default(),
171		}
172	}
173
174	pub fn from_token_with_value(engine: StandardEngine, info: &Token) -> Self {
175		Self {
176			engine,
177			identity: info.identity,
178			authenticated: true,
179			token: Some(info.token.clone()),
180			retry: RetryStrategy::default(),
181		}
182	}
183
184	pub fn trusted(engine: StandardEngine, identity: IdentityId) -> Self {
185		Self {
186			engine,
187			identity,
188			authenticated: false,
189			token: None,
190			retry: RetryStrategy::default(),
191		}
192	}
193
194	pub fn anonymous(engine: StandardEngine) -> Self {
195		Self::trusted(engine, IdentityId::anonymous())
196	}
197
198	pub fn with_retry(mut self, strategy: RetryStrategy) -> Self {
199		self.retry = strategy;
200		self
201	}
202
203	#[inline]
204	pub fn identity(&self) -> IdentityId {
205		self.identity
206	}
207
208	#[inline]
209	pub fn token(&self) -> Option<&str> {
210		self.token.as_deref()
211	}
212
213	#[inline]
214	pub fn is_authenticated(&self) -> bool {
215		self.authenticated
216	}
217
218	#[instrument(name = "session::query", level = "debug", skip(self, params), fields(rql = %rql))]
219	pub fn query(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
220		self.engine.query_as(self.identity, rql, params.into())
221	}
222
223	#[instrument(name = "session::command", level = "debug", skip(self, params), fields(rql = %rql))]
224	pub fn command(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
225		let params = params.into();
226		self.retry
227			.execute(self.engine.rng(), rql, || self.engine.command_as(self.identity, rql, params.clone()))
228	}
229
230	#[instrument(name = "session::admin", level = "debug", skip(self, params), fields(rql = %rql))]
231	pub fn admin(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
232		let params = params.into();
233		self.retry.execute(self.engine.rng(), rql, || self.engine.admin_as(self.identity, rql, params.clone()))
234	}
235}
236
237#[cfg(test)]
238mod retry_tests {
239	use std::{cell::Cell, time::Duration};
240
241	use reifydb_core::{execution::ExecutionResult, metric::ExecutionMetrics};
242	use reifydb_runtime::context::rng::Rng;
243	use reifydb_type::{
244		error::{Diagnostic, Error},
245		fragment::Fragment,
246	};
247
248	use super::{Backoff, RetryStrategy, compute_backoff, exponential_cap};
249
250	fn ok() -> ExecutionResult {
251		ExecutionResult {
252			frames: vec![],
253			error: None,
254			metrics: ExecutionMetrics::default(),
255		}
256	}
257
258	fn err(code: &str) -> ExecutionResult {
259		ExecutionResult {
260			frames: vec![],
261			error: Some(Error(Box::new(Diagnostic {
262				code: code.to_string(),
263				rql: None,
264				message: format!("{} test", code),
265				column: None,
266				fragment: Fragment::None,
267				label: None,
268				help: None,
269				notes: vec![],
270				cause: None,
271				operator_chain: None,
272			}))),
273			metrics: ExecutionMetrics::default(),
274		}
275	}
276
277	fn no_sleep_strategy(max_attempts: u32) -> RetryStrategy {
278		RetryStrategy {
279			max_attempts,
280			backoff: Backoff::None,
281		}
282	}
283
284	#[test]
285	fn success_first_try_runs_closure_once() {
286		let strategy = no_sleep_strategy(5);
287		let rng = Rng::default();
288		let calls = Cell::new(0u32);
289		let result = strategy.execute(&rng, "", || {
290			calls.set(calls.get() + 1);
291			ok()
292		});
293		assert!(result.is_ok());
294		assert_eq!(calls.get(), 1);
295	}
296
297	#[test]
298	fn non_conflict_error_is_not_retried() {
299		let strategy = no_sleep_strategy(5);
300		let rng = Rng::default();
301		let calls = Cell::new(0u32);
302		let result = strategy.execute(&rng, "", || {
303			calls.set(calls.get() + 1);
304			err("TXN_002")
305		});
306		assert!(result.is_err());
307		assert_eq!(calls.get(), 1);
308	}
309
310	#[test]
311	fn conflict_retries_then_succeeds() {
312		let strategy = no_sleep_strategy(5);
313		let rng = Rng::default();
314		let calls = Cell::new(0u32);
315		let result = strategy.execute(&rng, "", || {
316			let n = calls.get();
317			calls.set(n + 1);
318			if n < 2 {
319				err("TXN_001")
320			} else {
321				ok()
322			}
323		});
324		assert!(result.is_ok());
325		assert_eq!(calls.get(), 3);
326	}
327
328	#[test]
329	fn conflict_exhausts_attempts_returns_last_error() {
330		let strategy = no_sleep_strategy(4);
331		let rng = Rng::default();
332		let calls = Cell::new(0u32);
333		let result = strategy.execute(&rng, "", || {
334			calls.set(calls.get() + 1);
335			err("TXN_001")
336		});
337		assert!(result.is_err());
338		assert_eq!(result.error.as_ref().unwrap().code, "TXN_001");
339		assert_eq!(calls.get(), 4);
340	}
341
342	#[test]
343	fn jittered_backoff_stays_within_cap() {
344		let base = Duration::from_millis(10);
345		let max = Duration::from_millis(100);
346		let backoff = Backoff::ExponentialJitter {
347			base,
348			max,
349		};
350		let rng = Rng::default();
351		for attempt in 0..8 {
352			let cap = exponential_cap(base, max, attempt);
353			for _ in 0..50 {
354				let d = compute_backoff(&backoff, attempt, &rng);
355				assert!(d <= cap, "attempt {}: {:?} exceeds cap {:?}", attempt, d, cap);
356			}
357		}
358	}
359
360	#[test]
361	fn seeded_rng_produces_deterministic_jitter() {
362		let base = Duration::from_millis(5);
363		let max = Duration::from_millis(200);
364		let backoff = Backoff::ExponentialJitter {
365			base,
366			max,
367		};
368		let sample = |seed: u64| -> Vec<Duration> {
369			let rng = Rng::seeded(seed);
370			(0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng)).collect()
371		};
372		assert_eq!(sample(42), sample(42));
373		assert_ne!(sample(42), sample(43));
374	}
375
376	#[test]
377	fn seeded_rng_produces_exact_pinned_jitter_values() {
378		let base = Duration::from_millis(5);
379		let max = Duration::from_millis(200);
380		let backoff = Backoff::ExponentialJitter {
381			base,
382			max,
383		};
384		let nanos = |seed: u64| -> Vec<u64> {
385			let rng = Rng::seeded(seed);
386			(0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng).as_nanos() as u64).collect()
387		};
388
389		let expected_42: Vec<u64> = vec![
390			3_848_394,
391			113_809,
392			2_934_288,
393			23_292_485,
394			77_680_508,
395			31_066_617,
396			36_519_179,
397			190_866_841,
398		];
399		let expected_43: Vec<u64> = vec![
400			3_974_671, 4_842_103, 12_057_439, 29_830_325, 72_334_216, 22_229_100, 36_417_439, 81_417_246,
401		];
402
403		assert_eq!(nanos(42), expected_42);
404		assert_eq!(nanos(43), expected_43);
405
406		assert_eq!(nanos(42), expected_42);
407		assert_eq!(nanos(43), expected_43);
408	}
409
410	#[test]
411	fn exponential_cap_saturates_at_max() {
412		let base = Duration::from_millis(5);
413		let max = Duration::from_millis(200);
414		assert_eq!(exponential_cap(base, max, 0), Duration::from_millis(5));
415		assert_eq!(exponential_cap(base, max, 1), Duration::from_millis(10));
416		assert_eq!(exponential_cap(base, max, 5), Duration::from_millis(160));
417		assert_eq!(exponential_cap(base, max, 6), max);
418		assert_eq!(exponential_cap(base, max, 100), max);
419	}
420
421	#[test]
422	fn default_uses_jittered_backoff() {
423		let s = RetryStrategy::default();
424		assert_eq!(s.max_attempts, 10);
425		match s.backoff {
426			Backoff::ExponentialJitter {
427				base,
428				max,
429			} => {
430				assert_eq!(base, Duration::from_millis(5));
431				assert_eq!(max, Duration::from_millis(200));
432			}
433			_ => panic!("expected ExponentialJitter default"),
434		}
435	}
436}