use crate::error::{IgtlError, Result};
use crate::io::reconnect::ReconnectConfig;
use crate::protocol::header::Header;
use crate::protocol::message::{IgtlMessage, Message};
use rustls::pki_types::ServerName;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::sleep;
use tokio_rustls::client::TlsStream;
use tokio_rustls::{rustls, TlsConnector};
use tracing::{debug, info, trace, warn};
pub struct TcpAsyncTlsReconnectClient {
hostname: String,
port: u16,
tls_config: Arc<rustls::ClientConfig>,
reconnect_config: ReconnectConfig,
stream: Option<TlsStream<TcpStream>>,
reconnect_count: usize,
verify_crc: bool,
}
impl TcpAsyncTlsReconnectClient {
pub async fn connect(
hostname: &str,
port: u16,
tls_config: Arc<rustls::ClientConfig>,
reconnect_config: ReconnectConfig,
) -> Result<Self> {
info!(
hostname = hostname,
port = port,
"Creating TLS reconnecting client"
);
let mut client = Self {
hostname: hostname.to_string(),
port,
tls_config,
reconnect_config,
stream: None,
reconnect_count: 0,
verify_crc: true,
};
client.ensure_connected().await?;
Ok(client)
}
async fn try_connect(
hostname: &str,
port: u16,
tls_config: Arc<rustls::ClientConfig>,
) -> Result<TlsStream<TcpStream>> {
let addr = format!("{}:{}", hostname, port);
debug!(addr = %addr, "Attempting TLS connection");
let tcp_stream = TcpStream::connect(&addr).await?;
let local_addr = tcp_stream.local_addr()?;
let server_name = ServerName::try_from(hostname.to_string()).map_err(|e| {
IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid hostname: {}", e),
))
})?;
let connector = TlsConnector::from(tls_config);
let stream = connector.connect(server_name, tcp_stream).await.map_err(|e| {
warn!(error = %e, "TLS handshake failed");
IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
format!("TLS handshake failed: {}", e),
))
})?;
info!(
local_addr = %local_addr,
remote_addr = %addr,
"TLS connection established"
);
Ok(stream)
}
async fn ensure_connected(&mut self) -> Result<()> {
if self.stream.is_some() {
return Ok(());
}
let mut attempt = 0;
loop {
if let Some(max) = self.reconnect_config.max_attempts {
if attempt >= max {
warn!(
attempts = attempt,
max_attempts = max,
"Max reconnection attempts reached"
);
return Err(IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Max reconnection attempts exceeded",
)));
}
}
let delay = self.reconnect_config.delay_for_attempt(attempt);
if attempt > 0 {
info!(
attempt = attempt + 1,
delay_ms = delay.as_millis(),
"Reconnecting with TLS..."
);
sleep(delay).await;
}
match Self::try_connect(&self.hostname, self.port, self.tls_config.clone()).await {
Ok(stream) => {
self.stream = Some(stream);
if attempt > 0 {
self.reconnect_count += 1;
info!(
reconnect_count = self.reconnect_count,
"TLS reconnection successful"
);
}
return Ok(());
}
Err(e) => {
warn!(
attempt = attempt + 1,
error = %e,
"TLS reconnection attempt failed"
);
attempt += 1;
}
}
}
}
pub fn reconnect_count(&self) -> usize {
self.reconnect_count
}
pub fn is_connected(&self) -> bool {
self.stream.is_some()
}
pub fn set_verify_crc(&mut self, verify: bool) {
self.verify_crc = verify;
}
pub fn verify_crc(&self) -> bool {
self.verify_crc
}
pub async fn send<T: Message>(&mut self, msg: &IgtlMessage<T>) -> Result<()> {
let data = msg.encode()?;
let msg_type = msg.header.type_name.as_str().unwrap_or("UNKNOWN");
let device_name = msg.header.device_name.as_str().unwrap_or("UNKNOWN");
debug!(
msg_type = msg_type,
device_name = device_name,
size = data.len(),
"Sending message over TLS (with auto-reconnect)"
);
loop {
self.ensure_connected().await?;
if let Some(stream) = &mut self.stream {
match stream.write_all(&data).await {
Ok(_) => {
stream.flush().await?;
trace!(
msg_type = msg_type,
bytes_sent = data.len(),
"Message sent successfully over TLS"
);
return Ok(());
}
Err(e) => {
warn!(error = %e, "TLS send failed, will reconnect");
self.stream = None;
}
}
}
}
}
pub async fn receive<T: Message>(&mut self) -> Result<IgtlMessage<T>> {
loop {
self.ensure_connected().await?;
if let Some(stream) = &mut self.stream {
let mut header_buf = vec![0u8; Header::SIZE];
match stream.read_exact(&mut header_buf).await {
Ok(_) => {}
Err(e) => {
warn!(error = %e, "TLS header read failed, will reconnect");
self.stream = None;
continue;
}
}
let header = match Header::decode(&header_buf) {
Ok(h) => h,
Err(e) => {
warn!(error = %e, "Header decode failed");
return Err(e);
}
};
let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
let device_name = header.device_name.as_str().unwrap_or("UNKNOWN");
debug!(
msg_type = msg_type,
device_name = device_name,
body_size = header.body_size,
version = header.version,
"Received message header over TLS"
);
let mut body_buf = vec![0u8; header.body_size as usize];
match stream.read_exact(&mut body_buf).await {
Ok(_) => {}
Err(e) => {
warn!(error = %e, "TLS body read failed, will reconnect");
self.stream = None;
continue;
}
}
trace!(
msg_type = msg_type,
bytes_read = body_buf.len(),
"Message body received over TLS"
);
let mut full_msg = header_buf;
full_msg.extend_from_slice(&body_buf);
let result = IgtlMessage::decode_with_options(&full_msg, self.verify_crc);
match &result {
Ok(_) => {
debug!(
msg_type = msg_type,
device_name = device_name,
"Message decoded successfully"
);
}
Err(e) => {
warn!(
msg_type = msg_type,
error = %e,
"Failed to decode message"
);
}
}
return result;
}
}
}
}