Skip to main content

moq_native/
reconnect.rs

1use std::task::{Poll, ready};
2use std::time::Duration;
3
4use moq_net::kio;
5use url::Url;
6
7use crate::{Client, Error};
8
9/// Exponential backoff configuration for reconnection attempts.
10#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
11#[serde(default, deny_unknown_fields)]
12pub struct Backoff {
13	/// Initial delay before first reconnect attempt.
14	#[arg(
15		id = "backoff-initial",
16		long,
17		default_value = "1s",
18		env = "MOQ_BACKOFF_INITIAL",
19		value_parser = humantime::parse_duration,
20	)]
21	#[serde(with = "humantime_serde")]
22	pub initial: Duration,
23
24	/// Multiplier applied to delay after each failure.
25	#[arg(id = "backoff-multiplier", long, default_value_t = 2, env = "MOQ_BACKOFF_MULTIPLIER")]
26	pub multiplier: u32,
27
28	/// Maximum delay between reconnect attempts.
29	#[arg(
30		id = "backoff-max",
31		long,
32		default_value = "30s",
33		env = "MOQ_BACKOFF_MAX",
34		value_parser = humantime::parse_duration,
35	)]
36	#[serde(with = "humantime_serde")]
37	pub max: Duration,
38
39	/// Maximum time to spend retrying before giving up.
40	/// Resets after each successful connection. Set to 0 for unlimited retries.
41	#[arg(
42		id = "backoff-timeout",
43		long,
44		default_value = "5m",
45		env = "MOQ_BACKOFF_TIMEOUT",
46		value_parser = humantime::parse_duration,
47	)]
48	#[serde(with = "humantime_serde")]
49	pub timeout: Duration,
50}
51
52impl Default for Backoff {
53	fn default() -> Self {
54		Self {
55			initial: Duration::from_secs(1),
56			multiplier: 2,
57			max: Duration::from_secs(30),
58			timeout: Duration::from_secs(300),
59		}
60	}
61}
62
63/// A connection lifecycle transition reported by [`Reconnect::status`].
64#[derive(Clone, Copy, Debug, PartialEq, Eq)]
65pub enum Status {
66	/// A session connected (the first connect, or a reconnect after a drop).
67	Connected,
68	/// An established session dropped; a reconnect attempt follows.
69	Disconnected,
70}
71
72/// Shared reconnect state, observed by consumers through a [`kio`] channel.
73///
74/// The channel closing (all producers dropped) is the terminal signal; `error`
75/// distinguishes a permanent give-up from a graceful close.
76#[derive(Default)]
77struct State {
78	/// Current connection status, or `None` before the first connect.
79	status: Option<Status>,
80	/// Set when the reconnect loop permanently gives up (reconnect timeout exceeded).
81	error: Option<Error>,
82}
83
84/// Handle to a background reconnect loop.
85///
86/// Spawns a tokio task that connects, waits for session close, then reconnects with exponential
87/// backoff. [`status`](Self::status) reports connection changes; [`closed`](Self::closed) waits for
88/// the loop to stop. Dropping the handle aborts the background task.
89pub struct Reconnect {
90	abort: tokio::task::AbortHandle,
91	state: kio::Consumer<State>,
92	/// The last status returned by [`status`](Self::status), for change detection.
93	last_reported: Option<Status>,
94}
95
96impl Reconnect {
97	pub(crate) fn new(client: Client, url: Url, backoff: Backoff) -> Self {
98		let producer = kio::Producer::<State>::default();
99		let state = producer.consume();
100		let task = tokio::spawn(async move {
101			if let Err(err) = Self::run(&producer, client, url, backoff).await {
102				tracing::error!(%err, "reconnect loop exited");
103				if let Ok(mut state) = producer.write() {
104					state.error = Some(err);
105				}
106			}
107			// Dropping the producer here closes the channel, signaling consumers.
108		});
109		Self {
110			abort: task.abort_handle(),
111			state,
112			last_reported: None,
113		}
114	}
115
116	async fn run(state: &kio::Producer<State>, client: Client, url: Url, backoff: Backoff) -> crate::Result<()> {
117		let mut delay = backoff.initial;
118		let mut retry_start = tokio::time::Instant::now();
119		let mut last_error: Option<Error> = None;
120
121		loop {
122			if !backoff.timeout.is_zero() && retry_start.elapsed() > backoff.timeout {
123				let timeout = backoff.timeout;
124				let msg = match last_error {
125					Some(err) => format!("reconnect timed out after {timeout:?}: {err}"),
126					None => format!("reconnect timed out after {timeout:?}"),
127				};
128				return Err(Error::Reconnect(msg));
129			}
130
131			tracing::info!(%url, "connecting");
132
133			match client.connect(url.clone()).await {
134				Ok(session) => {
135					tracing::info!(%url, "connected");
136					delay = backoff.initial;
137					last_error = None;
138					if let Ok(mut state) = state.write() {
139						state.status = Some(Status::Connected);
140					}
141					let _ = session.closed().await;
142					tracing::warn!(%url, "session closed, reconnecting");
143					if let Ok(mut state) = state.write() {
144						state.status = Some(Status::Disconnected);
145					}
146					retry_start = tokio::time::Instant::now();
147				}
148				Err(err) => {
149					if err.is_auth() {
150						return Err(err);
151					}
152					tracing::warn!(%url, %err, ?delay, "connection failed, retrying");
153					last_error = Some(err);
154					tokio::time::sleep(delay).await;
155					delay = std::cmp::min(delay * backoff.multiplier, backoff.max);
156				}
157			}
158		}
159	}
160
161	/// Poll for the next connection status change since this handle last reported one.
162	///
163	/// `Ready(Ok(status))` on a change, `Ready(Err)` once the loop has stopped (the give-up error,
164	/// or a generic one when the handle is dropped), `Pending` otherwise.
165	pub fn poll_status(&mut self, waiter: &kio::Waiter) -> Poll<crate::Result<Status>> {
166		let last = self.last_reported;
167		let status = match ready!(self.state.poll(waiter, |state| match state.status {
168			Some(status) if Some(status) != last => Poll::Ready(status),
169			_ => Poll::Pending,
170		})) {
171			Ok(status) => status,
172			Err(state) => return Poll::Ready(Err(terminal(&state))),
173		};
174
175		self.last_reported = Some(status);
176		Poll::Ready(Ok(status))
177	}
178
179	/// Wait until the connection status changes from what this handle last reported.
180	///
181	/// Returns the current [`Status`]. The loop alternates `Connected`/`Disconnected`, so successive
182	/// calls alternate too; but a status that flips and flips back before the caller polls is
183	/// reported once. This tracks the *current* state, not every edge.
184	pub async fn status(&mut self) -> crate::Result<Status> {
185		kio::wait(|waiter| self.poll_status(waiter)).await
186	}
187
188	/// Poll whether the reconnect loop has stopped.
189	///
190	/// `Ready(Err)` if it permanently gave up (reconnect timeout exceeded), `Ready(Ok(()))` if
191	/// stopped by dropping the handle, `Pending` while it's still running.
192	pub fn poll_closed(&self, waiter: &kio::Waiter) -> Poll<crate::Result<()>> {
193		ready!(self.state.poll_closed(waiter));
194		Poll::Ready(match &self.state.read().error {
195			Some(err) => Err(err.clone()),
196			None => Ok(()),
197		})
198	}
199
200	/// Wait until the reconnect loop stops.
201	pub async fn closed(&self) -> crate::Result<()> {
202		kio::wait(|waiter| self.poll_closed(waiter)).await
203	}
204}
205
206impl Drop for Reconnect {
207	fn drop(&mut self) {
208		self.abort.abort();
209	}
210}
211
212/// The terminal error read from a closed channel's final state.
213fn terminal(state: &State) -> Error {
214	match &state.error {
215		Some(err) => err.clone(),
216		None => Error::Reconnect("reconnect stopped".to_string()),
217	}
218}
219
220#[cfg(test)]
221mod tests {
222	use super::*;
223
224	#[test]
225	fn test_backoff_default() {
226		let backoff = Backoff::default();
227		assert_eq!(backoff.initial, Duration::from_secs(1));
228		assert_eq!(backoff.multiplier, 2);
229		assert_eq!(backoff.max, Duration::from_secs(30));
230		assert_eq!(backoff.timeout, Duration::from_secs(300));
231	}
232}