use crate::error::Result;
use crate::packet::Packet;
use crate::transport::PacketWriter;
use parking_lot::Mutex;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use tokio::time::Duration;
use super::unified::UnifiedWriter;
#[cfg(feature = "transport-quic")]
use crate::session::SessionState;
#[derive(Debug, Default)]
pub(crate) struct KeepaliveState {
last_ping_sent: Option<tokio::time::Instant>,
last_pong_received: Option<tokio::time::Instant>,
}
impl KeepaliveState {
pub(crate) fn reset(&mut self) {
self.last_ping_sent = None;
self.last_pong_received = None;
}
pub(crate) fn has_outstanding_ping(&self) -> bool {
match (self.last_ping_sent, self.last_pong_received) {
(Some(sent_at), Some(received_at)) => sent_at > received_at,
(Some(_), None) => true,
_ => false,
}
}
pub(crate) fn record_ping_sent(&mut self) {
self.last_ping_sent = Some(tokio::time::Instant::now());
}
pub(crate) fn record_pong_received(&mut self) {
self.last_pong_received = Some(tokio::time::Instant::now());
}
pub(crate) fn is_timeout(&self, timeout_duration: Duration) -> bool {
match (self.last_ping_sent, self.last_pong_received) {
(Some(sent_at), Some(received_at)) => {
sent_at > received_at && sent_at.elapsed() > timeout_duration
}
(Some(sent_at), None) => sent_at.elapsed() > timeout_duration,
_ => false,
}
}
}
pub(crate) fn owns_current_connection(
connection_epoch: u64,
current_connection_epoch: &AtomicU64,
) -> bool {
current_connection_epoch.load(Ordering::SeqCst) == connection_epoch
}
pub(crate) fn mark_disconnected_if_current(
connected: &AtomicBool,
connection_epoch: u64,
current_connection_epoch: &AtomicU64,
) {
if owns_current_connection(connection_epoch, current_connection_epoch) {
connected.store(false, Ordering::SeqCst);
}
}
pub(super) const PINGREQ_LOG_INTERVAL: u32 = 20;
pub(super) async fn send_pingreq_with_priority(
writer: &Arc<tokio::sync::Mutex<UnifiedWriter>>,
config: &mqtt5_protocol::KeepaliveConfig,
) -> Result<()> {
let max_attempts = config.lock_retry_attempts;
let retry_delay = Duration::from_millis(u64::from(config.lock_retry_delay_ms));
for attempt in 0..max_attempts {
if let Ok(mut guard) = writer.try_lock() {
return guard.write_packet(Packet::PingReq).await;
}
if attempt > 0 && attempt % PINGREQ_LOG_INTERVAL == 0 {
tracing::warn!(
attempt,
max_attempts,
"PINGREQ waiting for writer lock - possible contention"
);
}
tokio::time::sleep(retry_delay).await;
}
tracing::error!(
max_attempts,
"Failed to acquire writer lock for PINGREQ, falling back to blocking"
);
writer.lock().await.write_packet(Packet::PingReq).await
}
pub(super) async fn keepalive_task_with_writer(
writer: Arc<tokio::sync::Mutex<UnifiedWriter>>,
keepalive_interval: Duration,
keepalive_state: Arc<Mutex<KeepaliveState>>,
connected: Arc<AtomicBool>,
connection_epoch: u64,
current_connection_epoch: Arc<AtomicU64>,
keepalive_config: Option<mqtt5_protocol::KeepaliveConfig>,
) {
let config = keepalive_config.unwrap_or_default();
let ping_interval = config.ping_interval(keepalive_interval);
let timeout_duration = config.timeout_duration(keepalive_interval);
let mut interval = tokio::time::interval(ping_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
interval.tick().await;
loop {
interval.tick().await;
{
let state = keepalive_state.lock();
if state.is_timeout(timeout_duration) {
tracing::error!("Keepalive timeout - no PINGRESP received");
mark_disconnected_if_current(
&connected,
connection_epoch,
¤t_connection_epoch,
);
break;
}
}
let should_send_ping = {
let mut state = keepalive_state.lock();
let should_send_ping = !state.has_outstanding_ping();
if should_send_ping {
state.record_ping_sent();
}
should_send_ping
};
if !should_send_ping {
continue;
}
match tokio::time::timeout(
timeout_duration,
send_pingreq_with_priority(&writer, &config),
)
.await
{
Ok(Ok(())) => {}
Ok(Err(e)) => {
tracing::error!("Error sending PINGREQ: {e}");
mark_disconnected_if_current(
&connected,
connection_epoch,
¤t_connection_epoch,
);
break;
}
Err(_) => {
tracing::error!("PINGREQ send timed out");
mark_disconnected_if_current(
&connected,
connection_epoch,
¤t_connection_epoch,
);
break;
}
}
}
}
#[cfg(feature = "transport-quic")]
pub(super) async fn flow_expiration_task(session: Arc<tokio::sync::RwLock<SessionState>>) {
let check_interval = Duration::from_secs(60);
let mut interval = tokio::time::interval(check_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
interval.tick().await;
loop {
interval.tick().await;
let expired = session.read().await.expire_flows().await;
if !expired.is_empty() {
tracing::debug!(count = expired.len(), "Expired {} flows", expired.len());
}
}
}
#[cfg(test)]
mod tests {
use super::{mark_disconnected_if_current, owns_current_connection, KeepaliveState};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use tokio::time::{Duration, Instant};
#[test]
fn stale_epoch_is_not_current() {
assert!(!owns_current_connection(1, &AtomicU64::new(2)));
}
#[test]
fn current_epoch_matches() {
assert!(owns_current_connection(2, &AtomicU64::new(2)));
}
#[test]
fn stale_keepalive_epoch_does_not_disconnect_current_connection() {
let connected = AtomicBool::new(true);
let current_epoch = AtomicU64::new(2);
mark_disconnected_if_current(&connected, 1, ¤t_epoch);
assert!(connected.load(Ordering::SeqCst));
}
#[test]
fn current_keepalive_epoch_disconnects_connection() {
let connected = AtomicBool::new(true);
let current_epoch = AtomicU64::new(2);
mark_disconnected_if_current(&connected, 2, ¤t_epoch);
assert!(!connected.load(Ordering::SeqCst));
}
#[test]
fn reset_clears_keepalive_tracking() {
let mut state = KeepaliveState {
last_ping_sent: Some(Instant::now() - Duration::from_secs(2)),
last_pong_received: Some(Instant::now() - Duration::from_secs(1)),
};
state.reset();
assert_eq!(state.last_ping_sent, None);
assert_eq!(state.last_pong_received, None);
assert!(!state.has_outstanding_ping());
}
}