use crate::client::Client;
use crate::request::IqError;
use log::{debug, warn};
use rand::Rng;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use wacore_ng::iq::spec::IqSpec;
fn ms_since(timestamp_ms: u64) -> Option<u64> {
if timestamp_ms == 0 {
return None;
}
let now = chrono::Utc::now().timestamp_millis() as u64;
Some(now.saturating_sub(timestamp_ms))
}
fn is_dead_socket(last_sent_ms: u64, last_received_ms: u64) -> bool {
if last_sent_ms == 0 {
return false;
}
if last_received_ms >= last_sent_ms {
return false;
}
ms_since(last_sent_ms)
.map(|elapsed| elapsed > DEAD_SOCKET_TIME.as_millis() as u64)
.unwrap_or(false)
}
const KEEP_ALIVE_INTERVAL_MIN: Duration = Duration::from_secs(15);
const KEEP_ALIVE_INTERVAL_MAX: Duration = Duration::from_secs(30);
const KEEP_ALIVE_RESPONSE_DEADLINE: Duration = Duration::from_secs(20);
const DEAD_SOCKET_TIME: Duration = Duration::from_secs(20);
#[derive(Debug, PartialEq)]
enum KeepaliveResult {
Ok,
TransientFailure,
FatalFailure,
}
fn classify_keepalive_error(e: &IqError) -> KeepaliveResult {
match e {
IqError::Socket(_)
| IqError::Disconnected(_)
| IqError::NotConnected
| IqError::InternalChannelClosed => KeepaliveResult::FatalFailure,
IqError::Timeout | IqError::ServerError { .. } | IqError::ParseError(_) => {
KeepaliveResult::TransientFailure
}
}
}
impl Client {
async fn send_keepalive(&self) -> KeepaliveResult {
if !self.is_connected() {
return KeepaliveResult::FatalFailure;
}
let has_pending = !self.response_waiters.lock().await.is_empty();
if has_pending {
debug!(target: "Client/Keepalive", "Skipping ping: IQ responses pending");
return KeepaliveResult::Ok;
}
debug!(target: "Client/Keepalive", "Sending keepalive ping");
let start_ms = chrono::Utc::now().timestamp_millis();
let iq = wacore_ng::iq::keepalive::KeepaliveSpec::with_timeout(KEEP_ALIVE_RESPONSE_DEADLINE)
.build_iq();
match self.send_iq(iq).await {
Ok(response_node) => {
let end_ms = chrono::Utc::now().timestamp_millis();
let rtt_ms = end_ms - start_ms;
debug!(target: "Client/Keepalive", "Received keepalive pong (RTT: {rtt_ms}ms)");
self.unified_session.update_server_time_offset_with_rtt(
&response_node,
start_ms,
rtt_ms,
);
KeepaliveResult::Ok
}
Err(e) => {
let result = classify_keepalive_error(&e);
warn!(target: "Client/Keepalive", "Keepalive ping failed: {e:?}");
result
}
}
}
pub(crate) async fn keepalive_loop(self: Arc<Self>) {
let mut error_count = 0u32;
loop {
let interval_ms = rand::rng().random_range(
KEEP_ALIVE_INTERVAL_MIN.as_millis()..=KEEP_ALIVE_INTERVAL_MAX.as_millis(),
);
let interval = Duration::from_millis(interval_ms as u64);
tokio::select! {
_ = tokio::time::sleep(interval) => {
if !self.is_connected() {
debug!(target: "Client/Keepalive", "Not connected, exiting keepalive loop.");
return;
}
let last_sent = self.last_data_sent_ms.load(Ordering::Relaxed);
let last_recv = self.last_data_received_ms.load(Ordering::Relaxed);
if is_dead_socket(last_sent, last_recv) {
let elapsed = ms_since(last_sent).unwrap_or(0);
warn!(
target: "Client/Keepalive",
"No data received for {:.1}s after send (dead socket), forcing reconnect.",
elapsed as f64 / 1000.0
);
self.reconnect_immediately().await;
return;
}
if let Some(since_recv) = ms_since(last_recv)
&& since_recv < KEEP_ALIVE_INTERVAL_MIN.as_millis() as u64
{
if error_count > 0 {
debug!(target: "Client/Keepalive", "Keepalive restored (recent activity).");
error_count = 0;
}
continue;
}
match self.send_keepalive().await {
KeepaliveResult::Ok => {
if error_count > 0 {
debug!(target: "Client/Keepalive", "Keepalive restored after {error_count} failure(s).");
}
error_count = 0;
}
KeepaliveResult::FatalFailure => {
debug!(target: "Client/Keepalive", "Fatal keepalive failure, exiting loop.");
return;
}
KeepaliveResult::TransientFailure => {
error_count += 1;
warn!(target: "Client/Keepalive", "Keepalive timeout, error count: {error_count}");
}
}
},
_ = self.shutdown_notifier.notified() => {
debug!(target: "Client/Keepalive", "Shutdown signaled, exiting keepalive loop.");
return;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::socket::error::SocketError;
use wacore_binary_ng::builder::NodeBuilder;
#[test]
fn test_classify_timeout_is_transient() {
assert_eq!(
classify_keepalive_error(&IqError::Timeout),
KeepaliveResult::TransientFailure,
"Timeout should be transient — connection may recover"
);
}
#[test]
fn test_classify_not_connected_is_fatal() {
assert_eq!(
classify_keepalive_error(&IqError::NotConnected),
KeepaliveResult::FatalFailure,
);
}
#[test]
fn test_classify_internal_channel_closed_is_fatal() {
assert_eq!(
classify_keepalive_error(&IqError::InternalChannelClosed),
KeepaliveResult::FatalFailure,
);
}
#[test]
fn test_classify_socket_error_is_fatal() {
assert_eq!(
classify_keepalive_error(&IqError::Socket(SocketError::Crypto("test".to_string()))),
KeepaliveResult::FatalFailure,
);
}
#[test]
fn test_classify_disconnected_is_fatal() {
let node = NodeBuilder::new("disconnect").build();
assert_eq!(
classify_keepalive_error(&IqError::Disconnected(node)),
KeepaliveResult::FatalFailure,
);
}
#[test]
fn test_classify_server_error_is_transient() {
assert_eq!(
classify_keepalive_error(&IqError::ServerError {
code: 500,
text: "internal".to_string()
}),
KeepaliveResult::TransientFailure,
"ServerError should be transient — server may recover"
);
}
#[test]
fn test_classify_parse_error_is_transient() {
assert_eq!(
classify_keepalive_error(&IqError::ParseError(anyhow::anyhow!("bad response"))),
KeepaliveResult::TransientFailure,
"ParseError should be transient — bad response, not a dead connection"
);
}
#[test]
fn test_ms_since_never_set() {
assert_eq!(ms_since(0), None, "should return None when timestamp is 0");
}
#[test]
fn test_ms_since_recent() {
let now_ms = chrono::Utc::now().timestamp_millis() as u64;
let elapsed = ms_since(now_ms).unwrap();
assert!(elapsed < 100, "should be near-zero, got {elapsed}ms");
}
#[test]
fn test_ms_since_stale() {
let thirty_sec_ago = (chrono::Utc::now().timestamp_millis() as u64).saturating_sub(30_000);
let elapsed = ms_since(thirty_sec_ago).unwrap();
assert!(
(29_000..=31_000).contains(&elapsed),
"should be ~30s, got {elapsed}ms"
);
}
#[test]
fn test_dead_socket_never_sent() {
assert!(!is_dead_socket(0, 0));
}
#[test]
fn test_dead_socket_received_after_send() {
let t = chrono::Utc::now().timestamp_millis() as u64;
assert!(!is_dead_socket(t, t + 1));
}
#[test]
fn test_dead_socket_sent_recently() {
let now = chrono::Utc::now().timestamp_millis() as u64;
assert!(!is_dead_socket(now, 0));
}
#[test]
fn test_dead_socket_sent_long_ago_no_reply() {
let thirty_ago = (chrono::Utc::now().timestamp_millis() as u64).saturating_sub(30_000);
assert!(is_dead_socket(thirty_ago, 0));
}
#[test]
fn test_dead_socket_sent_long_ago_old_reply() {
let thirty_ago = (chrono::Utc::now().timestamp_millis() as u64).saturating_sub(30_000);
let thirty_one_ago = thirty_ago.saturating_sub(1_000);
assert!(is_dead_socket(thirty_ago, thirty_one_ago));
}
#[test]
fn test_dead_socket_sent_long_ago_recent_reply() {
let thirty_ago = (chrono::Utc::now().timestamp_millis() as u64).saturating_sub(30_000);
let one_ago = (chrono::Utc::now().timestamp_millis() as u64).saturating_sub(1_000);
assert!(!is_dead_socket(thirty_ago, one_ago));
}
#[test]
fn test_keepalive_interval_matches_wa_web() {
assert_eq!(KEEP_ALIVE_INTERVAL_MIN, Duration::from_secs(15));
assert_eq!(KEEP_ALIVE_INTERVAL_MAX, Duration::from_secs(30));
}
#[test]
fn test_dead_socket_time_matches_wa_web() {
assert_eq!(DEAD_SOCKET_TIME, Duration::from_secs(20));
}
}