use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use once_cell::sync::Lazy;
use rand::{thread_rng, Rng};
use crate::errors::{Error, IdleSessionTimeoutError};
static IDLE_TIMEOUT_RULE: Lazy<RetryRule> = Lazy::new(|| RetryRule {
attempts: 2,
backoff: Arc::new(|_| Duration::new(0, 0)),
});
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum RetryCondition {
TransactionConflict,
NetworkError,
}
#[derive(Debug, Clone, Default)]
pub struct TransactionOptions {
read_only: bool,
deferrable: bool,
}
#[derive(Debug, Clone)]
pub struct RetryOptions(Arc<RetryOptionsInner>);
#[derive(Debug, Clone)]
struct RetryOptionsInner {
default: RetryRule,
overrides: HashMap<RetryCondition, RetryRule>,
}
#[derive(Clone)]
pub(crate) struct RetryRule {
pub(crate) attempts: u32,
pub(crate) backoff: Arc<dyn Fn(u32) -> Duration + Send + Sync>,
}
impl TransactionOptions {
pub fn read_only(mut self, read_only: bool) -> Self {
self.read_only = read_only;
self
}
pub fn deferrable(mut self, deferrable: bool) -> Self {
self.deferrable = deferrable;
self
}
}
impl Default for RetryRule {
fn default() -> RetryRule {
RetryRule {
attempts: 3,
backoff: Arc::new(|n| {
Duration::from_millis(2u64.pow(n) * 100 + thread_rng().gen_range(0..100))
}),
}
}
}
impl Default for RetryOptions {
fn default() -> RetryOptions {
RetryOptions(Arc::new(RetryOptionsInner {
default: RetryRule::default(),
overrides: HashMap::new(),
}))
}
}
impl RetryOptions {
pub fn new(
self,
attempts: u32,
backoff: impl Fn(u32) -> Duration + Send + Sync + 'static,
) -> Self {
RetryOptions(Arc::new(RetryOptionsInner {
default: RetryRule {
attempts,
backoff: Arc::new(backoff),
},
overrides: HashMap::new(),
}))
}
pub fn with_rule<F>(
mut self,
condition: RetryCondition,
attempts: u32,
backoff: impl Fn(u32) -> Duration + Send + Sync + 'static,
) -> Self {
let inner = Arc::make_mut(&mut self.0);
inner.overrides.insert(
condition,
RetryRule {
attempts,
backoff: Arc::new(backoff),
},
);
self
}
pub(crate) fn get_rule(&self, err: &Error) -> &RetryRule {
use edgedb_errors::{ClientError, TransactionConflictError};
use RetryCondition::*;
if err.is::<IdleSessionTimeoutError>() {
&IDLE_TIMEOUT_RULE
} else if err.is::<TransactionConflictError>() {
self.0
.overrides
.get(&TransactionConflict)
.unwrap_or(&self.0.default)
} else if err.is::<ClientError>() {
self.0
.overrides
.get(&NetworkError)
.unwrap_or(&self.0.default)
} else {
&self.0.default
}
}
}
struct DebugBackoff<F>(F, u32);
impl<F> fmt::Debug for DebugBackoff<F>
where
F: Fn(u32) -> Duration,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.1 > 3 {
for i in 0..3 {
write!(f, "{:?}, ", (self.0)(i))?;
}
write!(f, "...")?;
} else {
write!(f, "{:?}", (self.0)(0))?;
for i in 1..self.1 {
write!(f, ", {:?}", (self.0)(i))?;
}
}
Ok(())
}
}
#[test]
fn debug_backoff() {
assert_eq!(
format!(
"{:?}",
DebugBackoff(|i| Duration::from_secs(10 + (i as u64) * 10), 3)
),
"10s, 20s, 30s"
);
assert_eq!(
format!(
"{:?}",
DebugBackoff(|i| Duration::from_secs(10 + (i as u64) * 10), 10)
),
"10s, 20s, 30s, ..."
);
assert_eq!(
format!(
"{:?}",
DebugBackoff(|i| Duration::from_secs(10 + (i as u64) * 10), 2)
),
"10s, 20s"
);
}
impl fmt::Debug for RetryRule {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("RetryRule")
.field("attempts", &self.attempts)
.field("backoff", &DebugBackoff(&*self.backoff, self.attempts))
.finish()
}
}