use std::collections::HashMap;
use std::time::Duration;
use futures_core::Stream;
use seedlink_rs_protocol::SequenceNumber;
use tracing::{debug, info, warn};
use crate::SeedLinkClient;
use crate::error::{ClientError, Result};
use crate::state::{ClientConfig, OwnedFrame, StationKey};
#[derive(Clone, Debug)]
pub struct ReconnectConfig {
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub multiplier: f64,
pub max_attempts: u32,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
initial_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(60),
multiplier: 2.0,
max_attempts: 0,
}
}
}
#[derive(Clone, Debug)]
enum SubscriptionStep {
Station { station: String, network: String },
Select { pattern: String },
Data,
DataFrom(SequenceNumber),
TimeWindow { start: String, end: Option<String> },
}
pub struct ReconnectingClient {
addr: String,
config: ClientConfig,
reconnect: ReconnectConfig,
subscriptions: Vec<SubscriptionStep>,
client: Option<SeedLinkClient>,
sequences: HashMap<StationKey, SequenceNumber>,
}
impl ReconnectingClient {
pub async fn connect(addr: &str) -> Result<Self> {
Self::connect_with_config(addr, ClientConfig::default(), ReconnectConfig::default()).await
}
pub async fn connect_with_config(
addr: &str,
config: ClientConfig,
reconnect: ReconnectConfig,
) -> Result<Self> {
let client = SeedLinkClient::connect_with_config(addr, config.clone()).await?;
Ok(Self {
addr: addr.to_owned(),
config,
reconnect,
subscriptions: Vec::new(),
client: Some(client),
sequences: HashMap::new(),
})
}
pub async fn station(&mut self, station: &str, network: &str) -> Result<()> {
self.subscriptions.push(SubscriptionStep::Station {
station: station.to_owned(),
network: network.to_owned(),
});
self.client_mut()?.station(station, network).await
}
pub async fn select(&mut self, pattern: &str) -> Result<()> {
self.subscriptions.push(SubscriptionStep::Select {
pattern: pattern.to_owned(),
});
self.client_mut()?.select(pattern).await
}
pub async fn data(&mut self) -> Result<()> {
self.subscriptions.push(SubscriptionStep::Data);
self.client_mut()?.data().await
}
pub async fn data_from(&mut self, sequence: SequenceNumber) -> Result<()> {
self.subscriptions
.push(SubscriptionStep::DataFrom(sequence));
self.client_mut()?.data_from(sequence).await
}
pub async fn time_window(&mut self, start: &str, end: Option<&str>) -> Result<()> {
self.subscriptions.push(SubscriptionStep::TimeWindow {
start: start.to_owned(),
end: end.map(|s| s.to_owned()),
});
self.client_mut()?.time_window(start, end).await
}
pub async fn end_stream(&mut self) -> Result<()> {
self.client_mut()?.end_stream().await
}
pub async fn next_frame(&mut self) -> Result<Option<OwnedFrame>> {
loop {
let result = match self.client.as_mut() {
Some(client) => client.next_frame().await,
None => return Err(ClientError::Disconnected),
};
match result {
Ok(Some(frame)) => {
if let Some(key) = frame.station_key()
&& let Some(&tracked) = self.sequences.get(&key)
&& frame.sequence() <= tracked
{
debug!(
seq = %frame.sequence(),
tracked = %tracked,
station = ?key,
"skipping duplicate frame"
);
continue;
}
self.sync_sequences();
return Ok(Some(frame));
}
Ok(None) => {
debug!("stream ended, attempting reconnect");
match self.attempt_reconnect().await {
Ok(()) => {
continue;
}
Err(ClientError::ReconnectFailed { attempts }) => {
warn!(attempts, "reconnect failed, giving up");
return Err(ClientError::ReconnectFailed { attempts });
}
Err(e) => return Err(e),
}
}
Err(e) => return Err(e),
}
}
}
pub fn into_stream(self) -> impl Stream<Item = Result<OwnedFrame>> {
async_stream::try_stream! {
let mut this = self;
loop {
match this.next_frame().await {
Ok(Some(frame)) => yield frame,
Ok(None) => break,
Err(ClientError::ReconnectFailed { .. }) => break,
Err(e) => Err(e)?,
}
}
}
}
pub fn last_sequence(&self, network: &str, station: &str) -> Option<SequenceNumber> {
let key = StationKey {
network: network.to_owned(),
station: station.to_owned(),
};
self.sequences.get(&key).copied()
}
pub fn sequences(&self) -> &HashMap<StationKey, SequenceNumber> {
&self.sequences
}
fn client_mut(&mut self) -> Result<&mut SeedLinkClient> {
self.client.as_mut().ok_or(ClientError::Disconnected)
}
fn sync_sequences(&mut self) {
if let Some(client) = &self.client {
for (key, seq) in client.sequences() {
self.sequences.insert(key.clone(), *seq);
}
}
}
async fn attempt_reconnect(&mut self) -> Result<()> {
self.client = None;
let mut backoff = self.reconnect.initial_backoff;
let max_attempts = self.reconnect.max_attempts;
for attempt in 1.. {
if max_attempts > 0 && attempt > max_attempts {
return Err(ClientError::ReconnectFailed {
attempts: max_attempts,
});
}
info!(attempt, backoff_ms = backoff.as_millis(), "reconnecting");
tokio::time::sleep(backoff).await;
match SeedLinkClient::connect_with_config(&self.addr, self.config.clone()).await {
Ok(mut new_client) => {
if let Err(e) = self.replay_subscriptions(&mut new_client).await {
warn!(attempt, error = %e, "replay failed, retrying");
backoff = self.next_backoff(backoff);
continue;
}
if let Err(e) = new_client.end_stream().await {
warn!(attempt, error = %e, "end_stream failed, retrying");
backoff = self.next_backoff(backoff);
continue;
}
info!(attempt, "reconnected successfully");
self.client = Some(new_client);
return Ok(());
}
Err(e) => {
warn!(attempt, error = %e, "reconnect attempt failed");
backoff = self.next_backoff(backoff);
}
}
}
unreachable!()
}
fn next_backoff(&self, current: Duration) -> Duration {
let next = current.mul_f64(self.reconnect.multiplier);
next.min(self.reconnect.max_backoff)
}
async fn replay_subscriptions(&self, client: &mut SeedLinkClient) -> Result<()> {
let mut current_station: Option<StationKey> = None;
for step in &self.subscriptions {
match step {
SubscriptionStep::Station { station, network } => {
client.station(station, network).await?;
current_station = Some(StationKey {
network: network.clone(),
station: station.clone(),
});
}
SubscriptionStep::Select { pattern } => {
client.select(pattern).await?;
}
SubscriptionStep::Data => {
if let Some(ref key) = current_station {
if let Some(seq) = self.sequences.get(key) {
debug!(%seq, station = %key.station, network = %key.network, "resuming from sequence");
client.data_from(*seq).await?;
} else {
client.data().await?;
}
} else {
client.data().await?;
}
}
SubscriptionStep::DataFrom(seq) => {
if let Some(ref key) = current_station
&& let Some(tracked) = self.sequences.get(key)
&& *tracked > *seq
{
client.data_from(*tracked).await?;
continue;
}
client.data_from(*seq).await?;
}
SubscriptionStep::TimeWindow { start, end } => {
client.time_window(start, end.as_deref()).await?;
}
}
}
Ok(())
}
}
impl Clone for ClientConfig {
fn clone(&self) -> Self {
Self {
connect_timeout: self.connect_timeout,
read_timeout: self.read_timeout,
prefer_v4: self.prefer_v4,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mock::{MockConfig, MockServer};
use seedlink_rs_protocol::frame::v3;
fn make_v3_frame(seq: u64, station: &str, network: &str) -> Vec<u8> {
let mut payload = [0u8; v3::PAYLOAD_LEN];
let sta_bytes = station.as_bytes();
for (i, &b) in sta_bytes.iter().enumerate().take(5) {
payload[8 + i] = b;
}
for i in sta_bytes.len()..5 {
payload[8 + i] = b' ';
}
let net_bytes = network.as_bytes();
for (i, &b) in net_bytes.iter().enumerate().take(2) {
payload[18 + i] = b;
}
for i in net_bytes.len()..2 {
payload[18 + i] = b' ';
}
v3::write(SequenceNumber::new(seq), &payload).unwrap()
}
#[tokio::test]
async fn reconnect_on_disconnect() {
let config = MockConfig {
close_after_stream: true,
max_connections: 2,
connection_frames: Some(vec![
vec![make_v3_frame(1, "ANMO", "IU")],
vec![make_v3_frame(2, "ANMO", "IU")],
]),
..MockConfig::v3_default(vec![])
};
let server = MockServer::start(config).await;
let reconnect_config = ReconnectConfig {
initial_backoff: Duration::from_millis(10),
max_backoff: Duration::from_millis(50),
max_attempts: 3,
..Default::default()
};
let client_config = ClientConfig {
prefer_v4: false,
..Default::default()
};
let mut client = ReconnectingClient::connect_with_config(
&server.addr().to_string(),
client_config,
reconnect_config,
)
.await
.unwrap();
client.station("ANMO", "IU").await.unwrap();
client.data().await.unwrap();
client.end_stream().await.unwrap();
let frame1 = client.next_frame().await.unwrap().unwrap();
assert_eq!(frame1.sequence(), SequenceNumber::new(1));
let frame2 = client.next_frame().await.unwrap().unwrap();
assert_eq!(frame2.sequence(), SequenceNumber::new(2));
}
#[tokio::test]
async fn reconnect_max_attempts() {
let frames = vec![make_v3_frame(1, "ANMO", "IU")];
let config = MockConfig {
close_after_stream: true,
max_connections: 1,
..MockConfig::v3_default(frames)
};
let server = MockServer::start(config).await;
let reconnect_config = ReconnectConfig {
initial_backoff: Duration::from_millis(10),
max_backoff: Duration::from_millis(20),
max_attempts: 2,
..Default::default()
};
let client_config = ClientConfig {
prefer_v4: false,
..Default::default()
};
let mut client = ReconnectingClient::connect_with_config(
&server.addr().to_string(),
client_config,
reconnect_config,
)
.await
.unwrap();
client.station("ANMO", "IU").await.unwrap();
client.data().await.unwrap();
client.end_stream().await.unwrap();
let frame = client.next_frame().await.unwrap().unwrap();
assert_eq!(frame.sequence(), SequenceNumber::new(1));
let err = client.next_frame().await.unwrap_err();
assert!(matches!(err, ClientError::ReconnectFailed { attempts: 2 }));
}
#[tokio::test]
async fn reconnect_resumes_sequence_verified_on_wire() {
let config = MockConfig {
close_after_stream: true,
max_connections: 2,
connection_frames: Some(vec![
vec![
make_v3_frame(10, "ANMO", "IU"),
make_v3_frame(11, "ANMO", "IU"),
],
vec![
make_v3_frame(10, "ANMO", "IU"),
make_v3_frame(11, "ANMO", "IU"),
make_v3_frame(12, "ANMO", "IU"),
],
]),
..MockConfig::v3_default(vec![])
};
let server = MockServer::start(config).await;
let reconnect_config = ReconnectConfig {
initial_backoff: Duration::from_millis(10),
max_backoff: Duration::from_millis(50),
max_attempts: 3,
..Default::default()
};
let client_config = ClientConfig {
prefer_v4: false,
..Default::default()
};
let mut client = ReconnectingClient::connect_with_config(
&server.addr().to_string(),
client_config,
reconnect_config,
)
.await
.unwrap();
client.station("ANMO", "IU").await.unwrap();
client.data().await.unwrap();
client.end_stream().await.unwrap();
let f1 = client.next_frame().await.unwrap().unwrap();
assert_eq!(f1.sequence(), SequenceNumber::new(10));
let f2 = client.next_frame().await.unwrap().unwrap();
assert_eq!(f2.sequence(), SequenceNumber::new(11));
assert_eq!(
client.last_sequence("IU", "ANMO"),
Some(SequenceNumber::new(11))
);
let f3 = client.next_frame().await.unwrap().unwrap();
assert_eq!(f3.sequence(), SequenceNumber::new(12));
let conn0 = server.captured().connection(0);
assert_eq!(conn0[0], "HELLO");
assert_eq!(conn0[1], "STATION ANMO IU");
assert_eq!(conn0[2], "DATA");
assert_eq!(conn0[3], "END");
let conn1 = server.captured().connection(1);
assert_eq!(conn1[0], "HELLO");
assert_eq!(conn1[1], "STATION ANMO IU");
assert_eq!(conn1[2], "DATA 00000B");
assert_eq!(conn1[3], "END");
}
#[tokio::test]
async fn reconnect_multi_station_resumes_each_sequence() {
let config = MockConfig {
close_after_stream: true,
max_connections: 2,
connection_frames: Some(vec![
vec![
make_v3_frame(10, "ANMO", "IU"),
make_v3_frame(11, "ANMO", "IU"),
make_v3_frame(4, "WLF", "GE"),
make_v3_frame(5, "WLF", "GE"),
],
vec![
make_v3_frame(11, "ANMO", "IU"), make_v3_frame(12, "ANMO", "IU"), make_v3_frame(5, "WLF", "GE"), make_v3_frame(6, "WLF", "GE"), ],
]),
..MockConfig::v3_default(vec![])
};
let server = MockServer::start(config).await;
let reconnect_config = ReconnectConfig {
initial_backoff: Duration::from_millis(10),
max_backoff: Duration::from_millis(50),
max_attempts: 3,
..Default::default()
};
let client_config = ClientConfig {
prefer_v4: false,
..Default::default()
};
let mut client = ReconnectingClient::connect_with_config(
&server.addr().to_string(),
client_config,
reconnect_config,
)
.await
.unwrap();
client.station("ANMO", "IU").await.unwrap();
client.data().await.unwrap();
client.station("WLF", "GE").await.unwrap();
client.data().await.unwrap();
client.end_stream().await.unwrap();
for _ in 0..4 {
client.next_frame().await.unwrap().unwrap();
}
assert_eq!(
client.last_sequence("IU", "ANMO"),
Some(SequenceNumber::new(11))
);
assert_eq!(
client.last_sequence("GE", "WLF"),
Some(SequenceNumber::new(5))
);
let f = client.next_frame().await.unwrap().unwrap();
assert_eq!(f.sequence(), SequenceNumber::new(12));
let f = client.next_frame().await.unwrap().unwrap();
assert_eq!(f.sequence(), SequenceNumber::new(6));
let conn1 = server.captured().connection(1);
assert_eq!(conn1[0], "HELLO");
assert_eq!(conn1[1], "STATION ANMO IU");
assert_eq!(conn1[2], "DATA 00000B"); assert_eq!(conn1[3], "STATION WLF GE");
assert_eq!(conn1[4], "DATA 000005"); assert_eq!(conn1[5], "END");
}
#[tokio::test]
async fn reconnect_into_stream() {
use std::pin::pin;
use tokio_stream::StreamExt;
let config = MockConfig {
close_after_stream: true,
max_connections: 2,
connection_frames: Some(vec![
vec![make_v3_frame(1, "ANMO", "IU")],
vec![make_v3_frame(2, "ANMO", "IU")],
]),
..MockConfig::v3_default(vec![])
};
let server = MockServer::start(config).await;
let reconnect_config = ReconnectConfig {
initial_backoff: Duration::from_millis(10),
max_backoff: Duration::from_millis(50),
max_attempts: 1,
..Default::default()
};
let client_config = ClientConfig {
prefer_v4: false,
..Default::default()
};
let mut client = ReconnectingClient::connect_with_config(
&server.addr().to_string(),
client_config,
reconnect_config,
)
.await
.unwrap();
client.station("ANMO", "IU").await.unwrap();
client.data().await.unwrap();
client.end_stream().await.unwrap();
let mut stream = pin!(client.into_stream());
let frame1 = stream.next().await.unwrap().unwrap();
assert_eq!(frame1.sequence(), SequenceNumber::new(1));
let frame2 = stream.next().await.unwrap().unwrap();
assert_eq!(frame2.sequence(), SequenceNumber::new(2));
let end = stream.next().await;
assert!(end.is_none());
}
#[tokio::test]
async fn reconnect_dedup_skips_all_duplicates() {
let config = MockConfig {
close_after_stream: true,
max_connections: 2,
connection_frames: Some(vec![
vec![
make_v3_frame(10, "ANMO", "IU"),
make_v3_frame(11, "ANMO", "IU"),
],
vec![
make_v3_frame(10, "ANMO", "IU"),
make_v3_frame(11, "ANMO", "IU"),
],
]),
..MockConfig::v3_default(vec![])
};
let server = MockServer::start(config).await;
let reconnect_config = ReconnectConfig {
initial_backoff: Duration::from_millis(10),
max_backoff: Duration::from_millis(20),
max_attempts: 1,
..Default::default()
};
let client_config = ClientConfig {
prefer_v4: false,
..Default::default()
};
let mut client = ReconnectingClient::connect_with_config(
&server.addr().to_string(),
client_config,
reconnect_config,
)
.await
.unwrap();
client.station("ANMO", "IU").await.unwrap();
client.data().await.unwrap();
client.end_stream().await.unwrap();
let f1 = client.next_frame().await.unwrap().unwrap();
assert_eq!(f1.sequence(), SequenceNumber::new(10));
let f2 = client.next_frame().await.unwrap().unwrap();
assert_eq!(f2.sequence(), SequenceNumber::new(11));
let err = client.next_frame().await.unwrap_err();
assert!(matches!(err, ClientError::ReconnectFailed { attempts: 1 }));
}
}