1use std::sync::Arc;
2use std::time::Duration;
3
4use url::Url;
5
6use crate::Client;
7
8#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
10#[serde(default, deny_unknown_fields)]
11pub struct Backoff {
12 #[arg(
14 id = "backoff-initial",
15 long,
16 default_value = "1s",
17 env = "MOQ_BACKOFF_INITIAL",
18 value_parser = humantime::parse_duration,
19 )]
20 #[serde(with = "humantime_serde")]
21 pub initial: Duration,
22
23 #[arg(id = "backoff-multiplier", long, default_value_t = 2, env = "MOQ_BACKOFF_MULTIPLIER")]
25 pub multiplier: u32,
26
27 #[arg(
29 id = "backoff-max",
30 long,
31 default_value = "30s",
32 env = "MOQ_BACKOFF_MAX",
33 value_parser = humantime::parse_duration,
34 )]
35 #[serde(with = "humantime_serde")]
36 pub max: Duration,
37
38 #[arg(
41 id = "backoff-timeout",
42 long,
43 default_value = "5m",
44 env = "MOQ_BACKOFF_TIMEOUT",
45 value_parser = humantime::parse_duration,
46 )]
47 #[serde(with = "humantime_serde")]
48 pub timeout: Duration,
49}
50
51impl Default for Backoff {
52 fn default() -> Self {
53 Self {
54 initial: Duration::from_secs(1),
55 multiplier: 2,
56 max: Duration::from_secs(30),
57 timeout: Duration::from_secs(300),
58 }
59 }
60}
61
62pub struct Reconnect {
67 abort: tokio::task::AbortHandle,
68 closed_rx: tokio::sync::watch::Receiver<Option<Arc<anyhow::Error>>>,
69}
70
71impl Reconnect {
72 pub(crate) fn new(client: Client, url: Url, backoff: Backoff) -> Self {
73 let (closed_tx, closed_rx) = tokio::sync::watch::channel(None::<Arc<anyhow::Error>>);
74 let task = tokio::spawn(async move {
75 if let Err(err) = Self::run(client, url, backoff).await {
76 tracing::error!(err = %format!("{err:#}"), "reconnect loop exited");
77 let _ = closed_tx.send(Some(Arc::new(err)));
78 }
79 });
80 Self {
81 abort: task.abort_handle(),
82 closed_rx,
83 }
84 }
85
86 async fn run(client: Client, url: Url, backoff: Backoff) -> anyhow::Result<()> {
87 let mut delay = backoff.initial;
88 let mut retry_start = tokio::time::Instant::now();
89 let mut last_error: Option<anyhow::Error> = None;
90
91 loop {
92 if !backoff.timeout.is_zero() && retry_start.elapsed() > backoff.timeout {
93 let timeout = backoff.timeout;
94 return Err(last_error
95 .map(|e| e.context(format!("reconnect timed out after {timeout:?}")))
96 .unwrap_or_else(|| anyhow::anyhow!("reconnect timed out after {timeout:?}")));
97 }
98
99 tracing::info!(%url, "connecting");
100
101 match client.connect(url.clone()).await {
102 Ok(session) => {
103 tracing::info!(%url, "connected");
104 delay = backoff.initial;
105 last_error = None;
106 let _ = session.closed().await;
107 tracing::warn!(%url, "session closed, reconnecting");
108 retry_start = tokio::time::Instant::now();
109 }
110 Err(err) => {
111 tracing::warn!(%url, %err, ?delay, "connection failed, retrying");
112 last_error = Some(err);
113 tokio::time::sleep(delay).await;
114 delay = std::cmp::min(delay * backoff.multiplier, backoff.max);
115 }
116 }
117 }
118 }
119
120 pub async fn closed(&self) -> anyhow::Result<()> {
126 let mut rx = self.closed_rx.clone();
127 match rx.wait_for(|v| v.is_some()).await {
128 Ok(v) => {
129 let err = Arc::clone(v.as_ref().expect("predicate matched Some"));
130 Err(anyhow::anyhow!("{err:#}"))
131 }
132 Err(_) => Ok(()),
133 }
134 }
135
136 pub fn close(self) {
138 self.abort.abort();
139 }
140}
141
142impl Drop for Reconnect {
143 fn drop(&mut self) {
144 self.abort.abort();
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn test_backoff_default() {
154 let backoff = Backoff::default();
155 assert_eq!(backoff.initial, Duration::from_secs(1));
156 assert_eq!(backoff.multiplier, 2);
157 assert_eq!(backoff.max, Duration::from_secs(30));
158 assert_eq!(backoff.timeout, Duration::from_secs(300));
159 }
160}