1use std::{thread, time::Duration};
5
6use reifydb_core::{execution::ExecutionResult, interface::catalog::token::Token};
7use reifydb_runtime::context::rng::Rng;
8use reifydb_type::{params::Params, value::identity::IdentityId};
9use tracing::{debug, instrument, warn};
10
11use crate::engine::StandardEngine;
12
13pub enum Backoff {
15 None,
17 Fixed(Duration),
19 Exponential {
21 base: Duration,
22 max: Duration,
23 },
24 ExponentialJitter {
25 base: Duration,
26 max: Duration,
27 },
28}
29
30pub struct RetryStrategy {
32 pub max_attempts: u32,
33 pub backoff: Backoff,
34}
35
36impl Default for RetryStrategy {
37 fn default() -> Self {
38 Self {
39 max_attempts: 10,
40 backoff: Backoff::ExponentialJitter {
41 base: Duration::from_millis(5),
42 max: Duration::from_millis(200),
43 },
44 }
45 }
46}
47
48impl RetryStrategy {
49 pub fn no_retry() -> Self {
51 Self {
52 max_attempts: 1,
53 backoff: Backoff::None,
54 }
55 }
56
57 pub fn default_conflict_retry() -> Self {
58 Self::default()
59 }
60
61 pub fn with_fixed_backoff(max_attempts: u32, delay: Duration) -> Self {
63 Self {
64 max_attempts,
65 backoff: Backoff::Fixed(delay),
66 }
67 }
68
69 pub fn with_exponential_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
71 Self {
72 max_attempts,
73 backoff: Backoff::Exponential {
74 base,
75 max,
76 },
77 }
78 }
79
80 pub fn with_jittered_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
81 Self {
82 max_attempts,
83 backoff: Backoff::ExponentialJitter {
84 base,
85 max,
86 },
87 }
88 }
89
90 pub fn execute<F>(&self, rng: &Rng, rql: &str, mut f: F) -> ExecutionResult
91 where
92 F: FnMut() -> ExecutionResult,
93 {
94 let mut last_result = None;
95 for attempt in 0..self.max_attempts {
96 let result = f();
97 match &result.error {
98 None => return result,
99 Some(err) if err.code == "TXN_001" => {
100 last_result = Some(result);
101 let is_last_attempt = attempt + 1 >= self.max_attempts;
102 if is_last_attempt {
103 warn!(
104 attempt = attempt + 1,
105 max_attempts = self.max_attempts,
106 rql = %rql,
107 "Transaction conflict retries exhausted"
108 );
109 } else {
110 let delay = compute_backoff(&self.backoff, attempt, rng);
111 debug!(
112 attempt = attempt + 1,
113 max_attempts = self.max_attempts,
114 delay_us = delay.as_micros() as u64,
115 rql = %rql,
116 "Transaction conflict detected, retrying after backoff"
117 );
118 if !delay.is_zero() {
119 thread::sleep(delay);
120 }
121 }
122 }
123 Some(_) => {
124 return result;
125 }
126 }
127 }
128 last_result.unwrap()
129 }
130}
131
132fn compute_backoff(backoff: &Backoff, attempt: u32, rng: &Rng) -> Duration {
133 match backoff {
134 Backoff::None => Duration::ZERO,
135 Backoff::Fixed(d) => *d,
136 Backoff::Exponential {
137 base,
138 max,
139 } => exponential_cap(*base, *max, attempt),
140 Backoff::ExponentialJitter {
141 base,
142 max,
143 } => {
144 let cap = exponential_cap(*base, *max, attempt);
145 let cap_nanos = cap.as_nanos().min(u64::MAX as u128) as u64;
146 if cap_nanos == 0 {
147 return Duration::ZERO;
148 }
149 let sampled = rng.infra_u64_inclusive(cap_nanos);
150 Duration::from_nanos(sampled)
151 }
152 }
153}
154
155fn exponential_cap(base: Duration, max: Duration, attempt: u32) -> Duration {
156 let shift = attempt.min(30);
157 let multiplier = 1u32 << shift;
158 base.saturating_mul(multiplier).min(max)
159}
160
161pub struct Session {
163 engine: StandardEngine,
164 identity: IdentityId,
165 authenticated: bool,
166 token: Option<String>,
167 retry: RetryStrategy,
168}
169
170impl Session {
171 pub fn from_token(engine: StandardEngine, info: &Token) -> Self {
173 Self {
174 engine,
175 identity: info.identity,
176 authenticated: true,
177 token: None,
178 retry: RetryStrategy::default(),
179 }
180 }
181
182 pub fn from_token_with_value(engine: StandardEngine, info: &Token) -> Self {
184 Self {
185 engine,
186 identity: info.identity,
187 authenticated: true,
188 token: Some(info.token.clone()),
189 retry: RetryStrategy::default(),
190 }
191 }
192
193 pub fn trusted(engine: StandardEngine, identity: IdentityId) -> Self {
195 Self {
196 engine,
197 identity,
198 authenticated: false,
199 token: None,
200 retry: RetryStrategy::default(),
201 }
202 }
203
204 pub fn anonymous(engine: StandardEngine) -> Self {
206 Self::trusted(engine, IdentityId::anonymous())
207 }
208
209 pub fn with_retry(mut self, strategy: RetryStrategy) -> Self {
211 self.retry = strategy;
212 self
213 }
214
215 #[inline]
217 pub fn identity(&self) -> IdentityId {
218 self.identity
219 }
220
221 #[inline]
223 pub fn token(&self) -> Option<&str> {
224 self.token.as_deref()
225 }
226
227 #[inline]
229 pub fn is_authenticated(&self) -> bool {
230 self.authenticated
231 }
232
233 #[instrument(name = "session::query", level = "debug", skip(self, params), fields(rql = %rql))]
235 pub fn query(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
236 self.engine.query_as(self.identity, rql, params.into())
237 }
238
239 #[instrument(name = "session::command", level = "debug", skip(self, params), fields(rql = %rql))]
241 pub fn command(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
242 let params = params.into();
243 self.retry
244 .execute(self.engine.rng(), rql, || self.engine.command_as(self.identity, rql, params.clone()))
245 }
246
247 #[instrument(name = "session::admin", level = "debug", skip(self, params), fields(rql = %rql))]
249 pub fn admin(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
250 let params = params.into();
251 self.retry.execute(self.engine.rng(), rql, || self.engine.admin_as(self.identity, rql, params.clone()))
252 }
253}
254
255#[cfg(test)]
256mod retry_tests {
257 use std::{cell::Cell, time::Duration};
258
259 use reifydb_core::{execution::ExecutionResult, metric::ExecutionMetrics};
260 use reifydb_runtime::context::rng::Rng;
261 use reifydb_type::{
262 error::{Diagnostic, Error},
263 fragment::Fragment,
264 };
265
266 use super::{Backoff, RetryStrategy, compute_backoff, exponential_cap};
267
268 fn ok() -> ExecutionResult {
269 ExecutionResult {
270 frames: vec![],
271 error: None,
272 metrics: ExecutionMetrics::default(),
273 }
274 }
275
276 fn err(code: &str) -> ExecutionResult {
277 ExecutionResult {
278 frames: vec![],
279 error: Some(Error(Box::new(Diagnostic {
280 code: code.to_string(),
281 rql: None,
282 message: format!("{} test", code),
283 column: None,
284 fragment: Fragment::None,
285 label: None,
286 help: None,
287 notes: vec![],
288 cause: None,
289 operator_chain: None,
290 }))),
291 metrics: ExecutionMetrics::default(),
292 }
293 }
294
295 fn no_sleep_strategy(max_attempts: u32) -> RetryStrategy {
296 RetryStrategy {
297 max_attempts,
298 backoff: Backoff::None,
299 }
300 }
301
302 #[test]
303 fn success_first_try_runs_closure_once() {
304 let strategy = no_sleep_strategy(5);
305 let rng = Rng::default();
306 let calls = Cell::new(0u32);
307 let result = strategy.execute(&rng, "", || {
308 calls.set(calls.get() + 1);
309 ok()
310 });
311 assert!(result.is_ok());
312 assert_eq!(calls.get(), 1);
313 }
314
315 #[test]
316 fn non_conflict_error_is_not_retried() {
317 let strategy = no_sleep_strategy(5);
318 let rng = Rng::default();
319 let calls = Cell::new(0u32);
320 let result = strategy.execute(&rng, "", || {
321 calls.set(calls.get() + 1);
322 err("TXN_002")
323 });
324 assert!(result.is_err());
325 assert_eq!(calls.get(), 1);
326 }
327
328 #[test]
329 fn conflict_retries_then_succeeds() {
330 let strategy = no_sleep_strategy(5);
331 let rng = Rng::default();
332 let calls = Cell::new(0u32);
333 let result = strategy.execute(&rng, "", || {
334 let n = calls.get();
335 calls.set(n + 1);
336 if n < 2 {
337 err("TXN_001")
338 } else {
339 ok()
340 }
341 });
342 assert!(result.is_ok());
343 assert_eq!(calls.get(), 3);
344 }
345
346 #[test]
347 fn conflict_exhausts_attempts_returns_last_error() {
348 let strategy = no_sleep_strategy(4);
349 let rng = Rng::default();
350 let calls = Cell::new(0u32);
351 let result = strategy.execute(&rng, "", || {
352 calls.set(calls.get() + 1);
353 err("TXN_001")
354 });
355 assert!(result.is_err());
356 assert_eq!(result.error.as_ref().unwrap().code, "TXN_001");
357 assert_eq!(calls.get(), 4);
358 }
359
360 #[test]
361 fn jittered_backoff_stays_within_cap() {
362 let base = Duration::from_millis(10);
363 let max = Duration::from_millis(100);
364 let backoff = Backoff::ExponentialJitter {
365 base,
366 max,
367 };
368 let rng = Rng::default();
369 for attempt in 0..8 {
370 let cap = exponential_cap(base, max, attempt);
371 for _ in 0..50 {
372 let d = compute_backoff(&backoff, attempt, &rng);
373 assert!(d <= cap, "attempt {}: {:?} exceeds cap {:?}", attempt, d, cap);
374 }
375 }
376 }
377
378 #[test]
379 fn seeded_rng_produces_deterministic_jitter() {
380 let base = Duration::from_millis(5);
381 let max = Duration::from_millis(200);
382 let backoff = Backoff::ExponentialJitter {
383 base,
384 max,
385 };
386 let sample = |seed: u64| -> Vec<Duration> {
387 let rng = Rng::seeded(seed);
388 (0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng)).collect()
389 };
390 assert_eq!(sample(42), sample(42));
391 assert_ne!(sample(42), sample(43));
392 }
393
394 #[test]
395 fn seeded_rng_produces_exact_pinned_jitter_values() {
396 let base = Duration::from_millis(5);
397 let max = Duration::from_millis(200);
398 let backoff = Backoff::ExponentialJitter {
399 base,
400 max,
401 };
402 let nanos = |seed: u64| -> Vec<u64> {
403 let rng = Rng::seeded(seed);
404 (0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng).as_nanos() as u64).collect()
405 };
406
407 let expected_42: Vec<u64> = vec![
408 3_848_394,
409 113_809,
410 2_934_288,
411 23_292_485,
412 77_680_508,
413 31_066_617,
414 36_519_179,
415 190_866_841,
416 ];
417 let expected_43: Vec<u64> = vec![
418 3_974_671, 4_842_103, 12_057_439, 29_830_325, 72_334_216, 22_229_100, 36_417_439, 81_417_246,
419 ];
420
421 assert_eq!(nanos(42), expected_42);
422 assert_eq!(nanos(43), expected_43);
423
424 assert_eq!(nanos(42), expected_42);
425 assert_eq!(nanos(43), expected_43);
426 }
427
428 #[test]
429 fn exponential_cap_saturates_at_max() {
430 let base = Duration::from_millis(5);
431 let max = Duration::from_millis(200);
432 assert_eq!(exponential_cap(base, max, 0), Duration::from_millis(5));
433 assert_eq!(exponential_cap(base, max, 1), Duration::from_millis(10));
434 assert_eq!(exponential_cap(base, max, 5), Duration::from_millis(160));
435 assert_eq!(exponential_cap(base, max, 6), max);
436 assert_eq!(exponential_cap(base, max, 100), max);
437 }
438
439 #[test]
440 fn default_uses_jittered_backoff() {
441 let s = RetryStrategy::default();
442 assert_eq!(s.max_attempts, 10);
443 match s.backoff {
444 Backoff::ExponentialJitter {
445 base,
446 max,
447 } => {
448 assert_eq!(base, Duration::from_millis(5));
449 assert_eq!(max, Duration::from_millis(200));
450 }
451 _ => panic!("expected ExponentialJitter default"),
452 }
453 }
454}