1use std::{thread, time::Duration};
5
6use reifydb_core::{execution::ExecutionResult, interface::catalog::token::Token};
7use reifydb_runtime::context::rng::Rng;
8use reifydb_type::{params::Params, value::identity::IdentityId};
9use tracing::{debug, instrument, warn};
10
11use crate::engine::StandardEngine;
12
13pub enum Backoff {
14 None,
15
16 Fixed(Duration),
17
18 Exponential {
19 base: Duration,
20 max: Duration,
21 },
22 ExponentialJitter {
23 base: Duration,
24 max: Duration,
25 },
26}
27
28pub struct RetryStrategy {
29 pub max_attempts: u32,
30 pub backoff: Backoff,
31}
32
33impl Default for RetryStrategy {
34 fn default() -> Self {
35 Self {
36 max_attempts: 10,
37 backoff: Backoff::ExponentialJitter {
38 base: Duration::from_millis(5),
39 max: Duration::from_millis(200),
40 },
41 }
42 }
43}
44
45impl RetryStrategy {
46 pub fn no_retry() -> Self {
47 Self {
48 max_attempts: 1,
49 backoff: Backoff::None,
50 }
51 }
52
53 pub fn default_conflict_retry() -> Self {
54 Self::default()
55 }
56
57 pub fn with_fixed_backoff(max_attempts: u32, delay: Duration) -> Self {
58 Self {
59 max_attempts,
60 backoff: Backoff::Fixed(delay),
61 }
62 }
63
64 pub fn with_exponential_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
65 Self {
66 max_attempts,
67 backoff: Backoff::Exponential {
68 base,
69 max,
70 },
71 }
72 }
73
74 pub fn with_jittered_backoff(max_attempts: u32, base: Duration, max: Duration) -> Self {
75 Self {
76 max_attempts,
77 backoff: Backoff::ExponentialJitter {
78 base,
79 max,
80 },
81 }
82 }
83
84 pub fn execute<F>(&self, rng: &Rng, rql: &str, mut f: F) -> ExecutionResult
85 where
86 F: FnMut() -> ExecutionResult,
87 {
88 let mut last_result = None;
89 for attempt in 0..self.max_attempts {
90 let result = f();
91 match &result.error {
92 None => return result,
93 Some(err) if err.code == "TXN_001" => {
94 last_result = Some(result);
95 let is_last_attempt = attempt + 1 >= self.max_attempts;
96 if is_last_attempt {
97 warn!(
98 attempt = attempt + 1,
99 max_attempts = self.max_attempts,
100 rql = %rql,
101 "Transaction conflict retries exhausted"
102 );
103 } else {
104 let delay = compute_backoff(&self.backoff, attempt, rng);
105 debug!(
106 attempt = attempt + 1,
107 max_attempts = self.max_attempts,
108 delay_us = delay.as_micros() as u64,
109 rql = %rql,
110 "Transaction conflict detected, retrying after backoff"
111 );
112 if !delay.is_zero() {
113 thread::sleep(delay);
114 }
115 }
116 }
117 Some(_) => {
118 return result;
119 }
120 }
121 }
122 last_result.unwrap()
123 }
124}
125
126fn compute_backoff(backoff: &Backoff, attempt: u32, rng: &Rng) -> Duration {
127 match backoff {
128 Backoff::None => Duration::ZERO,
129 Backoff::Fixed(d) => *d,
130 Backoff::Exponential {
131 base,
132 max,
133 } => exponential_cap(*base, *max, attempt),
134 Backoff::ExponentialJitter {
135 base,
136 max,
137 } => {
138 let cap = exponential_cap(*base, *max, attempt);
139 let cap_nanos = cap.as_nanos().min(u64::MAX as u128) as u64;
140 if cap_nanos == 0 {
141 return Duration::ZERO;
142 }
143 let sampled = rng.infra_u64_inclusive(cap_nanos);
144 Duration::from_nanos(sampled)
145 }
146 }
147}
148
149fn exponential_cap(base: Duration, max: Duration, attempt: u32) -> Duration {
150 let shift = attempt.min(30);
151 let multiplier = 1u32 << shift;
152 base.saturating_mul(multiplier).min(max)
153}
154
155pub struct Session {
156 engine: StandardEngine,
157 identity: IdentityId,
158 authenticated: bool,
159 token: Option<String>,
160 retry: RetryStrategy,
161}
162
163impl Session {
164 pub fn from_token(engine: StandardEngine, info: &Token) -> Self {
165 Self {
166 engine,
167 identity: info.identity,
168 authenticated: true,
169 token: None,
170 retry: RetryStrategy::default(),
171 }
172 }
173
174 pub fn from_token_with_value(engine: StandardEngine, info: &Token) -> Self {
175 Self {
176 engine,
177 identity: info.identity,
178 authenticated: true,
179 token: Some(info.token.clone()),
180 retry: RetryStrategy::default(),
181 }
182 }
183
184 pub fn trusted(engine: StandardEngine, identity: IdentityId) -> Self {
185 Self {
186 engine,
187 identity,
188 authenticated: false,
189 token: None,
190 retry: RetryStrategy::default(),
191 }
192 }
193
194 pub fn anonymous(engine: StandardEngine) -> Self {
195 Self::trusted(engine, IdentityId::anonymous())
196 }
197
198 pub fn with_retry(mut self, strategy: RetryStrategy) -> Self {
199 self.retry = strategy;
200 self
201 }
202
203 #[inline]
204 pub fn identity(&self) -> IdentityId {
205 self.identity
206 }
207
208 #[inline]
209 pub fn token(&self) -> Option<&str> {
210 self.token.as_deref()
211 }
212
213 #[inline]
214 pub fn is_authenticated(&self) -> bool {
215 self.authenticated
216 }
217
218 #[instrument(name = "session::query", level = "debug", skip(self, params), fields(rql = %rql))]
219 pub fn query(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
220 self.engine.query_as(self.identity, rql, params.into())
221 }
222
223 #[instrument(name = "session::command", level = "debug", skip(self, params), fields(rql = %rql))]
224 pub fn command(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
225 let params = params.into();
226 self.retry
227 .execute(self.engine.rng(), rql, || self.engine.command_as(self.identity, rql, params.clone()))
228 }
229
230 #[instrument(name = "session::admin", level = "debug", skip(self, params), fields(rql = %rql))]
231 pub fn admin(&self, rql: &str, params: impl Into<Params>) -> ExecutionResult {
232 let params = params.into();
233 self.retry.execute(self.engine.rng(), rql, || self.engine.admin_as(self.identity, rql, params.clone()))
234 }
235}
236
237#[cfg(test)]
238mod retry_tests {
239 use std::{cell::Cell, time::Duration};
240
241 use reifydb_core::{execution::ExecutionResult, metric::ExecutionMetrics};
242 use reifydb_runtime::context::rng::Rng;
243 use reifydb_type::{
244 error::{Diagnostic, Error},
245 fragment::Fragment,
246 };
247
248 use super::{Backoff, RetryStrategy, compute_backoff, exponential_cap};
249
250 fn ok() -> ExecutionResult {
251 ExecutionResult {
252 frames: vec![],
253 error: None,
254 metrics: ExecutionMetrics::default(),
255 }
256 }
257
258 fn err(code: &str) -> ExecutionResult {
259 ExecutionResult {
260 frames: vec![],
261 error: Some(Error(Box::new(Diagnostic {
262 code: code.to_string(),
263 rql: None,
264 message: format!("{} test", code),
265 column: None,
266 fragment: Fragment::None,
267 label: None,
268 help: None,
269 notes: vec![],
270 cause: None,
271 operator_chain: None,
272 }))),
273 metrics: ExecutionMetrics::default(),
274 }
275 }
276
277 fn no_sleep_strategy(max_attempts: u32) -> RetryStrategy {
278 RetryStrategy {
279 max_attempts,
280 backoff: Backoff::None,
281 }
282 }
283
284 #[test]
285 fn success_first_try_runs_closure_once() {
286 let strategy = no_sleep_strategy(5);
287 let rng = Rng::default();
288 let calls = Cell::new(0u32);
289 let result = strategy.execute(&rng, "", || {
290 calls.set(calls.get() + 1);
291 ok()
292 });
293 assert!(result.is_ok());
294 assert_eq!(calls.get(), 1);
295 }
296
297 #[test]
298 fn non_conflict_error_is_not_retried() {
299 let strategy = no_sleep_strategy(5);
300 let rng = Rng::default();
301 let calls = Cell::new(0u32);
302 let result = strategy.execute(&rng, "", || {
303 calls.set(calls.get() + 1);
304 err("TXN_002")
305 });
306 assert!(result.is_err());
307 assert_eq!(calls.get(), 1);
308 }
309
310 #[test]
311 fn conflict_retries_then_succeeds() {
312 let strategy = no_sleep_strategy(5);
313 let rng = Rng::default();
314 let calls = Cell::new(0u32);
315 let result = strategy.execute(&rng, "", || {
316 let n = calls.get();
317 calls.set(n + 1);
318 if n < 2 {
319 err("TXN_001")
320 } else {
321 ok()
322 }
323 });
324 assert!(result.is_ok());
325 assert_eq!(calls.get(), 3);
326 }
327
328 #[test]
329 fn conflict_exhausts_attempts_returns_last_error() {
330 let strategy = no_sleep_strategy(4);
331 let rng = Rng::default();
332 let calls = Cell::new(0u32);
333 let result = strategy.execute(&rng, "", || {
334 calls.set(calls.get() + 1);
335 err("TXN_001")
336 });
337 assert!(result.is_err());
338 assert_eq!(result.error.as_ref().unwrap().code, "TXN_001");
339 assert_eq!(calls.get(), 4);
340 }
341
342 #[test]
343 fn jittered_backoff_stays_within_cap() {
344 let base = Duration::from_millis(10);
345 let max = Duration::from_millis(100);
346 let backoff = Backoff::ExponentialJitter {
347 base,
348 max,
349 };
350 let rng = Rng::default();
351 for attempt in 0..8 {
352 let cap = exponential_cap(base, max, attempt);
353 for _ in 0..50 {
354 let d = compute_backoff(&backoff, attempt, &rng);
355 assert!(d <= cap, "attempt {}: {:?} exceeds cap {:?}", attempt, d, cap);
356 }
357 }
358 }
359
360 #[test]
361 fn seeded_rng_produces_deterministic_jitter() {
362 let base = Duration::from_millis(5);
363 let max = Duration::from_millis(200);
364 let backoff = Backoff::ExponentialJitter {
365 base,
366 max,
367 };
368 let sample = |seed: u64| -> Vec<Duration> {
369 let rng = Rng::seeded(seed);
370 (0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng)).collect()
371 };
372 assert_eq!(sample(42), sample(42));
373 assert_ne!(sample(42), sample(43));
374 }
375
376 #[test]
377 fn seeded_rng_produces_exact_pinned_jitter_values() {
378 let base = Duration::from_millis(5);
379 let max = Duration::from_millis(200);
380 let backoff = Backoff::ExponentialJitter {
381 base,
382 max,
383 };
384 let nanos = |seed: u64| -> Vec<u64> {
385 let rng = Rng::seeded(seed);
386 (0..8).map(|attempt| compute_backoff(&backoff, attempt, &rng).as_nanos() as u64).collect()
387 };
388
389 let expected_42: Vec<u64> = vec![
390 3_848_394,
391 113_809,
392 2_934_288,
393 23_292_485,
394 77_680_508,
395 31_066_617,
396 36_519_179,
397 190_866_841,
398 ];
399 let expected_43: Vec<u64> = vec![
400 3_974_671, 4_842_103, 12_057_439, 29_830_325, 72_334_216, 22_229_100, 36_417_439, 81_417_246,
401 ];
402
403 assert_eq!(nanos(42), expected_42);
404 assert_eq!(nanos(43), expected_43);
405
406 assert_eq!(nanos(42), expected_42);
407 assert_eq!(nanos(43), expected_43);
408 }
409
410 #[test]
411 fn exponential_cap_saturates_at_max() {
412 let base = Duration::from_millis(5);
413 let max = Duration::from_millis(200);
414 assert_eq!(exponential_cap(base, max, 0), Duration::from_millis(5));
415 assert_eq!(exponential_cap(base, max, 1), Duration::from_millis(10));
416 assert_eq!(exponential_cap(base, max, 5), Duration::from_millis(160));
417 assert_eq!(exponential_cap(base, max, 6), max);
418 assert_eq!(exponential_cap(base, max, 100), max);
419 }
420
421 #[test]
422 fn default_uses_jittered_backoff() {
423 let s = RetryStrategy::default();
424 assert_eq!(s.max_attempts, 10);
425 match s.backoff {
426 Backoff::ExponentialJitter {
427 base,
428 max,
429 } => {
430 assert_eq!(base, Duration::from_millis(5));
431 assert_eq!(max, Duration::from_millis(200));
432 }
433 _ => panic!("expected ExponentialJitter default"),
434 }
435 }
436}