Skip to main content

moq_native/
reconnect.rs

1use std::task::{Poll, ready};
2use std::time::Duration;
3
4use moq_net::kio;
5use url::Url;
6
7use crate::Client;
8
9/// Exponential backoff configuration for reconnection attempts.
10#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
11#[serde(default, deny_unknown_fields)]
12pub struct Backoff {
13	/// Initial delay before first reconnect attempt.
14	#[arg(
15		id = "backoff-initial",
16		long,
17		default_value = "1s",
18		env = "MOQ_BACKOFF_INITIAL",
19		value_parser = humantime::parse_duration,
20	)]
21	#[serde(with = "humantime_serde")]
22	pub initial: Duration,
23
24	/// Multiplier applied to delay after each failure.
25	#[arg(id = "backoff-multiplier", long, default_value_t = 2, env = "MOQ_BACKOFF_MULTIPLIER")]
26	pub multiplier: u32,
27
28	/// Maximum delay between reconnect attempts.
29	#[arg(
30		id = "backoff-max",
31		long,
32		default_value = "30s",
33		env = "MOQ_BACKOFF_MAX",
34		value_parser = humantime::parse_duration,
35	)]
36	#[serde(with = "humantime_serde")]
37	pub max: Duration,
38
39	/// Maximum time to spend retrying before giving up.
40	/// Resets after each successful connection. Set to 0 for unlimited retries.
41	#[arg(
42		id = "backoff-timeout",
43		long,
44		default_value = "5m",
45		env = "MOQ_BACKOFF_TIMEOUT",
46		value_parser = humantime::parse_duration,
47	)]
48	#[serde(with = "humantime_serde")]
49	pub timeout: Duration,
50}
51
52impl Default for Backoff {
53	fn default() -> Self {
54		Self {
55			initial: Duration::from_secs(1),
56			multiplier: 2,
57			max: Duration::from_secs(30),
58			timeout: Duration::from_secs(300),
59		}
60	}
61}
62
63/// A connection lifecycle transition reported by [`Reconnect::status`].
64#[derive(Clone, Copy, Debug, PartialEq, Eq)]
65pub enum Status {
66	/// A session connected (the first connect, or a reconnect after a drop).
67	Connected,
68	/// An established session dropped; a reconnect attempt follows.
69	Disconnected,
70}
71
72/// Shared reconnect state, observed by consumers through a [`kio`] channel.
73///
74/// The channel closing (all producers dropped) is the terminal signal; `error`
75/// distinguishes a permanent give-up from a graceful close.
76#[derive(Default)]
77struct State {
78	/// Current connection status, or `None` before the first connect.
79	status: Option<Status>,
80	/// Set when the reconnect loop permanently gives up (reconnect timeout exceeded).
81	error: Option<anyhow::Error>,
82}
83
84/// Handle to a background reconnect loop.
85///
86/// Spawns a tokio task that connects, waits for session close, then reconnects with exponential
87/// backoff. [`status`](Self::status) reports connection changes; [`closed`](Self::closed) waits for
88/// the loop to stop. Dropping the handle aborts the background task.
89pub struct Reconnect {
90	abort: tokio::task::AbortHandle,
91	state: kio::Consumer<State>,
92	/// The last status returned by [`status`](Self::status), for change detection.
93	last_reported: Option<Status>,
94}
95
96impl Reconnect {
97	pub(crate) fn new(client: Client, url: Url, backoff: Backoff) -> Self {
98		let producer = kio::Producer::<State>::default();
99		let state = producer.consume();
100		let task = tokio::spawn(async move {
101			if let Err(err) = Self::run(&producer, client, url, backoff).await {
102				tracing::error!(err = %format!("{err:#}"), "reconnect loop exited");
103				if let Ok(mut state) = producer.write() {
104					state.error = Some(err);
105				}
106			}
107			// Dropping the producer here closes the channel, signaling consumers.
108		});
109		Self {
110			abort: task.abort_handle(),
111			state,
112			last_reported: None,
113		}
114	}
115
116	async fn run(state: &kio::Producer<State>, client: Client, url: Url, backoff: Backoff) -> anyhow::Result<()> {
117		let mut delay = backoff.initial;
118		let mut retry_start = tokio::time::Instant::now();
119		let mut last_error: Option<anyhow::Error> = None;
120
121		loop {
122			if !backoff.timeout.is_zero() && retry_start.elapsed() > backoff.timeout {
123				let timeout = backoff.timeout;
124				return Err(last_error
125					.map(|e| e.context(format!("reconnect timed out after {timeout:?}")))
126					.unwrap_or_else(|| anyhow::anyhow!("reconnect timed out after {timeout:?}")));
127			}
128
129			tracing::info!(%url, "connecting");
130
131			match client.connect(url.clone()).await {
132				Ok(session) => {
133					tracing::info!(%url, "connected");
134					delay = backoff.initial;
135					last_error = None;
136					if let Ok(mut state) = state.write() {
137						state.status = Some(Status::Connected);
138					}
139					let _ = session.closed().await;
140					tracing::warn!(%url, "session closed, reconnecting");
141					if let Ok(mut state) = state.write() {
142						state.status = Some(Status::Disconnected);
143					}
144					retry_start = tokio::time::Instant::now();
145				}
146				Err(err) => {
147					tracing::warn!(%url, %err, ?delay, "connection failed, retrying");
148					last_error = Some(err);
149					tokio::time::sleep(delay).await;
150					delay = std::cmp::min(delay * backoff.multiplier, backoff.max);
151				}
152			}
153		}
154	}
155
156	/// Poll for the next connection status change since this handle last reported one.
157	///
158	/// `Ready(Ok(status))` on a change, `Ready(Err)` once the loop has stopped (the give-up error,
159	/// or a generic one when the handle is dropped), `Pending` otherwise.
160	pub fn poll_status(&mut self, waiter: &kio::Waiter) -> Poll<anyhow::Result<Status>> {
161		let last = self.last_reported;
162		let status = match ready!(self.state.poll(waiter, |state| match state.status {
163			Some(status) if Some(status) != last => Poll::Ready(status),
164			_ => Poll::Pending,
165		})) {
166			Ok(status) => status,
167			Err(state) => return Poll::Ready(Err(terminal(&state))),
168		};
169
170		self.last_reported = Some(status);
171		Poll::Ready(Ok(status))
172	}
173
174	/// Wait until the connection status changes from what this handle last reported.
175	///
176	/// Returns the current [`Status`]. The loop alternates `Connected`/`Disconnected`, so successive
177	/// calls alternate too; but a status that flips and flips back before the caller polls is
178	/// reported once. This tracks the *current* state, not every edge.
179	pub async fn status(&mut self) -> anyhow::Result<Status> {
180		kio::wait(|waiter| self.poll_status(waiter)).await
181	}
182
183	/// Poll whether the reconnect loop has stopped.
184	///
185	/// `Ready(Err)` if it permanently gave up (reconnect timeout exceeded), `Ready(Ok(()))` if
186	/// stopped by dropping the handle, `Pending` while it's still running.
187	pub fn poll_closed(&self, waiter: &kio::Waiter) -> Poll<anyhow::Result<()>> {
188		ready!(self.state.poll_closed(waiter));
189		Poll::Ready(match &self.state.read().error {
190			Some(err) => Err(anyhow::anyhow!("{err:#}")),
191			None => Ok(()),
192		})
193	}
194
195	/// Wait until the reconnect loop stops.
196	pub async fn closed(&self) -> anyhow::Result<()> {
197		kio::wait(|waiter| self.poll_closed(waiter)).await
198	}
199}
200
201impl Drop for Reconnect {
202	fn drop(&mut self) {
203		self.abort.abort();
204	}
205}
206
207/// The terminal error read from a closed channel's final state.
208fn terminal(state: &State) -> anyhow::Error {
209	match &state.error {
210		Some(err) => anyhow::anyhow!("{err:#}"),
211		None => anyhow::anyhow!("reconnect stopped"),
212	}
213}
214
215#[cfg(test)]
216mod tests {
217	use super::*;
218
219	#[test]
220	fn test_backoff_default() {
221		let backoff = Backoff::default();
222		assert_eq!(backoff.initial, Duration::from_secs(1));
223		assert_eq!(backoff.multiplier, 2);
224		assert_eq!(backoff.max, Duration::from_secs(30));
225		assert_eq!(backoff.timeout, Duration::from_secs(300));
226	}
227}