1use std::task::{Poll, ready};
2use std::time::Duration;
3
4use moq_net::kio;
5use url::Url;
6
7use crate::{Client, Error};
8
9#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
11#[serde(default, deny_unknown_fields)]
12pub struct Backoff {
13 #[arg(
15 id = "backoff-initial",
16 long,
17 default_value = "1s",
18 env = "MOQ_BACKOFF_INITIAL",
19 value_parser = humantime::parse_duration,
20 )]
21 #[serde(with = "humantime_serde")]
22 pub initial: Duration,
23
24 #[arg(id = "backoff-multiplier", long, default_value_t = 2, env = "MOQ_BACKOFF_MULTIPLIER")]
26 pub multiplier: u32,
27
28 #[arg(
30 id = "backoff-max",
31 long,
32 default_value = "30s",
33 env = "MOQ_BACKOFF_MAX",
34 value_parser = humantime::parse_duration,
35 )]
36 #[serde(with = "humantime_serde")]
37 pub max: Duration,
38
39 #[arg(
42 id = "backoff-timeout",
43 long,
44 default_value = "5m",
45 env = "MOQ_BACKOFF_TIMEOUT",
46 value_parser = humantime::parse_duration,
47 )]
48 #[serde(with = "humantime_serde")]
49 pub timeout: Duration,
50}
51
52impl Default for Backoff {
53 fn default() -> Self {
54 Self {
55 initial: Duration::from_secs(1),
56 multiplier: 2,
57 max: Duration::from_secs(30),
58 timeout: Duration::from_secs(300),
59 }
60 }
61}
62
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
65pub enum Status {
66 Connected,
68 Disconnected,
70}
71
72#[derive(Default)]
77struct State {
78 status: Option<Status>,
80 error: Option<Error>,
82}
83
84pub struct Reconnect {
90 abort: tokio::task::AbortHandle,
91 state: kio::Consumer<State>,
92 last_reported: Option<Status>,
94}
95
96impl Reconnect {
97 pub(crate) fn new(client: Client, url: Url, backoff: Backoff) -> Self {
98 let producer = kio::Producer::<State>::default();
99 let state = producer.consume();
100 let task = tokio::spawn(async move {
101 if let Err(err) = Self::run(&producer, client, url, backoff).await {
102 tracing::error!(%err, "reconnect loop exited");
103 if let Ok(mut state) = producer.write() {
104 state.error = Some(err);
105 }
106 }
107 });
109 Self {
110 abort: task.abort_handle(),
111 state,
112 last_reported: None,
113 }
114 }
115
116 async fn run(state: &kio::Producer<State>, client: Client, url: Url, backoff: Backoff) -> crate::Result<()> {
117 let mut delay = backoff.initial;
118 let mut retry_start = tokio::time::Instant::now();
119 let mut last_error: Option<Error> = None;
120
121 loop {
122 if !backoff.timeout.is_zero() && retry_start.elapsed() > backoff.timeout {
123 let timeout = backoff.timeout;
124 let msg = match last_error {
125 Some(err) => format!("reconnect timed out after {timeout:?}: {err}"),
126 None => format!("reconnect timed out after {timeout:?}"),
127 };
128 return Err(Error::Reconnect(msg));
129 }
130
131 tracing::info!(%url, "connecting");
132
133 match client.connect(url.clone()).await {
134 Ok(session) => {
135 tracing::info!(%url, "connected");
136 delay = backoff.initial;
137 last_error = None;
138 if let Ok(mut state) = state.write() {
139 state.status = Some(Status::Connected);
140 }
141 let _ = session.closed().await;
142 tracing::warn!(%url, "session closed, reconnecting");
143 if let Ok(mut state) = state.write() {
144 state.status = Some(Status::Disconnected);
145 }
146 retry_start = tokio::time::Instant::now();
147 }
148 Err(err) => {
149 tracing::warn!(%url, %err, ?delay, "connection failed, retrying");
150 last_error = Some(err);
151 tokio::time::sleep(delay).await;
152 delay = std::cmp::min(delay * backoff.multiplier, backoff.max);
153 }
154 }
155 }
156 }
157
158 pub fn poll_status(&mut self, waiter: &kio::Waiter) -> Poll<crate::Result<Status>> {
163 let last = self.last_reported;
164 let status = match ready!(self.state.poll(waiter, |state| match state.status {
165 Some(status) if Some(status) != last => Poll::Ready(status),
166 _ => Poll::Pending,
167 })) {
168 Ok(status) => status,
169 Err(state) => return Poll::Ready(Err(terminal(&state))),
170 };
171
172 self.last_reported = Some(status);
173 Poll::Ready(Ok(status))
174 }
175
176 pub async fn status(&mut self) -> crate::Result<Status> {
182 kio::wait(|waiter| self.poll_status(waiter)).await
183 }
184
185 pub fn poll_closed(&self, waiter: &kio::Waiter) -> Poll<crate::Result<()>> {
190 ready!(self.state.poll_closed(waiter));
191 Poll::Ready(match &self.state.read().error {
192 Some(err) => Err(err.clone()),
193 None => Ok(()),
194 })
195 }
196
197 pub async fn closed(&self) -> crate::Result<()> {
199 kio::wait(|waiter| self.poll_closed(waiter)).await
200 }
201}
202
203impl Drop for Reconnect {
204 fn drop(&mut self) {
205 self.abort.abort();
206 }
207}
208
209fn terminal(state: &State) -> Error {
211 match &state.error {
212 Some(err) => err.clone(),
213 None => Error::Reconnect("reconnect stopped".to_string()),
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_backoff_default() {
223 let backoff = Backoff::default();
224 assert_eq!(backoff.initial, Duration::from_secs(1));
225 assert_eq!(backoff.multiplier, 2);
226 assert_eq!(backoff.max, Duration::from_secs(30));
227 assert_eq!(backoff.timeout, Duration::from_secs(300));
228 }
229}