1use std::time::Duration;
2
3use url::Url;
4
5use crate::Client;
6
7#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
9#[serde(default, deny_unknown_fields)]
10pub struct Backoff {
11 #[arg(
13 id = "backoff-initial",
14 long,
15 default_value = "1s",
16 env = "MOQ_BACKOFF_INITIAL",
17 value_parser = humantime::parse_duration,
18 )]
19 #[serde(with = "humantime_serde")]
20 pub initial: Duration,
21
22 #[arg(id = "backoff-multiplier", long, default_value_t = 2, env = "MOQ_BACKOFF_MULTIPLIER")]
24 pub multiplier: u32,
25
26 #[arg(
28 id = "backoff-max",
29 long,
30 default_value = "30s",
31 env = "MOQ_BACKOFF_MAX",
32 value_parser = humantime::parse_duration,
33 )]
34 #[serde(with = "humantime_serde")]
35 pub max: Duration,
36
37 #[arg(
40 id = "backoff-timeout",
41 long,
42 env = "MOQ_BACKOFF_TIMEOUT",
43 value_parser = humantime::parse_duration,
44 )]
45 #[serde(default, with = "humantime_serde", skip_serializing_if = "Option::is_none")]
46 pub timeout: Option<Duration>,
47}
48
49impl Default for Backoff {
50 fn default() -> Self {
51 Self {
52 initial: Duration::from_secs(1),
53 multiplier: 2,
54 max: Duration::from_secs(30),
55 timeout: None,
56 }
57 }
58}
59
60pub struct Reconnect {
65 abort: tokio::task::AbortHandle,
66 closed_rx: tokio::sync::watch::Receiver<bool>,
67}
68
69impl Reconnect {
70 pub(crate) fn new(client: Client, url: Url, backoff: Backoff) -> Self {
71 let (closed_tx, closed_rx) = tokio::sync::watch::channel(false);
72 let task = tokio::spawn(async move {
73 if let Err(err) = Self::run(client, url, backoff).await {
74 tracing::error!(%err, "reconnect loop exited");
75 let _ = closed_tx.send(true);
76 }
77 });
78 Self {
79 abort: task.abort_handle(),
80 closed_rx,
81 }
82 }
83
84 async fn run(client: Client, url: Url, backoff: Backoff) -> anyhow::Result<()> {
85 let mut delay = backoff.initial;
86 let mut retry_start = tokio::time::Instant::now();
87
88 loop {
89 if let Some(timeout) = backoff.timeout {
90 if retry_start.elapsed() > timeout {
91 anyhow::bail!("reconnect timed out after {timeout:?}");
92 }
93 }
94
95 tracing::info!(%url, "connecting");
96
97 match client.connect(url.clone()).await {
98 Ok(session) => {
99 tracing::info!(%url, "connected");
100 delay = backoff.initial;
101 let _ = session.closed().await;
102 tracing::warn!(%url, "session closed, reconnecting");
103 retry_start = tokio::time::Instant::now();
104 }
105 Err(err) => {
106 tracing::warn!(%url, %err, ?delay, "connection failed, retrying");
107 tokio::time::sleep(delay).await;
108 delay = std::cmp::min(delay * backoff.multiplier, backoff.max);
109 }
110 }
111 }
112 }
113
114 pub async fn closed(&self) -> anyhow::Result<()> {
119 let mut rx = self.closed_rx.clone();
120 match rx.wait_for(|&v| v).await {
121 Ok(_) => anyhow::bail!("reconnect timed out"),
122 Err(_) => Ok(()),
123 }
124 }
125
126 pub fn close(self) {
128 self.abort.abort();
129 }
130}
131
132impl Drop for Reconnect {
133 fn drop(&mut self) {
134 self.abort.abort();
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_backoff_default() {
144 let backoff = Backoff::default();
145 assert_eq!(backoff.initial, Duration::from_secs(1));
146 assert_eq!(backoff.multiplier, 2);
147 assert_eq!(backoff.max, Duration::from_secs(30));
148 assert_eq!(backoff.timeout, None);
149 }
150}