use std::task::{Poll, ready};
use std::time::Duration;
use moq_net::kio;
use url::Url;
use crate::Client;
#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct Backoff {
#[arg(
id = "backoff-initial",
long,
default_value = "1s",
env = "MOQ_BACKOFF_INITIAL",
value_parser = humantime::parse_duration,
)]
#[serde(with = "humantime_serde")]
pub initial: Duration,
#[arg(id = "backoff-multiplier", long, default_value_t = 2, env = "MOQ_BACKOFF_MULTIPLIER")]
pub multiplier: u32,
#[arg(
id = "backoff-max",
long,
default_value = "30s",
env = "MOQ_BACKOFF_MAX",
value_parser = humantime::parse_duration,
)]
#[serde(with = "humantime_serde")]
pub max: Duration,
#[arg(
id = "backoff-timeout",
long,
default_value = "5m",
env = "MOQ_BACKOFF_TIMEOUT",
value_parser = humantime::parse_duration,
)]
#[serde(with = "humantime_serde")]
pub timeout: Duration,
}
impl Default for Backoff {
fn default() -> Self {
Self {
initial: Duration::from_secs(1),
multiplier: 2,
max: Duration::from_secs(30),
timeout: Duration::from_secs(300),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Status {
Connected,
Disconnected,
}
#[derive(Default)]
struct State {
status: Option<Status>,
error: Option<anyhow::Error>,
}
pub struct Reconnect {
abort: tokio::task::AbortHandle,
state: kio::Consumer<State>,
last_reported: Option<Status>,
}
impl Reconnect {
pub(crate) fn new(client: Client, url: Url, backoff: Backoff) -> Self {
let producer = kio::Producer::<State>::default();
let state = producer.consume();
let task = tokio::spawn(async move {
if let Err(err) = Self::run(&producer, client, url, backoff).await {
tracing::error!(err = %format!("{err:#}"), "reconnect loop exited");
if let Ok(mut state) = producer.write() {
state.error = Some(err);
}
}
});
Self {
abort: task.abort_handle(),
state,
last_reported: None,
}
}
async fn run(state: &kio::Producer<State>, client: Client, url: Url, backoff: Backoff) -> anyhow::Result<()> {
let mut delay = backoff.initial;
let mut retry_start = tokio::time::Instant::now();
let mut last_error: Option<anyhow::Error> = None;
loop {
if !backoff.timeout.is_zero() && retry_start.elapsed() > backoff.timeout {
let timeout = backoff.timeout;
return Err(last_error
.map(|e| e.context(format!("reconnect timed out after {timeout:?}")))
.unwrap_or_else(|| anyhow::anyhow!("reconnect timed out after {timeout:?}")));
}
tracing::info!(%url, "connecting");
match client.connect(url.clone()).await {
Ok(session) => {
tracing::info!(%url, "connected");
delay = backoff.initial;
last_error = None;
if let Ok(mut state) = state.write() {
state.status = Some(Status::Connected);
}
let _ = session.closed().await;
tracing::warn!(%url, "session closed, reconnecting");
if let Ok(mut state) = state.write() {
state.status = Some(Status::Disconnected);
}
retry_start = tokio::time::Instant::now();
}
Err(err) => {
tracing::warn!(%url, %err, ?delay, "connection failed, retrying");
last_error = Some(err);
tokio::time::sleep(delay).await;
delay = std::cmp::min(delay * backoff.multiplier, backoff.max);
}
}
}
}
pub fn poll_status(&mut self, waiter: &kio::Waiter) -> Poll<anyhow::Result<Status>> {
let last = self.last_reported;
let status = match ready!(self.state.poll(waiter, |state| match state.status {
Some(status) if Some(status) != last => Poll::Ready(status),
_ => Poll::Pending,
})) {
Ok(status) => status,
Err(state) => return Poll::Ready(Err(terminal(&state))),
};
self.last_reported = Some(status);
Poll::Ready(Ok(status))
}
pub async fn status(&mut self) -> anyhow::Result<Status> {
kio::wait(|waiter| self.poll_status(waiter)).await
}
pub fn poll_closed(&self, waiter: &kio::Waiter) -> Poll<anyhow::Result<()>> {
ready!(self.state.poll_closed(waiter));
Poll::Ready(match &self.state.read().error {
Some(err) => Err(anyhow::anyhow!("{err:#}")),
None => Ok(()),
})
}
pub async fn closed(&self) -> anyhow::Result<()> {
kio::wait(|waiter| self.poll_closed(waiter)).await
}
}
impl Drop for Reconnect {
fn drop(&mut self) {
self.abort.abort();
}
}
fn terminal(state: &State) -> anyhow::Error {
match &state.error {
Some(err) => anyhow::anyhow!("{err:#}"),
None => anyhow::anyhow!("reconnect stopped"),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backoff_default() {
let backoff = Backoff::default();
assert_eq!(backoff.initial, Duration::from_secs(1));
assert_eq!(backoff.multiplier, 2);
assert_eq!(backoff.max, Duration::from_secs(30));
assert_eq!(backoff.timeout, Duration::from_secs(300));
}
}