use crate::error::{MqttError, Result};
use crate::time::{Duration, Instant};
use crate::Transport;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use tokio::time::interval;
use tracing::{debug, instrument, warn};
#[derive(Debug, Clone, Default)]
pub struct ConnectionStats {
pub connections_established: u64,
pub connection_failures: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub last_connected: Option<Instant>,
pub last_disconnected: Option<Instant>,
pub uptime: Option<Duration>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Disconnected,
Connecting,
Connected,
Closing,
}
#[derive(Debug, Clone)]
pub struct ManagerConfig {
pub auto_reconnect: bool,
pub reconnect_delay: Duration,
pub max_reconnect_delay: Duration,
pub reconnect_multiplier: f64,
pub keep_alive_interval: Option<Duration>,
pub idle_timeout: Option<Duration>,
}
impl Default for ManagerConfig {
fn default() -> Self {
Self {
auto_reconnect: true,
reconnect_delay: Duration::from_secs(1),
max_reconnect_delay: Duration::from_secs(60),
reconnect_multiplier: 2.0,
keep_alive_interval: Some(Duration::from_secs(30)),
idle_timeout: Some(Duration::from_secs(300)),
}
}
}
pub struct TransportManager<T: Transport> {
transport: Arc<Mutex<T>>,
config: ManagerConfig,
state: Arc<RwLock<ConnectionState>>,
stats: Arc<RwLock<ConnectionStats>>,
last_activity: Arc<RwLock<Instant>>,
reconnect_delay: Arc<RwLock<Duration>>,
}
impl<T: Transport + 'static> TransportManager<T> {
pub fn new(transport: T, config: ManagerConfig) -> Self {
let initial_delay = config.reconnect_delay;
Self {
transport: Arc::new(Mutex::new(transport)),
config,
state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
stats: Arc::new(RwLock::new(ConnectionStats::default())),
last_activity: Arc::new(RwLock::new(Instant::now())),
reconnect_delay: Arc::new(RwLock::new(initial_delay)),
}
}
pub async fn state(&self) -> ConnectionState {
*self.state.read().await
}
pub async fn stats(&self) -> ConnectionStats {
let mut stats = self.stats.read().await.clone();
if *self.state.read().await == ConnectionState::Connected {
if let Some(connected_at) = stats.last_connected {
stats.uptime = Some(connected_at.elapsed());
}
}
stats
}
#[instrument(skip(self))]
pub async fn connect(&self) -> Result<()> {
let current_state = *self.state.read().await;
match current_state {
ConnectionState::Connected => return Ok(()),
ConnectionState::Connecting => {
return Err(MqttError::ConnectionError(
"Connection already in progress".to_string(),
))
}
_ => {}
}
debug!(state = ?ConnectionState::Connecting, "Transport state changed");
*self.state.write().await = ConnectionState::Connecting;
let mut transport = self.transport.lock().await;
match transport.connect().await {
Ok(()) => {
debug!(state = ?ConnectionState::Connected, "Transport state changed");
*self.state.write().await = ConnectionState::Connected;
let mut stats = self.stats.write().await;
stats.connections_established += 1;
stats.last_connected = Some(Instant::now());
*self.reconnect_delay.write().await = self.config.reconnect_delay;
*self.last_activity.write().await = Instant::now();
Ok(())
}
Err(e) => {
debug!(state = ?ConnectionState::Disconnected, "Transport state changed");
*self.state.write().await = ConnectionState::Disconnected;
let mut stats = self.stats.write().await;
stats.connection_failures += 1;
Err(e)
}
}
}
#[instrument(skip(self))]
pub async fn disconnect(&self) -> Result<()> {
debug!(state = ?ConnectionState::Closing, "Transport state changed");
*self.state.write().await = ConnectionState::Closing;
let mut transport = self.transport.lock().await;
let result = transport.close().await;
debug!(state = ?ConnectionState::Disconnected, "Transport state changed");
*self.state.write().await = ConnectionState::Disconnected;
let mut stats = self.stats.write().await;
stats.last_disconnected = Some(Instant::now());
stats.uptime = None;
result
}
#[instrument(skip(self, buf), fields(buf_len = buf.len()), level = "debug")]
pub async fn read(&self, buf: &mut [u8]) -> Result<usize> {
if *self.state.read().await != ConnectionState::Connected {
return Err(MqttError::NotConnected);
}
let mut transport = self.transport.lock().await;
let result = transport.read(buf).await;
if let Ok(n) = result {
*self.last_activity.write().await = Instant::now();
self.stats.write().await.bytes_received = self
.stats
.read()
.await
.bytes_received
.saturating_add(n as u64);
}
result
}
#[instrument(skip(self, buf), fields(buf_len = buf.len()), level = "debug")]
pub async fn write(&self, buf: &[u8]) -> Result<()> {
if *self.state.read().await != ConnectionState::Connected {
return Err(MqttError::NotConnected);
}
let mut transport = self.transport.lock().await;
let result = transport.write(buf).await;
if result.is_ok() {
*self.last_activity.write().await = Instant::now();
self.stats.write().await.bytes_sent = self
.stats
.read()
.await
.bytes_sent
.saturating_add(buf.len() as u64);
}
result
}
pub fn start_background_tasks(self: Arc<Self>) {
if let Some(keep_alive_interval) = self.config.keep_alive_interval {
let manager = Arc::clone(&self);
tokio::spawn(async move {
let mut ticker = interval(keep_alive_interval);
ticker.tick().await;
loop {
ticker.tick().await;
if *manager.state.read().await == ConnectionState::Connected {
*manager.last_activity.write().await = Instant::now();
}
}
});
}
if let Some(idle_timeout) = self.config.idle_timeout {
let manager = Arc::clone(&self);
tokio::spawn(async move {
let mut ticker = interval(Duration::from_secs(10)); ticker.tick().await;
loop {
ticker.tick().await;
if *manager.state.read().await == ConnectionState::Connected {
let last_activity = *manager.last_activity.read().await;
if last_activity.elapsed() > idle_timeout {
warn!(idle_secs = ?last_activity.elapsed().as_secs(), "Idle timeout triggered");
if let Err(e) = manager.disconnect().await {
warn!(error = %e, "Failed to disconnect on idle timeout");
}
}
}
}
});
}
if self.config.auto_reconnect {
let manager = Arc::clone(&self);
tokio::spawn(async move {
let mut ticker = interval(Duration::from_secs(1));
ticker.tick().await;
loop {
ticker.tick().await;
if *manager.state.read().await == ConnectionState::Disconnected {
let delay = *manager.reconnect_delay.read().await;
tokio::time::sleep(delay).await;
if manager.connect().await.is_err() {
let mut new_delay = manager.reconnect_delay.write().await;
*new_delay = Duration::from_secs_f64(
(new_delay.as_secs_f64() * manager.config.reconnect_multiplier)
.min(manager.config.max_reconnect_delay.as_secs_f64()),
);
}
}
}
});
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::tcp::{TcpConfig, TcpTransport};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
#[tokio::test]
async fn test_manager_creation() {
let transport =
TcpTransport::from_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1883));
let manager = TransportManager::new(transport, ManagerConfig::default());
assert_eq!(manager.state().await, ConnectionState::Disconnected);
let stats = manager.stats().await;
assert_eq!(stats.connections_established, 0);
assert_eq!(stats.connection_failures, 0);
}
#[tokio::test]
async fn test_manager_connect_not_available() {
let transport = TcpTransport::new(
TcpConfig::new(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1)), 1883,
))
.with_connect_timeout(Duration::from_millis(100)),
);
let manager = TransportManager::new(transport, ManagerConfig::default());
let result = manager.connect().await;
assert!(result.is_err());
assert_eq!(manager.state().await, ConnectionState::Disconnected);
let stats = manager.stats().await;
assert_eq!(stats.connection_failures, 1);
}
#[tokio::test]
async fn test_manager_read_write_not_connected() {
let transport =
TcpTransport::from_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1883));
let manager = TransportManager::new(transport, ManagerConfig::default());
let mut buf = [0u8; 10];
assert!(manager.read(&mut buf).await.is_err());
assert!(manager.write(b"test").await.is_err());
}
#[test]
fn test_manager_config_default() {
let config = ManagerConfig::default();
assert!(config.auto_reconnect);
assert_eq!(config.reconnect_delay, Duration::from_secs(1));
assert_eq!(config.max_reconnect_delay, Duration::from_secs(60));
assert!((config.reconnect_multiplier - 2.0).abs() < f64::EPSILON);
assert_eq!(config.keep_alive_interval, Some(Duration::from_secs(30)));
assert_eq!(config.idle_timeout, Some(Duration::from_secs(300)));
}
}