1use std::{thread, time::Duration};
11
12use reifydb_core::{execution::ExecutionResult, interface::catalog::token::Token};
13use reifydb_runtime::context::rng::Rng;
14use reifydb_type::{params::Params, value::identity::IdentityId};
15use tracing::{debug, instrument, warn};
16
17use crate::engine::StandardEngine;
18
19pub enum Backoff {
21 None,
23 Fixed(Duration),
25 Exponential {
27 base: Duration,
28 max: Duration,
29 },
30 ExponentialJitter {
31 base: Duration,
32 max: Duration,
33 },
34}
35
36pub struct RetryStrategy {
38 pub max_attempts: u32,
39 pub backoff: Backoff,
40}
41
42impl Default for RetryStrategy {
43 fn default() -> Self {
44 Self {
45 max_attempts: 10,
46 backoff: Backoff::ExponentialJitter {
47 base: Duration::from_millis(5),
48 max: Duration::from_millis(200),
49 },
50 }
51 }
52}
53
54impl RetryStrategy {
55 pub fn no_retry() -> Self {
57 Self {
58 max_attempts: 1,
59 backoff: Backoff::None,
60 }
61 }
62
63 pub fn default_conflict_retry() -> Self {
64 Self::default()
65 }
66
67 pub fn with_fixed_backoff(max_attempts: u32, delay: Duration) -> Self {
69 Self {
70 max_attempts,
71 backoff: Backoff::Fixed(delay),
72 }
73 }
74
75 pub fn with_exponential_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
77 Self {
78 max_attempts,
79 backoff: Backoff::Exponential {
80 base,
81 max,
82 },
83 }
84 }
85
86 pub fn with_jittered_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
87 Self {
88 max_attempts,
89 backoff: Backoff::ExponentialJitter {
90 base,
91 max,
92 },
93 }
94 }
95
96 pub fn execute<F>(&self, rng: &Rng, rql: &str, mut f: F) -> ExecutionResult
97 where
98 F: FnMut() -> ExecutionResult,
99 {
100 let mut last_result = None;
101 for attempt in 0..self.max_attempts {
102 let result = f();
103 match &result.error {
104 None => return result,
105 Some(err) if err.code == "TXN_001" => {
106 last_result = Some(result);
107 let is_last_attempt = attempt + 1 >= self.max_attempts;
108 if is_last_attempt {
109 warn!(
110 attempt = attempt + 1,
111 max_attempts = self.max_attempts,
112 rql = %rql,
113 "Transaction conflict retries exhausted"
114 );
115 } else {
116 let delay = compute_backoff(&self.backoff, attempt, rng);
117 debug!(
118 attempt = attempt + 1,
119 max_attempts = self.max_attempts,
120 delay_us = delay.as_micros() as u64,
121 rql = %rql,
122 "Transaction conflict detected, retrying after backoff"
123 );
124 if !delay.is_zero() {
125 thread::sleep(delay);
126 }
127 }
128 }
129 Some(_) => {
130 return result;
131 }
132 }
133 }
134 last_result.unwrap()
135 }
136}
137
138fn compute_backoff(backoff: &Backoff, attempt: u32, rng: &Rng) -> Duration {
139 match backoff {
140 Backoff::None => Duration::ZERO,
141 Backoff::Fixed(d) => *d,
142 Backoff::Exponential {
143 base,
144 max,
145 } => exponential_cap(*base, *max, attempt),
146 Backoff::ExponentialJitter {
147 base,
148 max,
149 } => {
150 let cap = exponential_cap(*base, *max, attempt);
151 let cap_nanos = cap.as_nanos().min(u64::MAX as u128) as u64;
152 if cap_nanos == 0 {
153 return Duration::ZERO;
154 }
155 let sampled = rng.infra_u64_inclusive(cap_nanos);
156 Duration::from_nanos(sampled)
157 }
158 }
159}
160
161fn exponential_cap(base: Duration, max: Duration, attempt: u32) -> Duration {
162 let shift = attempt.min(30);
163 let multiplier = 1u32 << shift;
164 base.saturating_mul(multiplier).min(max)
165}
166
167pub struct Session {
169 engine: StandardEngine,
170 identity: IdentityId,
171 authenticated: bool,
172 token: Option<String>,
173 retry: RetryStrategy,
174}
175
176impl Session {
177 pub fn from_token(engine: StandardEngine, info: &Token) -> Self {
179 Self {
180 engine,
181 identity: info.identity,
182 authenticated: true,
183 token: None,
184 retry: RetryStrategy::default(),
185 }
186 }
187
188 pub fn from_token_with_value(engine: StandardEngine, info: &Token) -> Self {
190 Self {
191 engine,
192 identity: info.identity,
193 authenticated: true,
194 token: Some(info.token.clone()),
195 retry: RetryStrategy::default(),
196 }
197 }
198
199 pub fn trusted(engine: StandardEngine, identity: IdentityId) -> Self {
201 Self {
202 engine,
203 identity,
204 authenticated: false,
205 token: None,
206 retry: RetryStrategy::default(),
207 }
208 }
209
210 pub fn anonymous(engine: StandardEngine) -> Self {
212 Self::trusted(engine, IdentityId::anonymous())
213 }
214
215 pub fn with_retry(mut self, strategy: RetryStrategy) -> Self {
217 self.retry = strategy;
218 self
219 }
220
221 #[inline]
223 pub fn identity(&self) -> IdentityId {
224 self.identity
225 }
226
227 #[inline]
229 pub fn token(&self) -> Option<&str> {
230 self.token.as_deref()
231 }
232
233 #[inline]
235 pub fn is_authenticated(&self) -> bool {
236 self.authenticated
237 }
238
239 #[instrument(name = "session::query", level = "debug", skip(self, params), fields(rql = %rql))]
241 pub fn query(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
242 self.engine.query_as(self.identity, rql, params.into())
243 }
244
245 #[instrument(name = "session::command", level = "debug", skip(self, params), fields(rql = %rql))]
247 pub fn command(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
248 let params = params.into();
249 self.retry
250 .execute(self.engine.rng(), rql, || self.engine.command_as(self.identity, rql, params.clone()))
251 }
252
253 #[instrument(name = "session::admin", level = "debug", skip(self, params), fields(rql = %rql))]
255 pub fn admin(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
256 let params = params.into();
257 self.retry.execute(self.engine.rng(), rql, || self.engine.admin_as(self.identity, rql, params.clone()))
258 }
259}
260
261#[cfg(test)]
262mod retry_tests {
263 use std::{cell::Cell, time::Duration};
264
265 use reifydb_core::{execution::ExecutionResult, metric::ExecutionMetrics};
266 use reifydb_runtime::context::rng::Rng;
267 use reifydb_type::{
268 error::{Diagnostic, Error},
269 fragment::Fragment,
270 };
271
272 use super::{Backoff, RetryStrategy, compute_backoff, exponential_cap};
273
274 fn ok() -> ExecutionResult {
275 ExecutionResult {
276 frames: vec![],
277 error: None,
278 metrics: ExecutionMetrics::default(),
279 }
280 }
281
282 fn err(code: &str) -> ExecutionResult {
283 ExecutionResult {
284 frames: vec![],
285 error: Some(Error(Box::new(Diagnostic {
286 code: code.to_string(),
287 rql: None,
288 message: format!("{} test", code),
289 column: None,
290 fragment: Fragment::None,
291 label: None,
292 help: None,
293 notes: vec![],
294 cause: None,
295 operator_chain: None,
296 }))),
297 metrics: ExecutionMetrics::default(),
298 }
299 }
300
301 fn no_sleep_strategy(max_attempts: u32) -> RetryStrategy {
302 RetryStrategy {
303 max_attempts,
304 backoff: Backoff::None,
305 }
306 }
307
308 #[test]
309 fn success_first_try_runs_closure_once() {
310 let strategy = no_sleep_strategy(5);
311 let rng = Rng::default();
312 let calls = Cell::new(0u32);
313 let result = strategy.execute(&rng, "", || {
314 calls.set(calls.get() + 1);
315 ok()
316 });
317 assert!(result.is_ok());
318 assert_eq!(calls.get(), 1);
319 }
320
321 #[test]
322 fn non_conflict_error_is_not_retried() {
323 let strategy = no_sleep_strategy(5);
324 let rng = Rng::default();
325 let calls = Cell::new(0u32);
326 let result = strategy.execute(&rng, "", || {
327 calls.set(calls.get() + 1);
328 err("TXN_002")
329 });
330 assert!(result.is_err());
331 assert_eq!(calls.get(), 1);
332 }
333
334 #[test]
335 fn conflict_retries_then_succeeds() {
336 let strategy = no_sleep_strategy(5);
337 let rng = Rng::default();
338 let calls = Cell::new(0u32);
339 let result = strategy.execute(&rng, "", || {
340 let n = calls.get();
341 calls.set(n + 1);
342 if n < 2 {
343 err("TXN_001")
344 } else {
345 ok()
346 }
347 });
348 assert!(result.is_ok());
349 assert_eq!(calls.get(), 3);
350 }
351
352 #[test]
353 fn conflict_exhausts_attempts_returns_last_error() {
354 let strategy = no_sleep_strategy(4);
355 let rng = Rng::default();
356 let calls = Cell::new(0u32);
357 let result = strategy.execute(&rng, "", || {
358 calls.set(calls.get() + 1);
359 err("TXN_001")
360 });
361 assert!(result.is_err());
362 assert_eq!(result.error.as_ref().unwrap().code, "TXN_001");
363 assert_eq!(calls.get(), 4);
364 }
365
366 #[test]
367 fn jittered_backoff_stays_within_cap() {
368 let base = Duration::from_millis(10);
369 let max = Duration::from_millis(100);
370 let backoff = Backoff::ExponentialJitter {
371 base,
372 max,
373 };
374 let rng = Rng::default();
375 for attempt in 0..8 {
376 let cap = exponential_cap(base, max, attempt);
377 for _ in 0..50 {
378 let d = compute_backoff(&backoff, attempt, &rng);
379 assert!(d <= cap, "attempt {}: {:?} exceeds cap {:?}", attempt, d, cap);
380 }
381 }
382 }
383
384 #[test]
385 fn seeded_rng_produces_deterministic_jitter() {
386 let base = Duration::from_millis(5);
387 let max = Duration::from_millis(200);
388 let backoff = Backoff::ExponentialJitter {
389 base,
390 max,
391 };
392 let sample = |seed: u64| -> Vec<Duration> {
393 let rng = Rng::seeded(seed);
394 (0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng)).collect()
395 };
396 assert_eq!(sample(42), sample(42));
397 assert_ne!(sample(42), sample(43));
398 }
399
400 #[test]
401 fn seeded_rng_produces_exact_pinned_jitter_values() {
402 let base = Duration::from_millis(5);
403 let max = Duration::from_millis(200);
404 let backoff = Backoff::ExponentialJitter {
405 base,
406 max,
407 };
408 let nanos = |seed: u64| -> Vec<u64> {
409 let rng = Rng::seeded(seed);
410 (0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng).as_nanos() as u64).collect()
411 };
412
413 let expected_42: Vec<u64> = vec![
414 3_848_394,
415 113_809,
416 2_934_288,
417 23_292_485,
418 77_680_508,
419 31_066_617,
420 36_519_179,
421 190_866_841,
422 ];
423 let expected_43: Vec<u64> = vec![
424 3_974_671, 4_842_103, 12_057_439, 29_830_325, 72_334_216, 22_229_100, 36_417_439, 81_417_246,
425 ];
426
427 assert_eq!(nanos(42), expected_42);
428 assert_eq!(nanos(43), expected_43);
429
430 assert_eq!(nanos(42), expected_42);
431 assert_eq!(nanos(43), expected_43);
432 }
433
434 #[test]
435 fn exponential_cap_saturates_at_max() {
436 let base = Duration::from_millis(5);
437 let max = Duration::from_millis(200);
438 assert_eq!(exponential_cap(base, max, 0), Duration::from_millis(5));
439 assert_eq!(exponential_cap(base, max, 1), Duration::from_millis(10));
440 assert_eq!(exponential_cap(base, max, 5), Duration::from_millis(160));
441 assert_eq!(exponential_cap(base, max, 6), max);
442 assert_eq!(exponential_cap(base, max, 100), max);
443 }
444
445 #[test]
446 fn default_uses_jittered_backoff() {
447 let s = RetryStrategy::default();
448 assert_eq!(s.max_attempts, 10);
449 match s.backoff {
450 Backoff::ExponentialJitter {
451 base,
452 max,
453 } => {
454 assert_eq!(base, Duration::from_millis(5));
455 assert_eq!(max, Duration::from_millis(200));
456 }
457 _ => panic!("expected ExponentialJitter default"),
458 }
459 }
460}