1use std::task::{Poll, ready};
2use std::time::Duration;
3
4use moq_net::kio;
5use url::Url;
6
7use crate::{Client, Error};
8
9#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
11#[serde(default, deny_unknown_fields)]
12pub struct Backoff {
13 #[arg(
15 id = "backoff-initial",
16 long,
17 default_value = "1s",
18 env = "MOQ_BACKOFF_INITIAL",
19 value_parser = humantime::parse_duration,
20 )]
21 #[serde(with = "humantime_serde")]
22 pub initial: Duration,
23
24 #[arg(id = "backoff-multiplier", long, default_value_t = 2, env = "MOQ_BACKOFF_MULTIPLIER")]
26 pub multiplier: u32,
27
28 #[arg(
30 id = "backoff-max",
31 long,
32 default_value = "30s",
33 env = "MOQ_BACKOFF_MAX",
34 value_parser = humantime::parse_duration,
35 )]
36 #[serde(with = "humantime_serde")]
37 pub max: Duration,
38
39 #[arg(
42 id = "backoff-timeout",
43 long,
44 default_value = "5m",
45 env = "MOQ_BACKOFF_TIMEOUT",
46 value_parser = humantime::parse_duration,
47 )]
48 #[serde(with = "humantime_serde")]
49 pub timeout: Duration,
50}
51
52impl Default for Backoff {
53 fn default() -> Self {
54 Self {
55 initial: Duration::from_secs(1),
56 multiplier: 2,
57 max: Duration::from_secs(30),
58 timeout: Duration::from_secs(300),
59 }
60 }
61}
62
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
65pub enum Status {
66 Connected,
68 Disconnected,
70}
71
72#[derive(Default)]
77struct State {
78 status: Option<Status>,
80 error: Option<Error>,
82}
83
84pub struct Reconnect {
90 abort: tokio::task::AbortHandle,
91 state: kio::Consumer<State>,
92 last_reported: Option<Status>,
94}
95
96impl Reconnect {
97 pub(crate) fn new(client: Client, url: Url, backoff: Backoff) -> Self {
98 let producer = kio::Producer::<State>::default();
99 let state = producer.consume();
100 let task = tokio::spawn(async move {
101 if let Err(err) = Self::run(&producer, client, url, backoff).await {
102 tracing::error!(%err, "reconnect loop exited");
103 if let Ok(mut state) = producer.write() {
104 state.error = Some(err);
105 }
106 }
107 });
109 Self {
110 abort: task.abort_handle(),
111 state,
112 last_reported: None,
113 }
114 }
115
116 async fn run(state: &kio::Producer<State>, client: Client, url: Url, backoff: Backoff) -> crate::Result<()> {
117 let mut delay = backoff.initial;
118 let mut retry_start = tokio::time::Instant::now();
119 let mut last_error: Option<Error> = None;
120
121 loop {
122 if !backoff.timeout.is_zero() && retry_start.elapsed() > backoff.timeout {
123 let timeout = backoff.timeout;
124 let msg = match last_error {
125 Some(err) => format!("reconnect timed out after {timeout:?}: {err}"),
126 None => format!("reconnect timed out after {timeout:?}"),
127 };
128 return Err(Error::Reconnect(msg));
129 }
130
131 tracing::info!(%url, "connecting");
132
133 match client.connect(url.clone()).await {
134 Ok(session) => {
135 tracing::info!(%url, "connected");
136 delay = backoff.initial;
137 last_error = None;
138 if let Ok(mut state) = state.write() {
139 state.status = Some(Status::Connected);
140 }
141 let _ = session.closed().await;
142 tracing::warn!(%url, "session closed, reconnecting");
143 if let Ok(mut state) = state.write() {
144 state.status = Some(Status::Disconnected);
145 }
146 retry_start = tokio::time::Instant::now();
147 }
148 Err(err) => {
149 if err.is_auth() {
150 return Err(err);
151 }
152 tracing::warn!(%url, %err, ?delay, "connection failed, retrying");
153 last_error = Some(err);
154 tokio::time::sleep(delay).await;
155 delay = std::cmp::min(delay * backoff.multiplier, backoff.max);
156 }
157 }
158 }
159 }
160
161 pub fn poll_status(&mut self, waiter: &kio::Waiter) -> Poll<crate::Result<Status>> {
166 let last = self.last_reported;
167 let status = match ready!(self.state.poll(waiter, |state| match state.status {
168 Some(status) if Some(status) != last => Poll::Ready(status),
169 _ => Poll::Pending,
170 })) {
171 Ok(status) => status,
172 Err(state) => return Poll::Ready(Err(terminal(&state))),
173 };
174
175 self.last_reported = Some(status);
176 Poll::Ready(Ok(status))
177 }
178
179 pub async fn status(&mut self) -> crate::Result<Status> {
185 kio::wait(|waiter| self.poll_status(waiter)).await
186 }
187
188 pub fn poll_closed(&self, waiter: &kio::Waiter) -> Poll<crate::Result<()>> {
193 ready!(self.state.poll_closed(waiter));
194 Poll::Ready(match &self.state.read().error {
195 Some(err) => Err(err.clone()),
196 None => Ok(()),
197 })
198 }
199
200 pub async fn closed(&self) -> crate::Result<()> {
202 kio::wait(|waiter| self.poll_closed(waiter)).await
203 }
204}
205
206impl Drop for Reconnect {
207 fn drop(&mut self) {
208 self.abort.abort();
209 }
210}
211
212fn terminal(state: &State) -> Error {
214 match &state.error {
215 Some(err) => err.clone(),
216 None => Error::Reconnect("reconnect stopped".to_string()),
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_backoff_default() {
226 let backoff = Backoff::default();
227 assert_eq!(backoff.initial, Duration::from_secs(1));
228 assert_eq!(backoff.multiplier, 2);
229 assert_eq!(backoff.max, Duration::from_secs(30));
230 assert_eq!(backoff.timeout, Duration::from_secs(300));
231 }
232}