1use std::task::{Poll, ready};
2use std::time::Duration;
3
4use moq_net::kio;
5use url::Url;
6
7use crate::Client;
8
9#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
11#[serde(default, deny_unknown_fields)]
12pub struct Backoff {
13 #[arg(
15 id = "backoff-initial",
16 long,
17 default_value = "1s",
18 env = "MOQ_BACKOFF_INITIAL",
19 value_parser = humantime::parse_duration,
20 )]
21 #[serde(with = "humantime_serde")]
22 pub initial: Duration,
23
24 #[arg(id = "backoff-multiplier", long, default_value_t = 2, env = "MOQ_BACKOFF_MULTIPLIER")]
26 pub multiplier: u32,
27
28 #[arg(
30 id = "backoff-max",
31 long,
32 default_value = "30s",
33 env = "MOQ_BACKOFF_MAX",
34 value_parser = humantime::parse_duration,
35 )]
36 #[serde(with = "humantime_serde")]
37 pub max: Duration,
38
39 #[arg(
42 id = "backoff-timeout",
43 long,
44 default_value = "5m",
45 env = "MOQ_BACKOFF_TIMEOUT",
46 value_parser = humantime::parse_duration,
47 )]
48 #[serde(with = "humantime_serde")]
49 pub timeout: Duration,
50}
51
52impl Default for Backoff {
53 fn default() -> Self {
54 Self {
55 initial: Duration::from_secs(1),
56 multiplier: 2,
57 max: Duration::from_secs(30),
58 timeout: Duration::from_secs(300),
59 }
60 }
61}
62
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
65pub enum Status {
66 Connected,
68 Disconnected,
70}
71
72#[derive(Default)]
77struct State {
78 status: Option<Status>,
80 error: Option<anyhow::Error>,
82}
83
84pub struct Reconnect {
90 abort: tokio::task::AbortHandle,
91 state: kio::Consumer<State>,
92 last_reported: Option<Status>,
94}
95
96impl Reconnect {
97 pub(crate) fn new(client: Client, url: Url, backoff: Backoff) -> Self {
98 let producer = kio::Producer::<State>::default();
99 let state = producer.consume();
100 let task = tokio::spawn(async move {
101 if let Err(err) = Self::run(&producer, client, url, backoff).await {
102 tracing::error!(err = %format!("{err:#}"), "reconnect loop exited");
103 if let Ok(mut state) = producer.write() {
104 state.error = Some(err);
105 }
106 }
107 });
109 Self {
110 abort: task.abort_handle(),
111 state,
112 last_reported: None,
113 }
114 }
115
116 async fn run(state: &kio::Producer<State>, client: Client, url: Url, backoff: Backoff) -> anyhow::Result<()> {
117 let mut delay = backoff.initial;
118 let mut retry_start = tokio::time::Instant::now();
119 let mut last_error: Option<anyhow::Error> = None;
120
121 loop {
122 if !backoff.timeout.is_zero() && retry_start.elapsed() > backoff.timeout {
123 let timeout = backoff.timeout;
124 return Err(last_error
125 .map(|e| e.context(format!("reconnect timed out after {timeout:?}")))
126 .unwrap_or_else(|| anyhow::anyhow!("reconnect timed out after {timeout:?}")));
127 }
128
129 tracing::info!(%url, "connecting");
130
131 match client.connect(url.clone()).await {
132 Ok(session) => {
133 tracing::info!(%url, "connected");
134 delay = backoff.initial;
135 last_error = None;
136 if let Ok(mut state) = state.write() {
137 state.status = Some(Status::Connected);
138 }
139 let _ = session.closed().await;
140 tracing::warn!(%url, "session closed, reconnecting");
141 if let Ok(mut state) = state.write() {
142 state.status = Some(Status::Disconnected);
143 }
144 retry_start = tokio::time::Instant::now();
145 }
146 Err(err) => {
147 tracing::warn!(%url, %err, ?delay, "connection failed, retrying");
148 last_error = Some(err);
149 tokio::time::sleep(delay).await;
150 delay = std::cmp::min(delay * backoff.multiplier, backoff.max);
151 }
152 }
153 }
154 }
155
156 pub fn poll_status(&mut self, waiter: &kio::Waiter) -> Poll<anyhow::Result<Status>> {
161 let last = self.last_reported;
162 let status = match ready!(self.state.poll(waiter, |state| match state.status {
163 Some(status) if Some(status) != last => Poll::Ready(status),
164 _ => Poll::Pending,
165 })) {
166 Ok(status) => status,
167 Err(state) => return Poll::Ready(Err(terminal(&state))),
168 };
169
170 self.last_reported = Some(status);
171 Poll::Ready(Ok(status))
172 }
173
174 pub async fn status(&mut self) -> anyhow::Result<Status> {
180 kio::wait(|waiter| self.poll_status(waiter)).await
181 }
182
183 pub fn poll_closed(&self, waiter: &kio::Waiter) -> Poll<anyhow::Result<()>> {
188 ready!(self.state.poll_closed(waiter));
189 Poll::Ready(match &self.state.read().error {
190 Some(err) => Err(anyhow::anyhow!("{err:#}")),
191 None => Ok(()),
192 })
193 }
194
195 pub async fn closed(&self) -> anyhow::Result<()> {
197 kio::wait(|waiter| self.poll_closed(waiter)).await
198 }
199}
200
201impl Drop for Reconnect {
202 fn drop(&mut self) {
203 self.abort.abort();
204 }
205}
206
207fn terminal(state: &State) -> anyhow::Error {
209 match &state.error {
210 Some(err) => anyhow::anyhow!("{err:#}"),
211 None => anyhow::anyhow!("reconnect stopped"),
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_backoff_default() {
221 let backoff = Backoff::default();
222 assert_eq!(backoff.initial, Duration::from_secs(1));
223 assert_eq!(backoff.multiplier, 2);
224 assert_eq!(backoff.max, Duration::from_secs(30));
225 assert_eq!(backoff.timeout, Duration::from_secs(300));
226 }
227}