use crate::error::{IgtlError, Result};
use crate::protocol::header::Header;
use crate::protocol::message::{IgtlMessage, Message};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
pub max_attempts: Option<usize>,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
pub use_jitter: bool,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
max_attempts: Some(10),
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
use_jitter: true,
}
}
}
impl ReconnectConfig {
pub fn infinite() -> Self {
Self {
max_attempts: None,
..Default::default()
}
}
pub fn with_max_attempts(attempts: usize) -> Self {
Self {
max_attempts: Some(attempts),
..Default::default()
}
}
pub fn with_delays(initial: Duration, max: Duration) -> Self {
Self {
initial_delay: initial,
max_delay: max,
..Default::default()
}
}
pub(crate) fn delay_for_attempt(&self, attempt: usize) -> Duration {
let delay_ms = self.initial_delay.as_millis() as f64
* self.backoff_multiplier.powi(attempt as i32);
let mut delay = Duration::from_millis(delay_ms.min(self.max_delay.as_millis() as f64) as u64);
if self.use_jitter {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash, Hasher};
let mut hasher = RandomState::new().build_hasher();
attempt.hash(&mut hasher);
let hash = hasher.finish();
let jitter = (hash % 25) as f64 / 100.0;
let jitter_ms = (delay.as_millis() as f64 * jitter) as u64;
delay = Duration::from_millis(delay.as_millis() as u64 + jitter_ms);
}
delay
}
}
#[deprecated(
since = "0.2.0",
note = "Use ClientBuilder instead: ClientBuilder::new().tcp(addr).async_mode().with_reconnect(config).build().await"
)]
pub struct ReconnectClient {
addr: String,
stream: Option<TcpStream>,
config: ReconnectConfig,
verify_crc: bool,
reconnect_count: usize,
}
impl ReconnectClient {
pub async fn connect(addr: &str, config: ReconnectConfig) -> Result<Self> {
info!(addr = addr, "Creating reconnecting client");
let stream = Self::try_connect(addr).await?;
Ok(ReconnectClient {
addr: addr.to_string(),
stream: Some(stream),
config,
verify_crc: true,
reconnect_count: 0,
})
}
async fn try_connect(addr: &str) -> Result<TcpStream> {
debug!(addr = addr, "Attempting connection");
let stream = TcpStream::connect(addr).await?;
info!(addr = addr, "Connected successfully");
Ok(stream)
}
async fn ensure_connected(&mut self) -> Result<()> {
if self.stream.is_some() {
return Ok(());
}
let mut attempt = 0;
loop {
if let Some(max) = self.config.max_attempts {
if attempt >= max {
warn!(
attempts = attempt,
max_attempts = max,
"Max reconnection attempts reached"
);
return Err(IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Max reconnection attempts exceeded",
)));
}
}
let delay = self.config.delay_for_attempt(attempt);
info!(
attempt = attempt + 1,
delay_ms = delay.as_millis(),
"Reconnecting..."
);
sleep(delay).await;
match Self::try_connect(&self.addr).await {
Ok(stream) => {
self.stream = Some(stream);
self.reconnect_count += 1;
info!(
reconnect_count = self.reconnect_count,
"Reconnection successful"
);
return Ok(());
}
Err(e) => {
warn!(
attempt = attempt + 1,
error = %e,
"Reconnection attempt failed"
);
attempt += 1;
}
}
}
}
pub fn reconnect_count(&self) -> usize {
self.reconnect_count
}
pub fn is_connected(&self) -> bool {
self.stream.is_some()
}
pub fn set_verify_crc(&mut self, verify: bool) {
self.verify_crc = verify;
}
pub fn verify_crc(&self) -> bool {
self.verify_crc
}
pub async fn send<T: Message>(&mut self, msg: &IgtlMessage<T>) -> Result<()> {
let data = msg.encode()?;
let msg_type = msg.header.type_name.as_str().unwrap_or("UNKNOWN");
debug!(
msg_type = msg_type,
size = data.len(),
"Sending message (with auto-reconnect)"
);
loop {
self.ensure_connected().await?;
if let Some(stream) = &mut self.stream {
match stream.write_all(&data).await {
Ok(_) => {
stream.flush().await?;
debug!(msg_type = msg_type, "Message sent successfully");
return Ok(());
}
Err(e) => {
warn!(error = %e, "Send failed, will reconnect");
self.stream = None;
}
}
}
}
}
pub async fn receive<T: Message>(&mut self) -> Result<IgtlMessage<T>> {
loop {
self.ensure_connected().await?;
if let Some(stream) = &mut self.stream {
let mut header_buf = vec![0u8; Header::SIZE];
match stream.read_exact(&mut header_buf).await {
Ok(_) => {}
Err(e) => {
warn!(error = %e, "Header read failed, will reconnect");
self.stream = None;
continue;
}
}
let header = match Header::decode(&header_buf) {
Ok(h) => h,
Err(e) => {
warn!(error = %e, "Header decode failed");
return Err(e);
}
};
let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
debug!(
msg_type = msg_type,
body_size = header.body_size,
"Received message header"
);
let mut body_buf = vec![0u8; header.body_size as usize];
match stream.read_exact(&mut body_buf).await {
Ok(_) => {}
Err(e) => {
warn!(error = %e, "Body read failed, will reconnect");
self.stream = None;
continue;
}
}
let mut full_msg = header_buf;
full_msg.extend_from_slice(&body_buf);
return IgtlMessage::decode_with_options(&full_msg, self.verify_crc);
}
}
}
pub async fn reconnect(&mut self) -> Result<()> {
info!("Manual reconnection triggered");
self.stream = None;
self.ensure_connected().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reconnect_config_defaults() {
let config = ReconnectConfig::default();
assert_eq!(config.max_attempts, Some(10));
assert_eq!(config.initial_delay, Duration::from_millis(100));
assert_eq!(config.max_delay, Duration::from_secs(30));
assert_eq!(config.backoff_multiplier, 2.0);
assert_eq!(config.use_jitter, true);
}
#[test]
fn test_reconnect_config_infinite() {
let config = ReconnectConfig::infinite();
assert_eq!(config.max_attempts, None);
}
#[test]
fn test_reconnect_config_delay_calculation() {
let config = ReconnectConfig {
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
backoff_multiplier: 2.0,
use_jitter: false,
..Default::default()
};
let delay0 = config.delay_for_attempt(0);
assert_eq!(delay0, Duration::from_millis(100));
let delay1 = config.delay_for_attempt(1);
assert_eq!(delay1, Duration::from_millis(200));
let delay2 = config.delay_for_attempt(2);
assert_eq!(delay2, Duration::from_millis(400));
let delay_large = config.delay_for_attempt(20);
assert!(delay_large <= config.max_delay);
}
#[tokio::test]
async fn test_reconnect_client_creation() {
let config = ReconnectConfig::with_max_attempts(1);
let result = ReconnectClient::connect("127.0.0.1:19999", config).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_reconnect_count() {
let config = ReconnectConfig::default();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let _ = listener.accept().await;
}
});
let client = ReconnectClient::connect(&addr.to_string(), config).await.unwrap();
assert_eq!(client.reconnect_count(), 0);
}
}