use std::num::NonZeroU32;
use std::ops::ControlFlow;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use crate::errors::InvocationError;
impl crate::errors::RpcError {
pub fn migrate_dc_id(&self) -> Option<i32> {
if self.code != 303 {
return None;
}
let is_migrate = self.name == "PHONE_MIGRATE"
|| self.name == "NETWORK_MIGRATE"
|| self.name == "FILE_MIGRATE"
|| self.name == "USER_MIGRATE"
|| self.name.ends_with("_MIGRATE");
if is_migrate {
Some(self.value.unwrap_or(2) as i32)
} else {
None
}
}
}
impl InvocationError {
pub fn migrate_dc_id(&self) -> Option<i32> {
match self {
Self::Rpc(r) => r.migrate_dc_id(),
_ => None,
}
}
}
pub trait RetryPolicy: Send + Sync + 'static {
fn should_retry(&self, ctx: &RetryContext) -> ControlFlow<(), Duration>;
}
pub struct RetryContext {
pub fail_count: NonZeroU32,
pub slept_so_far: Duration,
pub error: InvocationError,
}
pub struct NoRetries;
impl RetryPolicy for NoRetries {
fn should_retry(&self, _: &RetryContext) -> ControlFlow<(), Duration> {
ControlFlow::Break(())
}
}
pub struct AutoSleep {
pub threshold: Duration,
pub io_errors_as_flood_of: Option<Duration>,
}
impl Default for AutoSleep {
fn default() -> Self {
Self {
threshold: Duration::from_secs(60),
io_errors_as_flood_of: Some(Duration::from_secs(1)),
}
}
}
impl RetryPolicy for AutoSleep {
fn should_retry(&self, ctx: &RetryContext) -> ControlFlow<(), Duration> {
match &ctx.error {
InvocationError::Rpc(rpc) if rpc.code == 420 && rpc.name == "FLOOD_WAIT" => {
let secs = rpc.value.unwrap_or(0) as u64;
if secs <= self.threshold.as_secs() {
tracing::info!("FLOOD_WAIT_{secs}: sleeping before retry");
ControlFlow::Continue(Duration::from_secs(secs))
} else {
ControlFlow::Break(())
}
}
InvocationError::Rpc(rpc) if rpc.code == 420 && rpc.name == "SLOWMODE_WAIT" => {
let secs = rpc.value.unwrap_or(0) as u64;
if secs <= self.threshold.as_secs() {
tracing::info!("SLOWMODE_WAIT_{secs}: sleeping before retry");
ControlFlow::Continue(Duration::from_secs(secs))
} else {
ControlFlow::Break(())
}
}
InvocationError::Io(_) if ctx.fail_count.get() == 1 => {
if let Some(d) = self.io_errors_as_flood_of {
tracing::info!("I/O error: sleeping {d:?} before retry");
ControlFlow::Continue(d)
} else {
ControlFlow::Break(())
}
}
_ => ControlFlow::Break(()),
}
}
}
pub struct RetryLoop {
policy: Arc<dyn RetryPolicy>,
ctx: RetryContext,
}
impl RetryLoop {
pub fn new(policy: Arc<dyn RetryPolicy>) -> Self {
Self {
policy,
ctx: RetryContext {
fail_count: NonZeroU32::new(1).unwrap(),
slept_so_far: Duration::default(),
error: InvocationError::Dropped,
},
}
}
pub async fn advance(&mut self, err: InvocationError) -> Result<(), InvocationError> {
self.ctx.error = err;
match self.policy.should_retry(&self.ctx) {
ControlFlow::Continue(delay) => {
sleep(delay).await;
self.ctx.slept_so_far += delay;
self.ctx.fail_count = self.ctx.fail_count.saturating_add(1);
Ok(())
}
ControlFlow::Break(()) => {
Err(std::mem::replace(
&mut self.ctx.error,
InvocationError::Dropped,
))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
fn flood(secs: u32) -> InvocationError {
InvocationError::Rpc(crate::errors::RpcError {
code: 420,
name: "FLOOD_WAIT".into(),
value: Some(secs),
})
}
fn io_err() -> InvocationError {
InvocationError::Io(io::Error::new(io::ErrorKind::ConnectionReset, "reset"))
}
fn rpc(code: i32, name: &str, value: Option<u32>) -> InvocationError {
InvocationError::Rpc(crate::errors::RpcError {
code,
name: name.into(),
value,
})
}
#[test]
fn no_retries_always_breaks() {
let policy = NoRetries;
let ctx = RetryContext {
fail_count: NonZeroU32::new(1).unwrap(),
slept_so_far: Duration::default(),
error: flood(10),
};
assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
}
#[test]
fn autosleep_retries_flood_under_threshold() {
let policy = AutoSleep::default(); let ctx = RetryContext {
fail_count: NonZeroU32::new(1).unwrap(),
slept_so_far: Duration::default(),
error: flood(30),
};
match policy.should_retry(&ctx) {
ControlFlow::Continue(d) => assert_eq!(d, Duration::from_secs(30)),
other => panic!("expected Continue, got {other:?}"),
}
}
#[test]
fn autosleep_breaks_flood_over_threshold() {
let policy = AutoSleep::default(); let ctx = RetryContext {
fail_count: NonZeroU32::new(1).unwrap(),
slept_so_far: Duration::default(),
error: flood(120),
};
assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
}
#[test]
fn autosleep_no_second_flood_retry() {
let policy = AutoSleep::default();
let ctx = RetryContext {
fail_count: NonZeroU32::new(2).unwrap(),
slept_so_far: Duration::from_secs(30),
error: flood(30),
};
assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
}
#[test]
fn autosleep_retries_io_once() {
let policy = AutoSleep::default();
let ctx = RetryContext {
fail_count: NonZeroU32::new(1).unwrap(),
slept_so_far: Duration::default(),
error: io_err(),
};
match policy.should_retry(&ctx) {
ControlFlow::Continue(d) => assert_eq!(d, Duration::from_secs(1)),
other => panic!("expected Continue, got {other:?}"),
}
}
#[test]
fn autosleep_no_second_io_retry() {
let policy = AutoSleep::default();
let ctx = RetryContext {
fail_count: NonZeroU32::new(2).unwrap(),
slept_so_far: Duration::from_secs(1),
error: io_err(),
};
assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
}
#[test]
fn autosleep_breaks_other_rpc() {
let policy = AutoSleep::default();
let ctx = RetryContext {
fail_count: NonZeroU32::new(1).unwrap(),
slept_so_far: Duration::default(),
error: rpc(400, "BAD_REQUEST", None),
};
assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
}
#[test]
fn migrate_dc_id_detected() {
let e = crate::errors::RpcError {
code: 303,
name: "PHONE_MIGRATE".into(),
value: Some(5),
};
assert_eq!(e.migrate_dc_id(), Some(5));
}
#[test]
fn network_migrate_detected() {
let e = crate::errors::RpcError {
code: 303,
name: "NETWORK_MIGRATE".into(),
value: Some(3),
};
assert_eq!(e.migrate_dc_id(), Some(3));
}
#[test]
fn file_migrate_detected() {
let e = crate::errors::RpcError {
code: 303,
name: "FILE_MIGRATE".into(),
value: Some(4),
};
assert_eq!(e.migrate_dc_id(), Some(4));
}
#[test]
fn non_migrate_is_none() {
let e = crate::errors::RpcError {
code: 420,
name: "FLOOD_WAIT".into(),
value: Some(30),
};
assert_eq!(e.migrate_dc_id(), None);
}
#[test]
fn migrate_falls_back_to_dc2_when_no_value() {
let e = crate::errors::RpcError {
code: 303,
name: "PHONE_MIGRATE".into(),
value: None,
};
assert_eq!(e.migrate_dc_id(), Some(2));
}
#[tokio::test]
async fn retry_loop_gives_up_on_no_retries() {
let mut rl = RetryLoop::new(Arc::new(NoRetries));
let err = rpc(400, "SOMETHING_WRONG", None);
let result = rl.advance(err).await;
assert!(result.is_err());
}
#[tokio::test]
async fn retry_loop_increments_fail_count() {
let mut rl = RetryLoop::new(Arc::new(AutoSleep {
threshold: Duration::from_secs(60),
io_errors_as_flood_of: Some(Duration::from_millis(1)),
}));
assert!(rl.advance(io_err()).await.is_ok());
assert!(rl.advance(io_err()).await.is_err());
}
}