Skip to main content

ferogram_mtsender/
retry.rs

1// Copyright (c) Ankit Chaubey <ankitchaubey.dev@gmail.com>
2//
3// ferogram: async Telegram MTProto client in Rust
4// https://github.com/ankit-chaubey/ferogram
5//
6// Licensed under either the MIT License or the Apache License 2.0.
7// See the LICENSE-MIT or LICENSE-APACHE file in this repository:
8// https://github.com/ankit-chaubey/ferogram
9//
10// Feel free to use, modify, and share this code.
11// Please keep this notice when redistributing.
12
13use std::num::NonZeroU32;
14use std::ops::ControlFlow;
15use std::sync::Arc;
16use std::time::Duration;
17
18use tokio::time::sleep;
19
20use crate::errors::InvocationError;
21
22// RetryPolicy trait
23
24/// Controls how the client reacts when an RPC call fails.
25///
26/// Implement this trait to provide custom flood-wait handling, circuit
27/// breakers, or exponential back-off.
28pub trait RetryPolicy: Send + Sync + 'static {
29    /// Decide whether to retry the failed request.
30    ///
31    /// Return `ControlFlow::Continue(delay)` to sleep `delay` and retry.
32    /// Return `ControlFlow::Break(())` to propagate `ctx.error` to the caller.
33    fn should_retry(&self, ctx: &RetryContext) -> ControlFlow<(), Duration>;
34}
35
36/// Context passed to [`RetryPolicy::should_retry`] on each failure.
37pub struct RetryContext {
38    /// Number of times this request has failed (starts at 1).
39    pub fail_count: NonZeroU32,
40    /// Total time already slept for this request across all prior retries.
41    pub slept_so_far: Duration,
42    /// The most recent error.
43    pub error: InvocationError,
44}
45
46// Built-in policies
47
48/// Never retry: propagate every error immediately.
49pub struct NoRetries;
50
51impl RetryPolicy for NoRetries {
52    fn should_retry(&self, _: &RetryContext) -> ControlFlow<(), Duration> {
53        ControlFlow::Break(())
54    }
55}
56
57/// Automatically sleep on `FLOOD_WAIT` and retry once on transient I/O errors.
58///
59/// Default retry policy. Sleeps on `FLOOD_WAIT`, backs off on I/O errors.
60///
61/// ```rust
62/// # use ferogram_mtsender::AutoSleep;
63/// let policy = AutoSleep {
64/// threshold: std::time::Duration::from_secs(60),
65/// io_errors_as_flood_of: Some(std::time::Duration::from_secs(1)),
66/// };
67/// ```
68pub struct AutoSleep {
69    /// Maximum flood-wait the library will automatically sleep through.
70    ///
71    /// If Telegram asks us to wait longer than this, the error is propagated.
72    pub threshold: Duration,
73
74    /// If `Some(d)`, treat the first I/O error as a `d`-second flood wait
75    /// and retry once.  `None` propagates I/O errors immediately.
76    pub io_errors_as_flood_of: Option<Duration>,
77}
78
79impl Default for AutoSleep {
80    fn default() -> Self {
81        Self {
82            threshold: Duration::from_secs(60),
83            io_errors_as_flood_of: Some(Duration::from_secs(1)),
84        }
85    }
86}
87
88/// Add deterministic ±`max_jitter_secs` jitter to `base`.
89///
90/// Uses a fast integer hash of `seed` (the fail count) so no `rand` crate is
91/// needed. Different bots have different fail counts at any given moment, so
92/// the spread is sufficient to avoid thundering-herd on simultaneous FLOOD_WAITs.
93fn jitter_duration(base: Duration, seed: u32, max_jitter_secs: u64) -> Duration {
94    // Murmur3-inspired finalizer.
95    let h = {
96        let mut v = seed as u64 ^ 0x9e37_79b9_7f4a_7c15;
97        v ^= v >> 30;
98        v = v.wrapping_mul(0xbf58_476d_1ce4_e5b9);
99        v ^= v >> 27;
100        v = v.wrapping_mul(0x94d0_49bb_1331_11eb);
101        v ^= v >> 31;
102        v
103    };
104    // Map into [-max_jitter_secs, +max_jitter_secs] in milliseconds.
105    let range_ms = max_jitter_secs * 1000 * 2 + 1;
106    let jitter_ms = (h % range_ms) as i64 - (max_jitter_secs * 1000) as i64;
107    let base_ms = base.as_millis() as i64;
108    let final_ms = (base_ms + jitter_ms).max(0) as u64;
109    Duration::from_millis(final_ms)
110}
111
112impl RetryPolicy for AutoSleep {
113    fn should_retry(&self, ctx: &RetryContext) -> ControlFlow<(), Duration> {
114        match &ctx.error {
115            // FLOOD_WAIT: sleep as long as Telegram asks, plus ±2 s jitter.
116            // Jitter spreads retries across clients that all hit the same limit
117            // simultaneously (e.g. after a server-side rate-limit window resets).
118            InvocationError::Rpc(rpc) if rpc.code == 420 && rpc.name == "FLOOD_WAIT" => {
119                let secs = rpc.value.unwrap_or(0) as u64;
120                if secs <= self.threshold.as_secs() {
121                    let delay = jitter_duration(Duration::from_secs(secs), ctx.fail_count.get(), 2);
122                    tracing::info!("FLOOD_WAIT_{secs}: sleeping {delay:?} before retry");
123                    ControlFlow::Continue(delay)
124                } else {
125                    ControlFlow::Break(())
126                }
127            }
128
129            // SLOWMODE_WAIT: same semantics as FLOOD_WAIT; very common in
130            // group bots that send messages faster than the channel's slowmode.
131            InvocationError::Rpc(rpc) if rpc.code == 420 && rpc.name == "SLOWMODE_WAIT" => {
132                let secs = rpc.value.unwrap_or(0) as u64;
133                if secs <= self.threshold.as_secs() {
134                    let delay = jitter_duration(Duration::from_secs(secs), ctx.fail_count.get(), 2);
135                    tracing::info!("SLOWMODE_WAIT_{secs}: sleeping {delay:?} before retry");
136                    ControlFlow::Continue(delay)
137                } else {
138                    ControlFlow::Break(())
139                }
140            }
141
142            // Transient I/O errors: back off briefly and retry once.
143            InvocationError::Io(_) if ctx.fail_count.get() <= 1 => {
144                if let Some(d) = self.io_errors_as_flood_of {
145                    tracing::info!(
146                        "I/O error (attempt {}): sleeping {d:?} before retry",
147                        ctx.fail_count.get()
148                    );
149                    ControlFlow::Continue(d)
150                } else {
151                    ControlFlow::Break(())
152                }
153            }
154
155            _ => ControlFlow::Break(()),
156        }
157    }
158}
159
160// RetryLoop
161
162/// Drives the retry loop for a single RPC call.
163///
164/// Create one per call, then call `advance` after every failure.
165///
166/// ```rust,ignore
167/// let mut rl = RetryLoop::new(Arc::clone(&self.inner.retry_policy));
168/// loop {
169/// match self.do_rpc_call(req).await {
170///     Ok(body) => return Ok(body),
171///     Err(e)   => rl.advance(e).await?,
172/// }
173/// }
174/// ```
175///
176/// `advance` either:
177/// - sleeps the required duration and returns `Ok(())` → caller should retry, or
178/// - returns `Err(e)` → caller should propagate.
179///
180/// This is the single source of truth; previously the same loop was
181/// copy-pasted into `rpc_call_raw`, `rpc_write`, and the reconnect path.
182pub struct RetryLoop {
183    policy: Arc<dyn RetryPolicy>,
184    ctx: RetryContext,
185}
186
187impl RetryLoop {
188    pub fn new(policy: Arc<dyn RetryPolicy>) -> Self {
189        Self {
190            policy,
191            ctx: RetryContext {
192                fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
193                slept_so_far: Duration::default(),
194                error: InvocationError::Dropped,
195            },
196        }
197    }
198
199    /// Record a failure and either sleep+return-Ok (retry) or return-Err (give up).
200    ///
201    /// Mutates `self` to track cumulative state across retries.
202    pub async fn advance(&mut self, err: InvocationError) -> Result<(), InvocationError> {
203        self.ctx.error = err;
204        match self.policy.should_retry(&self.ctx) {
205            ControlFlow::Continue(delay) => {
206                sleep(delay).await;
207                self.ctx.slept_so_far += delay;
208                // saturating_add: if somehow we overflow NonZeroU32, clamp at MAX
209                self.ctx.fail_count = self.ctx.fail_count.saturating_add(1);
210                Ok(())
211            }
212            ControlFlow::Break(()) => {
213                // Move the error out so the caller doesn't have to clone it
214                Err(std::mem::replace(
215                    &mut self.ctx.error,
216                    InvocationError::Dropped,
217                ))
218            }
219        }
220    }
221}
222
223// CircuitBreaker
224
225/// Internal state of a [`CircuitBreaker`].
226#[derive(Debug)]
227enum CbState {
228    /// Normal operation: counting consecutive failures.
229    Closed { consecutive_failures: u32 },
230    /// Breaker tripped: all calls rejected until cooldown expires.
231    Open { tripped_at: std::time::Instant },
232}
233
234/// A [`RetryPolicy`] that stops retrying after `threshold` consecutive
235/// failures and stays silent for a `cooldown` window before resetting.
236///
237/// # States
238/// - **Closed** (normal): forwards calls, increments a failure counter on
239///   each error, and applies an exponential back-off up to `threshold − 1`
240///   attempts.  On the `threshold`-th consecutive failure the breaker trips.
241/// - **Open** (tripped): rejects every call immediately (`Break`) for the
242///   duration of `cooldown`.
243/// - **Reset**: once `cooldown` has elapsed the breaker closes again and
244///   the failure counter resets to zero.
245///
246/// Because [`RetryPolicy`] has no success callback the breaker cannot
247/// distinguish a successful probe from a clean run; the counter simply
248/// resets when the cooldown expires.  For a full half-open probe you can
249/// wrap `CircuitBreaker` in a custom `RetryPolicy`.
250///
251/// # Example
252/// ```rust
253/// use ferogram_mtsender::CircuitBreaker;
254/// use std::time::Duration;
255///
256/// // Trip after 5 consecutive errors; stay open for 30 s.
257/// let policy = CircuitBreaker::new(5, Duration::from_secs(30));
258/// ```
259pub struct CircuitBreaker {
260    /// Number of consecutive failures before the breaker trips.
261    threshold: u32,
262    /// How long the breaker stays open before resetting.
263    cooldown: Duration,
264    state: std::sync::Mutex<CbState>,
265}
266
267impl CircuitBreaker {
268    /// Create a new `CircuitBreaker`.
269    ///
270    /// - `threshold`: failures before the breaker trips (minimum 1).
271    /// - `cooldown`: how long the breaker stays open.
272    pub fn new(threshold: u32, cooldown: Duration) -> Self {
273        assert!(
274            threshold >= 1,
275            "CircuitBreaker threshold must be at least 1"
276        );
277        Self {
278            threshold,
279            cooldown,
280            state: std::sync::Mutex::new(CbState::Closed {
281                consecutive_failures: 0,
282            }),
283        }
284    }
285}
286
287impl RetryPolicy for CircuitBreaker {
288    fn should_retry(&self, _ctx: &RetryContext) -> ControlFlow<(), Duration> {
289        let mut state = self.state.lock().expect("lock poisoned");
290        match &*state {
291            CbState::Open { tripped_at } => {
292                if tripped_at.elapsed() >= self.cooldown {
293                    // Cooldown expired: reset to Closed, allow retry with small delay.
294                    *state = CbState::Closed {
295                        consecutive_failures: 1,
296                    };
297                    ControlFlow::Continue(Duration::from_millis(200))
298                } else {
299                    // Still open: reject immediately.
300                    ControlFlow::Break(())
301                }
302            }
303            CbState::Closed {
304                consecutive_failures,
305            } => {
306                let new_count = consecutive_failures + 1;
307                if new_count >= self.threshold {
308                    tracing::warn!(
309                        "[ferogram] CircuitBreaker tripped after {new_count} consecutive failures"
310                    );
311                    *state = CbState::Open {
312                        tripped_at: std::time::Instant::now(),
313                    };
314                    ControlFlow::Break(())
315                } else {
316                    // Exponential back-off: 200 ms × 2^(n-1), capped at ~3 s.
317                    let backoff_ms = 200u64 * (1u64 << new_count.saturating_sub(1).min(4));
318                    *state = CbState::Closed {
319                        consecutive_failures: new_count,
320                    };
321                    ControlFlow::Continue(Duration::from_millis(backoff_ms))
322                }
323            }
324        }
325    }
326}
327
328// Tests
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::errors::RpcError;
334    use std::io;
335
336    fn flood(secs: u32) -> InvocationError {
337        InvocationError::Rpc(RpcError {
338            code: 420,
339            name: "FLOOD_WAIT".into(),
340            value: Some(secs),
341        })
342    }
343
344    fn io_err() -> InvocationError {
345        InvocationError::Io(io::Error::new(io::ErrorKind::ConnectionReset, "reset"))
346    }
347
348    fn rpc(code: i32, name: &str, value: Option<u32>) -> InvocationError {
349        InvocationError::Rpc(RpcError {
350            code,
351            name: name.into(),
352            value,
353        })
354    }
355
356    // NoRetries
357
358    #[test]
359    fn no_retries_always_breaks() {
360        let policy = NoRetries;
361        let ctx = RetryContext {
362            fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
363            slept_so_far: Duration::default(),
364            error: flood(10),
365        };
366        assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
367    }
368
369    // AutoSleep
370
371    #[test]
372    fn autosleep_retries_flood_under_threshold() {
373        let policy = AutoSleep::default(); // threshold = 60s
374        let ctx = RetryContext {
375            fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
376            slept_so_far: Duration::default(),
377            error: flood(30),
378        };
379        match policy.should_retry(&ctx) {
380            // Jitter of ±2s is applied; accept 28..=32 s.
381            ControlFlow::Continue(d) => {
382                let secs = d.as_secs_f64();
383                assert!(
384                    secs >= 28.0 && secs <= 32.0,
385                    "expected 28-32s delay (jitter), got {secs:.3}s"
386                );
387            }
388            other => panic!("expected Continue, got {other:?}"),
389        }
390    }
391
392    #[test]
393    fn autosleep_breaks_flood_over_threshold() {
394        let policy = AutoSleep::default(); // threshold = 60s
395        let ctx = RetryContext {
396            fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
397            slept_so_far: Duration::default(),
398            error: flood(120),
399        };
400        assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
401    }
402
403    #[test]
404    fn autosleep_second_flood_retry_is_honoured() {
405        let policy = AutoSleep::default();
406        let ctx = RetryContext {
407            fail_count: NonZeroU32::new(2).expect("2 is nonzero"),
408            slept_so_far: Duration::from_secs(30),
409            error: flood(30),
410        };
411        match policy.should_retry(&ctx) {
412            // Jitter of ±2s; accept 28..=32 s.
413            ControlFlow::Continue(d) => {
414                let secs = d.as_secs_f64();
415                assert!(
416                    secs >= 28.0 && secs <= 32.0,
417                    "expected 28-32s on second FLOOD_WAIT, got {secs:.3}s"
418                );
419            }
420            other => panic!("expected Continue on second FLOOD_WAIT, got {other:?}"),
421        }
422    }
423
424    #[test]
425    fn autosleep_retries_io_once() {
426        let policy = AutoSleep::default();
427        let ctx = RetryContext {
428            fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
429            slept_so_far: Duration::default(),
430            error: io_err(),
431        };
432        match policy.should_retry(&ctx) {
433            ControlFlow::Continue(d) => assert_eq!(d, Duration::from_secs(1)),
434            other => panic!("expected Continue, got {other:?}"),
435        }
436    }
437
438    #[test]
439    fn autosleep_no_io_retry_after_first() {
440        let policy = AutoSleep::default();
441        let ctx = RetryContext {
442            fail_count: NonZeroU32::new(4).expect("4 is nonzero"),
443            slept_so_far: Duration::from_secs(3),
444            error: io_err(),
445        };
446        assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
447    }
448
449    #[test]
450    fn autosleep_breaks_other_rpc() {
451        let policy = AutoSleep::default();
452        let ctx = RetryContext {
453            fail_count: NonZeroU32::new(1).expect("1 is nonzero"),
454            slept_so_far: Duration::default(),
455            error: rpc(400, "BAD_REQUEST", None),
456        };
457        assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
458    }
459
460    // RpcError::migrate_dc_id
461
462    #[test]
463    fn migrate_dc_id_detected() {
464        let e = RpcError {
465            code: 303,
466            name: "PHONE_MIGRATE".into(),
467            value: Some(5),
468        };
469        assert_eq!(e.migrate_dc_id(), Some(5));
470    }
471
472    #[test]
473    fn network_migrate_detected() {
474        let e = RpcError {
475            code: 303,
476            name: "NETWORK_MIGRATE".into(),
477            value: Some(3),
478        };
479        assert_eq!(e.migrate_dc_id(), Some(3));
480    }
481
482    #[test]
483    fn file_migrate_detected() {
484        let e = RpcError {
485            code: 303,
486            name: "FILE_MIGRATE".into(),
487            value: Some(4),
488        };
489        assert_eq!(e.migrate_dc_id(), Some(4));
490    }
491
492    #[test]
493    fn non_migrate_is_none() {
494        let e = RpcError {
495            code: 420,
496            name: "FLOOD_WAIT".into(),
497            value: Some(30),
498        };
499        assert_eq!(e.migrate_dc_id(), None);
500    }
501
502    #[test]
503    fn migrate_falls_back_to_dc2_when_no_value() {
504        let e = RpcError {
505            code: 303,
506            name: "PHONE_MIGRATE".into(),
507            value: None,
508        };
509        assert_eq!(e.migrate_dc_id(), Some(2));
510    }
511
512    // RetryLoop
513
514    #[tokio::test]
515    async fn retry_loop_gives_up_on_no_retries() {
516        let mut rl = RetryLoop::new(Arc::new(NoRetries));
517        let err = rpc(400, "SOMETHING_WRONG", None);
518        let result = rl.advance(err).await;
519        assert!(result.is_err());
520    }
521
522    #[tokio::test]
523    async fn retry_loop_increments_fail_count() {
524        let mut rl = RetryLoop::new(Arc::new(AutoSleep {
525            threshold: Duration::from_secs(60),
526            io_errors_as_flood_of: Some(Duration::from_millis(1)),
527        }));
528        assert!(rl.advance(io_err()).await.is_ok());
529        assert!(rl.advance(io_err()).await.is_err());
530    }
531
532    // CircuitBreaker
533
534    #[test]
535    fn circuit_breaker_trips_after_threshold() {
536        let cb = CircuitBreaker::new(3, Duration::from_secs(60));
537        let ctx = |n: u32| RetryContext {
538            fail_count: NonZeroU32::new(n).unwrap(),
539            slept_so_far: Duration::default(),
540            error: rpc(500, "INTERNAL", None),
541        };
542        // First two failures: Continue (backoff)
543        assert!(matches!(cb.should_retry(&ctx(1)), ControlFlow::Continue(_)));
544        assert!(matches!(cb.should_retry(&ctx(2)), ControlFlow::Continue(_)));
545        // Third: trips the breaker → Break
546        assert!(matches!(cb.should_retry(&ctx(3)), ControlFlow::Break(())));
547        // Subsequent calls while open → Break immediately
548        assert!(matches!(cb.should_retry(&ctx(4)), ControlFlow::Break(())));
549    }
550
551    #[test]
552    fn circuit_breaker_resets_after_cooldown() {
553        let cb = CircuitBreaker::new(2, Duration::from_millis(10));
554        let ctx = |n: u32| RetryContext {
555            fail_count: NonZeroU32::new(n).unwrap(),
556            slept_so_far: Duration::default(),
557            error: rpc(500, "INTERNAL", None),
558        };
559        // Trip the breaker
560        assert!(matches!(cb.should_retry(&ctx(1)), ControlFlow::Continue(_)));
561        assert!(matches!(cb.should_retry(&ctx(2)), ControlFlow::Break(())));
562        // Wait for cooldown
563        std::thread::sleep(Duration::from_millis(20));
564        // After cooldown: breaker resets → Continue again
565        assert!(matches!(cb.should_retry(&ctx(1)), ControlFlow::Continue(_)));
566    }
567}