Skip to main content

layer_client/
retry.rs

1//! Retry policies for handling `FLOOD_WAIT`, transient I/O errors, and DC-migration redirects.
2
3use std::num::NonZeroU32;
4use std::ops::ControlFlow;
5use std::sync::Arc;
6use std::time::Duration;
7
8use tokio::time::sleep;
9
10use crate::errors::InvocationError;
11
12/// Extension methods on [`crate::errors::RpcError`] for routing decisions.
13impl crate::errors::RpcError {
14    /// If this is a DC-migration redirect (code 303), returns the target DC id.
15    ///
16    /// Telegram sends these for:
17    /// - `PHONE_MIGRATE_X`  : user's home DC during auth
18    /// - `NETWORK_MIGRATE_X`: general redirect
19    /// - `FILE_MIGRATE_X`   : file download/upload DC
20    /// - `USER_MIGRATE_X`   : account migration
21    ///
22    /// All have `code == 303` and a numeric suffix that is the DC id.
23    pub fn migrate_dc_id(&self) -> Option<i32> {
24        if self.code != 303 {
25            return None;
26        }
27        //  pattern: any *_MIGRATE_* name with a numeric value
28        let is_migrate = self.name == "PHONE_MIGRATE"
29            || self.name == "NETWORK_MIGRATE"
30            || self.name == "FILE_MIGRATE"
31            || self.name == "USER_MIGRATE"
32            || self.name.ends_with("_MIGRATE");
33        if is_migrate {
34            // value is the DC id; fall back to DC 2 (Amsterdam) if missing
35            Some(self.value.unwrap_or(2) as i32)
36        } else {
37            None
38        }
39    }
40}
41
42/// Extension on [`InvocationError`] for migrate detection.
43impl InvocationError {
44    /// If this error is a DC-migration redirect, returns the target DC id.
45    pub fn migrate_dc_id(&self) -> Option<i32> {
46        match self {
47            Self::Rpc(r) => r.migrate_dc_id(),
48            _ => None,
49        }
50    }
51}
52
53// RetryPolicy trait
54
55/// Controls how the client reacts when an RPC call fails.
56///
57/// Implement this trait to provide custom flood-wait handling, circuit
58/// breakers, or exponential back-off.
59pub trait RetryPolicy: Send + Sync + 'static {
60    /// Decide whether to retry the failed request.
61    ///
62    /// Return `ControlFlow::Continue(delay)` to sleep `delay` and retry.
63    /// Return `ControlFlow::Break(())` to propagate `ctx.error` to the caller.
64    fn should_retry(&self, ctx: &RetryContext) -> ControlFlow<(), Duration>;
65}
66
67/// Context passed to [`RetryPolicy::should_retry`] on each failure.
68pub struct RetryContext {
69    /// Number of times this request has failed (starts at 1).
70    pub fail_count: NonZeroU32,
71    /// Total time already slept for this request across all prior retries.
72    pub slept_so_far: Duration,
73    /// The most recent error.
74    pub error: InvocationError,
75}
76
77// Built-in policies
78
79/// Never retry: propagate every error immediately.
80pub struct NoRetries;
81
82impl RetryPolicy for NoRetries {
83    fn should_retry(&self, _: &RetryContext) -> ControlFlow<(), Duration> {
84        ControlFlow::Break(())
85    }
86}
87
88/// Automatically sleep on `FLOOD_WAIT` and retry once on transient I/O errors.
89///
90/// Default retry policy. Sleeps on `FLOOD_WAIT`, backs off on I/O errors.
91///
92/// ```rust
93/// # use layer_client::retry::AutoSleep;
94/// let policy = AutoSleep {
95/// threshold: std::time::Duration::from_secs(60),
96/// io_errors_as_flood_of: Some(std::time::Duration::from_secs(1)),
97/// };
98/// ```
99pub struct AutoSleep {
100    /// Maximum flood-wait the library will automatically sleep through.
101    ///
102    /// If Telegram asks us to wait longer than this, the error is propagated.
103    pub threshold: Duration,
104
105    /// If `Some(d)`, treat the first I/O error as a `d`-second flood wait
106    /// and retry once.  `None` propagates I/O errors immediately.
107    pub io_errors_as_flood_of: Option<Duration>,
108}
109
110impl Default for AutoSleep {
111    fn default() -> Self {
112        Self {
113            threshold: Duration::from_secs(60),
114            io_errors_as_flood_of: Some(Duration::from_secs(1)),
115        }
116    }
117}
118
119impl RetryPolicy for AutoSleep {
120    fn should_retry(&self, ctx: &RetryContext) -> ControlFlow<(), Duration> {
121        match &ctx.error {
122            // FLOOD_WAIT: sleep exactly as long as Telegram asks, for every
123            // occurrence up to threshold. Removing the fail_count==1 guard
124            // means a second consecutive FLOOD_WAIT is also honoured rather
125            // than propagated immediately.
126            InvocationError::Rpc(rpc) if rpc.code == 420 && rpc.name == "FLOOD_WAIT" => {
127                let secs = rpc.value.unwrap_or(0) as u64;
128                if secs <= self.threshold.as_secs() {
129                    tracing::info!("FLOOD_WAIT_{secs}: sleeping before retry");
130                    ControlFlow::Continue(Duration::from_secs(secs))
131                } else {
132                    ControlFlow::Break(())
133                }
134            }
135
136            // SLOWMODE_WAIT: same semantics as FLOOD_WAIT; very common in
137            // group bots that send messages faster than the channel's slowmode.
138            InvocationError::Rpc(rpc) if rpc.code == 420 && rpc.name == "SLOWMODE_WAIT" => {
139                let secs = rpc.value.unwrap_or(0) as u64;
140                if secs <= self.threshold.as_secs() {
141                    tracing::info!("SLOWMODE_WAIT_{secs}: sleeping before retry");
142                    ControlFlow::Continue(Duration::from_secs(secs))
143                } else {
144                    ControlFlow::Break(())
145                }
146            }
147
148            // Transient I/O errors: back off briefly and retry once.
149            InvocationError::Io(_) if ctx.fail_count.get() == 1 => {
150                if let Some(d) = self.io_errors_as_flood_of {
151                    tracing::info!("I/O error: sleeping {d:?} before retry");
152                    ControlFlow::Continue(d)
153                } else {
154                    ControlFlow::Break(())
155                }
156            }
157
158            _ => ControlFlow::Break(()),
159        }
160    }
161}
162
163// RetryLoop
164
165/// Drives the retry loop for a single RPC call.
166///
167/// Create one per call, then call `advance` after every failure.
168///
169/// ```rust,ignore
170/// let mut rl = RetryLoop::new(Arc::clone(&self.inner.retry_policy));
171/// loop {
172/// match self.do_rpc_call(req).await {
173///     Ok(body) => return Ok(body),
174///     Err(e)   => rl.advance(e).await?,
175/// }
176/// }
177/// ```
178///
179/// `advance` either:
180/// - sleeps the required duration and returns `Ok(())` → caller should retry, or
181/// - returns `Err(e)` → caller should propagate.
182///
183/// This is the single source of truth; previously the same loop was
184/// copy-pasted into `rpc_call_raw`, `rpc_write`, and the reconnect path.
185pub struct RetryLoop {
186    policy: Arc<dyn RetryPolicy>,
187    ctx: RetryContext,
188}
189
190impl RetryLoop {
191    pub fn new(policy: Arc<dyn RetryPolicy>) -> Self {
192        Self {
193            policy,
194            ctx: RetryContext {
195                fail_count: NonZeroU32::new(1).unwrap(),
196                slept_so_far: Duration::default(),
197                error: InvocationError::Dropped,
198            },
199        }
200    }
201
202    /// Record a failure and either sleep+return-Ok (retry) or return-Err (give up).
203    ///
204    /// Mutates `self` to track cumulative state across retries.
205    pub async fn advance(&mut self, err: InvocationError) -> Result<(), InvocationError> {
206        self.ctx.error = err;
207        match self.policy.should_retry(&self.ctx) {
208            ControlFlow::Continue(delay) => {
209                sleep(delay).await;
210                self.ctx.slept_so_far += delay;
211                // saturating_add: if somehow we overflow NonZeroU32, clamp at MAX
212                self.ctx.fail_count = self.ctx.fail_count.saturating_add(1);
213                Ok(())
214            }
215            ControlFlow::Break(()) => {
216                // Move the error out so the caller doesn't have to clone it
217                Err(std::mem::replace(
218                    &mut self.ctx.error,
219                    InvocationError::Dropped,
220                ))
221            }
222        }
223    }
224}
225
226// Tests
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use std::io;
232
233    fn flood(secs: u32) -> InvocationError {
234        InvocationError::Rpc(crate::errors::RpcError {
235            code: 420,
236            name: "FLOOD_WAIT".into(),
237            value: Some(secs),
238        })
239    }
240
241    fn io_err() -> InvocationError {
242        InvocationError::Io(io::Error::new(io::ErrorKind::ConnectionReset, "reset"))
243    }
244
245    fn rpc(code: i32, name: &str, value: Option<u32>) -> InvocationError {
246        InvocationError::Rpc(crate::errors::RpcError {
247            code,
248            name: name.into(),
249            value,
250        })
251    }
252
253    // NoRetries
254
255    #[test]
256    fn no_retries_always_breaks() {
257        let policy = NoRetries;
258        let ctx = RetryContext {
259            fail_count: NonZeroU32::new(1).unwrap(),
260            slept_so_far: Duration::default(),
261            error: flood(10),
262        };
263        assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
264    }
265
266    // AutoSleep
267
268    #[test]
269    fn autosleep_retries_flood_under_threshold() {
270        let policy = AutoSleep::default(); // threshold = 60s
271        let ctx = RetryContext {
272            fail_count: NonZeroU32::new(1).unwrap(),
273            slept_so_far: Duration::default(),
274            error: flood(30),
275        };
276        match policy.should_retry(&ctx) {
277            ControlFlow::Continue(d) => assert_eq!(d, Duration::from_secs(30)),
278            other => panic!("expected Continue, got {other:?}"),
279        }
280    }
281
282    #[test]
283    fn autosleep_breaks_flood_over_threshold() {
284        let policy = AutoSleep::default(); // threshold = 60s
285        let ctx = RetryContext {
286            fail_count: NonZeroU32::new(1).unwrap(),
287            slept_so_far: Duration::default(),
288            error: flood(120),
289        };
290        assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
291    }
292
293    #[test]
294    fn autosleep_no_second_flood_retry() {
295        let policy = AutoSleep::default();
296        // fail_count == 2 → already retried once, should give up
297        let ctx = RetryContext {
298            fail_count: NonZeroU32::new(2).unwrap(),
299            slept_so_far: Duration::from_secs(30),
300            error: flood(30),
301        };
302        assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
303    }
304
305    #[test]
306    fn autosleep_retries_io_once() {
307        let policy = AutoSleep::default();
308        let ctx = RetryContext {
309            fail_count: NonZeroU32::new(1).unwrap(),
310            slept_so_far: Duration::default(),
311            error: io_err(),
312        };
313        match policy.should_retry(&ctx) {
314            ControlFlow::Continue(d) => assert_eq!(d, Duration::from_secs(1)),
315            other => panic!("expected Continue, got {other:?}"),
316        }
317    }
318
319    #[test]
320    fn autosleep_no_second_io_retry() {
321        let policy = AutoSleep::default();
322        let ctx = RetryContext {
323            fail_count: NonZeroU32::new(2).unwrap(),
324            slept_so_far: Duration::from_secs(1),
325            error: io_err(),
326        };
327        assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
328    }
329
330    #[test]
331    fn autosleep_breaks_other_rpc() {
332        let policy = AutoSleep::default();
333        let ctx = RetryContext {
334            fail_count: NonZeroU32::new(1).unwrap(),
335            slept_so_far: Duration::default(),
336            error: rpc(400, "BAD_REQUEST", None),
337        };
338        assert!(matches!(policy.should_retry(&ctx), ControlFlow::Break(())));
339    }
340
341    // RpcError::migrate_dc_id
342
343    #[test]
344    fn migrate_dc_id_detected() {
345        let e = crate::errors::RpcError {
346            code: 303,
347            name: "PHONE_MIGRATE".into(),
348            value: Some(5),
349        };
350        assert_eq!(e.migrate_dc_id(), Some(5));
351    }
352
353    #[test]
354    fn network_migrate_detected() {
355        let e = crate::errors::RpcError {
356            code: 303,
357            name: "NETWORK_MIGRATE".into(),
358            value: Some(3),
359        };
360        assert_eq!(e.migrate_dc_id(), Some(3));
361    }
362
363    #[test]
364    fn file_migrate_detected() {
365        let e = crate::errors::RpcError {
366            code: 303,
367            name: "FILE_MIGRATE".into(),
368            value: Some(4),
369        };
370        assert_eq!(e.migrate_dc_id(), Some(4));
371    }
372
373    #[test]
374    fn non_migrate_is_none() {
375        let e = crate::errors::RpcError {
376            code: 420,
377            name: "FLOOD_WAIT".into(),
378            value: Some(30),
379        };
380        assert_eq!(e.migrate_dc_id(), None);
381    }
382
383    #[test]
384    fn migrate_falls_back_to_dc2_when_no_value() {
385        let e = crate::errors::RpcError {
386            code: 303,
387            name: "PHONE_MIGRATE".into(),
388            value: None,
389        };
390        assert_eq!(e.migrate_dc_id(), Some(2));
391    }
392
393    // RetryLoop
394
395    #[tokio::test]
396    async fn retry_loop_gives_up_on_no_retries() {
397        let mut rl = RetryLoop::new(Arc::new(NoRetries));
398        let err = rpc(400, "SOMETHING_WRONG", None);
399        let result = rl.advance(err).await;
400        assert!(result.is_err());
401    }
402
403    #[tokio::test]
404    async fn retry_loop_increments_fail_count() {
405        let mut rl = RetryLoop::new(Arc::new(AutoSleep {
406            threshold: Duration::from_secs(60),
407            io_errors_as_flood_of: Some(Duration::from_millis(1)),
408        }));
409        assert!(rl.advance(io_err()).await.is_ok());
410        assert!(rl.advance(io_err()).await.is_err());
411    }
412}