use std::num::NonZeroU32;
use std::ops::ControlFlow;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use crate::errors::InvocationError;
pub trait RetryPolicy: Send + Sync + 'static {
fn should_retry(&self, ctx: &RetryContext) -> ControlFlow<(), Duration>;
}
pub struct RetryContext {
pub fail_count: NonZeroU32,
pub slept_so_far: Duration,
pub error: InvocationError,
}
pub struct NoRetries;
impl RetryPolicy for NoRetries {
fn should_retry(&self, _: &RetryContext) -> ControlFlow<(), Duration> {
ControlFlow::Break(())
}
}
pub struct AutoSleep {
pub threshold: Duration,
pub io_errors_as_flood_of: Option<Duration>,
}
impl Default for AutoSleep {
fn default() -> Self {
Self {
threshold: Duration::from_secs(60),
io_errors_as_flood_of: Some(Duration::from_secs(1)),
}
}
}
fn jitter_duration(base: Duration, seed: u32, max_jitter_secs: u64) -> Duration {
let h = {
let mut v = seed as u64 ^ 0x9e37_79b9_7f4a_7c15;
v ^= v >> 30;
v = v.wrapping_mul(0xbf58_476d_1ce4_e5b9);
v ^= v >> 27;
v = v.wrapping_mul(0x94d0_49bb_1331_11eb);
v ^= v >> 31;
v
};
let range_ms = max_jitter_secs * 1000 * 2 + 1;
let jitter_ms = (h % range_ms) as i64 - (max_jitter_secs * 1000) as i64;
let base_ms = base.as_millis() as i64;
let final_ms = (base_ms + jitter_ms).max(0) as u64;
Duration::from_millis(final_ms)
}
impl RetryPolicy for AutoSleep {
fn should_retry(&self, ctx: &RetryContext) -> ControlFlow<(), Duration> {
match &ctx.error {
InvocationError::Rpc(rpc) if rpc.code == 420 && rpc.name == "FLOOD_WAIT" => {
let secs = rpc.value.unwrap_or(0) as u64;
if secs <= self.threshold.as_secs() {
let delay = jitter_duration(Duration::from_secs(secs), ctx.fail_count.get(), 2);
tracing::info!("FLOOD_WAIT_{secs}: sleeping {delay:?} before retry");
ControlFlow::Continue(delay)
} else {
ControlFlow::Break(())
}
}
InvocationError::Rpc(rpc) if rpc.code == 420 && rpc.name == "SLOWMODE_WAIT" => {
let secs = rpc.value.unwrap_or(0) as u64;
if secs <= self.threshold.as_secs() {
let delay = jitter_duration(Duration::from_secs(secs), ctx.fail_count.get(), 2);
tracing::info!("SLOWMODE_WAIT_{secs}: sleeping {delay:?} before retry");
ControlFlow::Continue(delay)
} else {
ControlFlow::Break(())
}
}
InvocationError::Io(_) if ctx.fail_count.get() <= 1 => {
if let Some(d) = self.io_errors_as_flood_of {
tracing::info!(
"I/O error (attempt {}): sleeping {d:?} before retry",
ctx.fail_count.get()
);
ControlFlow::Continue(d)
} else {
ControlFlow::Break(())
}
}
_ => ControlFlow::Break(()),
}
}
}
pub struct RetryLoop {
policy: Arc<dyn RetryPolicy>,
ctx: RetryContext,
}
impl RetryLoop {
pub fn new(policy: Arc<dyn RetryPolicy>) -> Self {
Self {
policy,
ctx: RetryContext {
fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
slept_so_far: Duration::default(),
error: InvocationError::Dropped,
},
}
}
pub async fn advance(&mut self, err: InvocationError) -> Result<(), InvocationError> {
self.ctx.error = err;
match self.policy.should_retry(&self.ctx) {
ControlFlow::Continue(delay) => {
sleep(delay).await;
self.ctx.slept_so_far += delay;
self.ctx.fail_count = self.ctx.fail_count.saturating_add(1);
Ok(())
}
ControlFlow::Break(()) => {
Err(std::mem::replace(
&mut self.ctx.error,
InvocationError::Dropped,
))
}
}
}
}
#[derive(Debug)]
enum CbState {
Closed { consecutive_failures: u32 },
Open { tripped_at: std::time::Instant },
}
pub struct CircuitBreaker {
threshold: u32,
cooldown: Duration,
state: std::sync::Mutex<CbState>,
}
impl CircuitBreaker {
pub fn new(threshold: u32, cooldown: Duration) -> Self {
assert!(
threshold >= 1,
"CircuitBreaker threshold must be at least 1"
);
Self {
threshold,
cooldown,
state: std::sync::Mutex::new(CbState::Closed {
consecutive_failures: 0,
}),
}
}
}
impl RetryPolicy for CircuitBreaker {
fn should_retry(&self, _ctx: &RetryContext) -> ControlFlow<(), Duration> {
let mut state = self.state.lock().expect("lock poisoned");
match &*state {
CbState::Open { tripped_at } => {
if tripped_at.elapsed() >= self.cooldown {
*state = CbState::Closed {
consecutive_failures: 1,
};
ControlFlow::Continue(Duration::from_millis(200))
} else {
ControlFlow::Break(())
}
}
CbState::Closed {
consecutive_failures,
} => {
let new_count = consecutive_failures + 1;
if new_count >= self.threshold {
tracing::warn!(
"[ferogram] CircuitBreaker tripped after {new_count} consecutive failures"
);
*state = CbState::Open {
tripped_at: std::time::Instant::now(),
};
ControlFlow::Break(())
} else {
let backoff_ms = 200u64 * (1u64 << new_count.saturating_sub(1).min(4));
*state = CbState::Closed {
consecutive_failures: new_count,
};
ControlFlow::Continue(Duration::from_millis(backoff_ms))
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::errors::RpcError;
use std::io;
fn flood(secs: u32) -> InvocationError {
InvocationError::Rpc(RpcError {
code: 420,
name: "FLOOD_WAIT".into(),
value: Some(secs),
})
}
fn io_err() -> InvocationError {
InvocationError::Io(io::Error::new(io::ErrorKind::ConnectionReset, "reset"))
}
fn rpc(code: i32, name: &str, value: Option<u32>) -> InvocationError {
InvocationError::Rpc(RpcError {
code,
name: name.into(),
value,
})
}
#[test]
fn no_retries_always_breaks() {
let policy = NoRetries;
let ctx = RetryContext {
fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
slept_so_far: Duration::default(),
error: flood(10),
};
assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
}
#[test]
fn autosleep_retries_flood_under_threshold() {
let policy = AutoSleep::default(); let ctx = RetryContext {
fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
slept_so_far: Duration::default(),
error: flood(30),
};
match policy.should_retry(&ctx) {
ControlFlow::Continue(d) => {
let secs = d.as_secs_f64();
assert!(
secs >= 28.0 && secs <= 32.0,
"expected 28-32s delay (jitter), got {secs:.3}s"
);
}
other => panic!("expected Continue, got {other:?}"),
}
}
#[test]
fn autosleep_breaks_flood_over_threshold() {
let policy = AutoSleep::default(); let ctx = RetryContext {
fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
slept_so_far: Duration::default(),
error: flood(120),
};
assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
}
#[test]
fn autosleep_second_flood_retry_is_honoured() {
let policy = AutoSleep::default();
let ctx = RetryContext {
fail_count: NonZeroU32::new(2).expect("2 is nonzero"),
slept_so_far: Duration::from_secs(30),
error: flood(30),
};
match policy.should_retry(&ctx) {
ControlFlow::Continue(d) => {
let secs = d.as_secs_f64();
assert!(
secs >= 28.0 && secs <= 32.0,
"expected 28-32s on second FLOOD_WAIT, got {secs:.3}s"
);
}
other => panic!("expected Continue on second FLOOD_WAIT, got {other:?}"),
}
}
#[test]
fn autosleep_retries_io_once() {
let policy = AutoSleep::default();
let ctx = RetryContext {
fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
slept_so_far: Duration::default(),
error: io_err(),
};
match policy.should_retry(&ctx) {
ControlFlow::Continue(d) => assert_eq!(d, Duration::from_secs(1)),
other => panic!("expected Continue, got {other:?}"),
}
}
#[test]
fn autosleep_no_io_retry_after_first() {
let policy = AutoSleep::default();
let ctx = RetryContext {
fail_count: NonZeroU32::new(4).expect("4 is nonzero"),
slept_so_far: Duration::from_secs(3),
error: io_err(),
};
assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
}
#[test]
fn autosleep_breaks_other_rpc() {
let policy = AutoSleep::default();
let ctx = RetryContext {
fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
slept_so_far: Duration::default(),
error: rpc(400, "BAD_REQUEST", None),
};
assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
}
#[test]
fn migrate_dc_id_detected() {
let e = RpcError {
code: 303,
name: "PHONE_MIGRATE".into(),
value: Some(5),
};
assert_eq!(e.migrate_dc_id(), Some(5));
}
#[test]
fn network_migrate_detected() {
let e = RpcError {
code: 303,
name: "NETWORK_MIGRATE".into(),
value: Some(3),
};
assert_eq!(e.migrate_dc_id(), Some(3));
}
#[test]
fn file_migrate_detected() {
let e = RpcError {
code: 303,
name: "FILE_MIGRATE".into(),
value: Some(4),
};
assert_eq!(e.migrate_dc_id(), Some(4));
}
#[test]
fn non_migrate_is_none() {
let e = RpcError {
code: 420,
name: "FLOOD_WAIT".into(),
value: Some(30),
};
assert_eq!(e.migrate_dc_id(), None);
}
#[test]
fn migrate_falls_back_to_dc2_when_no_value() {
let e = RpcError {
code: 303,
name: "PHONE_MIGRATE".into(),
value: None,
};
assert_eq!(e.migrate_dc_id(), Some(2));
}
#[tokio::test]
async fn retry_loop_gives_up_on_no_retries() {
let mut rl = RetryLoop::new(Arc::new(NoRetries));
let err = rpc(400, "SOMETHING_WRONG", None);
let result = rl.advance(err).await;
assert!(result.is_err());
}
#[tokio::test]
async fn retry_loop_increments_fail_count() {
let mut rl = RetryLoop::new(Arc::new(AutoSleep {
threshold: Duration::from_secs(60),
io_errors_as_flood_of: Some(Duration::from_millis(1)),
}));
assert!(rl.advance(io_err()).await.is_ok());
assert!(rl.advance(io_err()).await.is_err());
}
#[test]
fn circuit_breaker_trips_after_threshold() {
let cb = CircuitBreaker::new(3, Duration::from_secs(60));
let ctx = |n: u32| RetryContext {
fail_count: NonZeroU32::new(n).unwrap(),
slept_so_far: Duration::default(),
error: rpc(500, "INTERNAL", None),
};
assert!(matches!(cb.should_retry(&ctx(1)), ControlFlow::Continue(_)));
assert!(matches!(cb.should_retry(&ctx(2)), ControlFlow::Continue(_)));
assert!(matches!(cb.should_retry(&ctx(3)), ControlFlow::Break(())));
assert!(matches!(cb.should_retry(&ctx(4)), ControlFlow::Break(())));
}
#[test]
fn circuit_breaker_resets_after_cooldown() {
let cb = CircuitBreaker::new(2, Duration::from_millis(10));
let ctx = |n: u32| RetryContext {
fail_count: NonZeroU32::new(n).unwrap(),
slept_so_far: Duration::default(),
error: rpc(500, "INTERNAL", None),
};
assert!(matches!(cb.should_retry(&ctx(1)), ControlFlow::Continue(_)));
assert!(matches!(cb.should_retry(&ctx(2)), ControlFlow::Break(())));
std::thread::sleep(Duration::from_millis(20));
assert!(matches!(cb.should_retry(&ctx(1)), ControlFlow::Continue(_)));
}
}