Skip to main content

moq_native/
reconnect.rs

1use std::task::{Poll, ready};
2use std::time::Duration;
3
4use moq_net::kio;
5use url::Url;
6
7use crate::{Client, Error};
8
9/// Exponential backoff configuration for reconnection attempts.
10#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
11#[serde(default, deny_unknown_fields)]
12pub struct Backoff {
13	/// Initial delay before first reconnect attempt.
14	#[arg(
15		id = "backoff-initial",
16		long,
17		default_value = "1s",
18		env = "MOQ_BACKOFF_INITIAL",
19		value_parser = humantime::parse_duration,
20	)]
21	#[serde(with = "humantime_serde")]
22	pub initial: Duration,
23
24	/// Multiplier applied to delay after each failure.
25	#[arg(id = "backoff-multiplier", long, default_value_t = 2, env = "MOQ_BACKOFF_MULTIPLIER")]
26	pub multiplier: u32,
27
28	/// Maximum delay between reconnect attempts.
29	#[arg(
30		id = "backoff-max",
31		long,
32		default_value = "30s",
33		env = "MOQ_BACKOFF_MAX",
34		value_parser = humantime::parse_duration,
35	)]
36	#[serde(with = "humantime_serde")]
37	pub max: Duration,
38
39	/// Maximum time to spend retrying before giving up.
40	/// Resets after each successful connection. Set to 0 for unlimited retries.
41	#[arg(
42		id = "backoff-timeout",
43		long,
44		default_value = "5m",
45		env = "MOQ_BACKOFF_TIMEOUT",
46		value_parser = humantime::parse_duration,
47	)]
48	#[serde(with = "humantime_serde")]
49	pub timeout: Duration,
50}
51
52impl Default for Backoff {
53	fn default() -> Self {
54		Self {
55			initial: Duration::from_secs(1),
56			multiplier: 2,
57			max: Duration::from_secs(30),
58			timeout: Duration::from_secs(300),
59		}
60	}
61}
62
63/// A connection lifecycle transition reported by [`Reconnect::status`].
64#[derive(Clone, Copy, Debug, PartialEq, Eq)]
65pub enum Status {
66	/// A session connected (the first connect, or a reconnect after a drop).
67	Connected,
68	/// An established session dropped; a reconnect attempt follows.
69	Disconnected,
70}
71
72/// Shared reconnect state, observed by consumers through a [`kio`] channel.
73///
74/// The channel closing (all producers dropped) is the terminal signal; `error`
75/// distinguishes a permanent give-up from a graceful close.
76#[derive(Default)]
77struct State {
78	/// Current connection status, or `None` before the first connect.
79	status: Option<Status>,
80	/// Set when the reconnect loop permanently gives up (reconnect timeout exceeded).
81	error: Option<Error>,
82}
83
84/// Handle to a background reconnect loop.
85///
86/// Spawns a tokio task that connects, waits for session close, then reconnects with exponential
87/// backoff. [`status`](Self::status) reports connection changes; [`closed`](Self::closed) waits for
88/// the loop to stop. Dropping the handle aborts the background task.
89pub struct Reconnect {
90	abort: tokio::task::AbortHandle,
91	state: kio::Consumer<State>,
92	/// The last status returned by [`status`](Self::status), for change detection.
93	last_reported: Option<Status>,
94}
95
96impl Reconnect {
97	pub(crate) fn new(client: Client, url: Url, backoff: Backoff) -> Self {
98		let producer = kio::Producer::<State>::default();
99		let state = producer.consume();
100		let task = tokio::spawn(async move {
101			if let Err(err) = Self::run(&producer, client, url, backoff).await {
102				tracing::error!(%err, "reconnect loop exited");
103				if let Ok(mut state) = producer.write() {
104					state.error = Some(err);
105				}
106			}
107			// Dropping the producer here closes the channel, signaling consumers.
108		});
109		Self {
110			abort: task.abort_handle(),
111			state,
112			last_reported: None,
113		}
114	}
115
116	async fn run(state: &kio::Producer<State>, client: Client, url: Url, backoff: Backoff) -> crate::Result<()> {
117		let mut delay = backoff.initial;
118		let mut retry_start = tokio::time::Instant::now();
119		let mut last_error: Option<Error> = None;
120
121		loop {
122			if !backoff.timeout.is_zero() && retry_start.elapsed() > backoff.timeout {
123				let timeout = backoff.timeout;
124				let msg = match last_error {
125					Some(err) => format!("reconnect timed out after {timeout:?}: {err}"),
126					None => format!("reconnect timed out after {timeout:?}"),
127				};
128				return Err(Error::Reconnect(msg));
129			}
130
131			tracing::info!(%url, "connecting");
132
133			match client.connect(url.clone()).await {
134				Ok(session) => {
135					tracing::info!(%url, "connected");
136					delay = backoff.initial;
137					last_error = None;
138					if let Ok(mut state) = state.write() {
139						state.status = Some(Status::Connected);
140					}
141					let _ = session.closed().await;
142					tracing::warn!(%url, "session closed, reconnecting");
143					if let Ok(mut state) = state.write() {
144						state.status = Some(Status::Disconnected);
145					}
146					retry_start = tokio::time::Instant::now();
147				}
148				Err(err) => {
149					tracing::warn!(%url, %err, ?delay, "connection failed, retrying");
150					last_error = Some(err);
151					tokio::time::sleep(delay).await;
152					delay = std::cmp::min(delay * backoff.multiplier, backoff.max);
153				}
154			}
155		}
156	}
157
158	/// Poll for the next connection status change since this handle last reported one.
159	///
160	/// `Ready(Ok(status))` on a change, `Ready(Err)` once the loop has stopped (the give-up error,
161	/// or a generic one when the handle is dropped), `Pending` otherwise.
162	pub fn poll_status(&mut self, waiter: &kio::Waiter) -> Poll<crate::Result<Status>> {
163		let last = self.last_reported;
164		let status = match ready!(self.state.poll(waiter, |state| match state.status {
165			Some(status) if Some(status) != last => Poll::Ready(status),
166			_ => Poll::Pending,
167		})) {
168			Ok(status) => status,
169			Err(state) => return Poll::Ready(Err(terminal(&state))),
170		};
171
172		self.last_reported = Some(status);
173		Poll::Ready(Ok(status))
174	}
175
176	/// Wait until the connection status changes from what this handle last reported.
177	///
178	/// Returns the current [`Status`]. The loop alternates `Connected`/`Disconnected`, so successive
179	/// calls alternate too; but a status that flips and flips back before the caller polls is
180	/// reported once. This tracks the *current* state, not every edge.
181	pub async fn status(&mut self) -> crate::Result<Status> {
182		kio::wait(|waiter| self.poll_status(waiter)).await
183	}
184
185	/// Poll whether the reconnect loop has stopped.
186	///
187	/// `Ready(Err)` if it permanently gave up (reconnect timeout exceeded), `Ready(Ok(()))` if
188	/// stopped by dropping the handle, `Pending` while it's still running.
189	pub fn poll_closed(&self, waiter: &kio::Waiter) -> Poll<crate::Result<()>> {
190		ready!(self.state.poll_closed(waiter));
191		Poll::Ready(match &self.state.read().error {
192			Some(err) => Err(err.clone()),
193			None => Ok(()),
194		})
195	}
196
197	/// Wait until the reconnect loop stops.
198	pub async fn closed(&self) -> crate::Result<()> {
199		kio::wait(|waiter| self.poll_closed(waiter)).await
200	}
201}
202
203impl Drop for Reconnect {
204	fn drop(&mut self) {
205		self.abort.abort();
206	}
207}
208
209/// The terminal error read from a closed channel's final state.
210fn terminal(state: &State) -> Error {
211	match &state.error {
212		Some(err) => err.clone(),
213		None => Error::Reconnect("reconnect stopped".to_string()),
214	}
215}
216
217#[cfg(test)]
218mod tests {
219	use super::*;
220
221	#[test]
222	fn test_backoff_default() {
223		let backoff = Backoff::default();
224		assert_eq!(backoff.initial, Duration::from_secs(1));
225		assert_eq!(backoff.multiplier, 2);
226		assert_eq!(backoff.max, Duration::from_secs(30));
227		assert_eq!(backoff.timeout, Duration::from_secs(300));
228	}
229}