mod context_impl;
mod device_registry;
mod lid_pn;
mod sender_keys;
mod sessions;
use crate::cache::Cache;
use crate::cache_store::TypedCache;
use crate::handshake;
use crate::lid_pn_cache::LidPnCache;
use crate::pair;
use anyhow::{Result, anyhow};
use futures::FutureExt;
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use wacore::xml::DisplayableNode;
use wacore_binary::builder::NodeBuilder;
use wacore_binary::jid::JidExt;
use wacore_binary::node::{Attrs, Node, NodeValue};
use crate::appstate_sync::AppStateProcessor;
use crate::handlers::chatstate::ChatStateEvent;
use crate::jid_utils::server_jid;
use crate::store::{commands::DeviceCommand, persistence_manager::PersistenceManager};
use crate::types::enc_handler::EncHandler;
use crate::types::events::{ConnectFailureReason, Event};
use log::{debug, error, info, trace, warn};
use rand::{Rng, RngExt};
use scopeguard;
use wacore_binary::jid::Jid;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, AtomicUsize, Ordering};
#[derive(Debug, Clone)]
pub struct NodeFilter {
tag: String,
attrs: Vec<(String, String)>,
}
impl NodeFilter {
pub fn tag(tag: impl Into<String>) -> Self {
Self {
tag: tag.into(),
attrs: Vec::new(),
}
}
pub fn attr(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.attrs.push((key.into(), value.into()));
self
}
pub fn from_jid(self, jid: &Jid) -> Self {
self.attr("from", jid.to_string())
}
fn matches(&self, node: &Node) -> bool {
node.tag == self.tag
&& self
.attrs
.iter()
.all(|(k, v)| node.attrs.get(k.as_str()).is_some_and(|attr| *attr == *v))
}
}
struct NodeWaiter {
filter: NodeFilter,
tx: futures::channel::oneshot::Sender<Arc<Node>>,
}
use async_lock::Mutex;
use async_lock::RwLock;
use std::time::Duration;
use thiserror::Error;
use wacore::appstate::patch_decode::WAPatchName;
use wacore::client::context::GroupInfo;
use wacore::runtime::timeout as rt_timeout;
use waproto::whatsapp as wa;
use crate::cache_config::CacheConfig;
use crate::socket::{NoiseSocket, SocketError, error::EncryptSendError};
use crate::sync_task::MajorSyncTask;
use wacore::runtime::Runtime;
type ChatStateHandler = Arc<dyn Fn(ChatStateEvent) + Send + Sync>;
const APP_STATE_RETRY_MAX_ATTEMPTS: u32 = 6;
const TRANSPORT_CONNECT_TIMEOUT: Duration = Duration::from_secs(20);
#[cfg(feature = "debug-diagnostics")]
#[derive(Debug, Clone)]
pub struct MemoryDiagnostics {
pub group_cache: u64,
pub device_cache: u64,
pub device_registry_cache: u64,
pub lid_pn_lid_entries: u64,
pub lid_pn_pn_entries: u64,
pub retried_group_messages: u64,
pub recent_messages: u64,
pub message_retry_counts: u64,
pub pdo_pending_requests: u64,
pub session_locks: u64,
pub message_queues: u64,
pub message_enqueue_locks: u64,
pub response_waiters: usize,
pub node_waiters: usize,
pub pending_retries: usize,
pub presence_subscriptions: usize,
pub app_state_key_requests: usize,
pub app_state_syncing: usize,
pub signal_cache_sessions: usize,
pub signal_cache_identities: usize,
pub signal_cache_sender_keys: usize,
pub chatstate_handlers: usize,
pub custom_enc_handlers: usize,
}
#[cfg(feature = "debug-diagnostics")]
impl std::fmt::Display for MemoryDiagnostics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "=== Memory Diagnostics ===")?;
writeln!(f, "--- Moka caches (TTL-bounded) ---")?;
writeln!(f, " group_cache: {}", self.group_cache)?;
writeln!(f, " device_cache: {}", self.device_cache)?;
writeln!(
f,
" device_registry_cache: {}",
self.device_registry_cache
)?;
writeln!(f, " lid_pn (lid): {}", self.lid_pn_lid_entries)?;
writeln!(f, " lid_pn (pn): {}", self.lid_pn_pn_entries)?;
writeln!(
f,
" retried_group_messages: {}",
self.retried_group_messages
)?;
writeln!(f, " recent_messages: {}", self.recent_messages)?;
writeln!(f, " message_retry_counts: {}", self.message_retry_counts)?;
writeln!(f, " pdo_pending_requests: {}", self.pdo_pending_requests)?;
writeln!(f, "--- Moka caches (capacity-only) ---")?;
writeln!(f, " session_locks: {}", self.session_locks)?;
writeln!(f, " message_queues: {}", self.message_queues)?;
writeln!(
f,
" message_enqueue_locks: {}",
self.message_enqueue_locks
)?;
writeln!(f, "--- Unbounded collections ---")?;
writeln!(f, " response_waiters: {}", self.response_waiters)?;
writeln!(f, " node_waiters: {}", self.node_waiters)?;
writeln!(f, " pending_retries: {}", self.pending_retries)?;
writeln!(
f,
" presence_subscriptions: {}",
self.presence_subscriptions
)?;
writeln!(
f,
" app_state_key_requests: {}",
self.app_state_key_requests
)?;
writeln!(f, " app_state_syncing: {}", self.app_state_syncing)?;
writeln!(
f,
" signal_sessions: {}",
self.signal_cache_sessions
)?;
writeln!(
f,
" signal_identities: {}",
self.signal_cache_identities
)?;
writeln!(
f,
" signal_sender_keys: {}",
self.signal_cache_sender_keys
)?;
writeln!(f, "--- Misc ---")?;
writeln!(f, " chatstate_handlers: {}", self.chatstate_handlers)?;
writeln!(f, " custom_enc_handlers: {}", self.custom_enc_handlers)?;
Ok(())
}
}
#[derive(Debug, Error)]
pub enum ClientError {
#[error("client is not connected")]
NotConnected,
#[error("socket error: {0}")]
Socket(#[from] SocketError),
#[error("encrypt/send error: {0}")]
EncryptSend(#[from] EncryptSendError),
#[error("client is already connected")]
AlreadyConnected,
#[error("client is not logged in")]
NotLoggedIn,
}
use wacore::types::message::StanzaKey;
#[derive(Debug)]
pub(crate) struct OfflineSyncMetrics {
pub active: AtomicBool,
pub total_messages: AtomicUsize,
pub processed_messages: AtomicUsize,
pub start_time: std::sync::Mutex<Option<wacore::time::Instant>>,
}
pub struct Client {
pub(crate) runtime: Arc<dyn Runtime>,
pub(crate) core: wacore::client::CoreClient,
pub(crate) persistence_manager: Arc<PersistenceManager>,
pub(crate) media_conn: Arc<RwLock<Option<crate::mediaconn::MediaConn>>>,
pub(crate) is_logged_in: Arc<AtomicBool>,
pub(crate) is_connecting: Arc<AtomicBool>,
pub(crate) is_running: Arc<AtomicBool>,
is_connected: Arc<AtomicBool>,
pub(crate) shutdown_notifier: Arc<event_listener::Event>,
pub(crate) last_data_received_ms: Arc<AtomicU64>,
pub(crate) last_data_sent_ms: Arc<AtomicU64>,
pub(crate) transport: Arc<Mutex<Option<Arc<dyn crate::transport::Transport>>>>,
pub(crate) transport_events:
Arc<Mutex<Option<async_channel::Receiver<crate::transport::TransportEvent>>>>,
pub(crate) transport_factory: Arc<dyn crate::transport::TransportFactory>,
pub(crate) noise_socket: Arc<Mutex<Option<Arc<NoiseSocket>>>>,
pub(crate) response_waiters:
Arc<Mutex<HashMap<String, futures::channel::oneshot::Sender<wacore_binary::Node>>>>,
node_waiters: std::sync::Mutex<Vec<NodeWaiter>>,
node_waiter_count: AtomicUsize,
pub(crate) unique_id: String,
pub(crate) id_counter: Arc<AtomicU64>,
pub(crate) unified_session: crate::unified_session::UnifiedSessionManager,
pub(crate) signal_cache: Arc<crate::store::signal_cache::SignalStoreCache>,
pub(crate) message_processing_semaphore: std::sync::Mutex<Arc<async_lock::Semaphore>>,
pub(crate) message_semaphore_generation: Arc<AtomicU64>,
pub(crate) session_locks: Cache<String, Arc<async_lock::Mutex<()>>>,
pub(crate) message_queues: Cache<String, async_channel::Sender<Arc<Node>>>,
pub(crate) lid_pn_cache: Arc<LidPnCache>,
pub(crate) message_enqueue_locks: Cache<String, Arc<async_lock::Mutex<()>>>,
pub group_cache: async_lock::Mutex<Option<Arc<TypedCache<Jid, GroupInfo>>>>,
#[allow(clippy::type_complexity)]
pub device_cache: async_lock::Mutex<Option<Arc<TypedCache<Jid, Vec<Jid>>>>>,
pub(crate) retried_group_messages: Cache<String, ()>,
pub(crate) expected_disconnect: Arc<AtomicBool>,
pub(crate) connection_generation: Arc<AtomicU64>,
pub(crate) recent_messages: Cache<StanzaKey, Vec<u8>>,
pub(crate) pending_retries: Arc<async_lock::Mutex<HashSet<String>>>,
pub(crate) message_retry_counts: Cache<String, u8>,
pub enable_auto_reconnect: Arc<AtomicBool>,
pub auto_reconnect_errors: Arc<AtomicU32>,
pub(crate) needs_initial_full_sync: Arc<AtomicBool>,
pub(crate) app_state_processor: async_lock::Mutex<Option<Arc<AppStateProcessor>>>,
pub(crate) app_state_key_requests: Arc<Mutex<HashMap<String, wacore::time::Instant>>>,
pub(crate) app_state_syncing: Arc<Mutex<HashSet<WAPatchName>>>,
pub(crate) initial_keys_synced_notifier: Arc<event_listener::Event>,
pub(crate) initial_app_state_keys_received: Arc<AtomicBool>,
pub(crate) server_has_prekeys: Arc<AtomicBool>,
pub(crate) prekey_upload_lock: Arc<async_lock::Mutex<()>>,
pub(crate) offline_sync_notifier: Arc<event_listener::Event>,
pub(crate) offline_sync_completed: Arc<AtomicBool>,
pub(crate) history_sync_tasks_in_flight: Arc<AtomicUsize>,
pub(crate) history_sync_idle_notifier: Arc<event_listener::Event>,
pub(crate) presence_subscriptions: Arc<async_lock::Mutex<HashSet<Jid>>>,
pub(crate) offline_sync_metrics: Arc<OfflineSyncMetrics>,
pub(crate) socket_ready_notifier: Arc<event_listener::Event>,
pub(crate) is_ready: Arc<AtomicBool>,
pub(crate) connected_notifier: Arc<event_listener::Event>,
pub(crate) major_sync_task_sender: async_channel::Sender<MajorSyncTask>,
pub(crate) pairing_cancellation_tx: Arc<Mutex<Option<async_channel::Sender<()>>>>,
pub(crate) pair_code_state: Arc<Mutex<wacore::pair_code::PairCodeState>>,
pub custom_enc_handlers: Arc<async_lock::RwLock<HashMap<String, Arc<dyn EncHandler>>>>,
pub(crate) chatstate_handlers: Arc<RwLock<Vec<ChatStateHandler>>>,
pub(crate) pdo_pending_requests: Cache<String, crate::pdo::PendingPdoRequest>,
pub(crate) device_registry_cache: TypedCache<String, wacore::store::traits::DeviceListRecord>,
pub(crate) stanza_router: crate::handlers::router::StanzaRouter,
pub(crate) synchronous_ack: bool,
pub http_client: Arc<dyn crate::http::HttpClient>,
pub(crate) override_version: Option<(u32, u32, u32)>,
pub(crate) skip_history_sync: AtomicBool,
pub(crate) cache_config: CacheConfig,
}
impl Client {
pub(crate) fn swap_message_semaphore(&self, permits: usize) {
let mut guard = match self.message_processing_semaphore.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
*guard = Arc::new(async_lock::Semaphore::new(permits));
self.message_semaphore_generation
.fetch_add(1, Ordering::SeqCst);
}
fn should_downgrade_sync_error(&self, err: &anyhow::Error) -> bool {
if self.is_shutting_down() {
return true;
}
matches!(
err.downcast_ref::<crate::request::IqError>(),
Some(
crate::request::IqError::NotConnected
| crate::request::IqError::InternalChannelClosed
)
)
}
fn log_sync_error(&self, context: &str, err: &anyhow::Error) {
if self.should_downgrade_sync_error(err) {
debug!("Skipping {context} during shutdown: {err}");
} else {
warn!("Failed {context}: {err}");
}
}
fn is_fully_ready(&self) -> bool {
self.is_connected() && self.is_logged_in() && self.is_ready.load(Ordering::Relaxed)
}
fn dispatch_connected(&self) {
self.is_ready.store(true, Ordering::Relaxed);
self.core
.event_bus
.dispatch(&Event::Connected(crate::types::events::Connected));
self.connected_notifier.notify(usize::MAX);
}
pub fn set_skip_history_sync(&self, enabled: bool) {
self.skip_history_sync.store(enabled, Ordering::Relaxed);
}
pub async fn process_sync_task(self: &Arc<Self>, task: crate::sync_task::MajorSyncTask) {
match task {
crate::sync_task::MajorSyncTask::HistorySync {
message_id,
notification,
} => {
self.process_history_sync_task(message_id, *notification)
.await;
self.finish_history_sync_task();
}
crate::sync_task::MajorSyncTask::AppStateSync { name, full_sync } => {
if let Err(e) = self.process_app_state_sync_task(name, full_sync).await {
log::warn!("App state sync task for {name:?} failed: {e}");
}
}
}
}
pub fn skip_history_sync_enabled(&self) -> bool {
self.skip_history_sync.load(Ordering::Relaxed)
}
pub(crate) fn is_shutting_down(&self) -> bool {
self.expected_disconnect.load(Ordering::Relaxed) || !self.is_running.load(Ordering::Relaxed)
}
pub async fn new(
runtime: Arc<dyn Runtime>,
persistence_manager: Arc<PersistenceManager>,
transport_factory: Arc<dyn crate::transport::TransportFactory>,
http_client: Arc<dyn crate::http::HttpClient>,
override_version: Option<(u32, u32, u32)>,
) -> (Arc<Self>, async_channel::Receiver<MajorSyncTask>) {
Self::new_with_cache_config(
runtime,
persistence_manager,
transport_factory,
http_client,
override_version,
CacheConfig::default(),
)
.await
}
pub async fn new_with_cache_config(
runtime: Arc<dyn Runtime>,
persistence_manager: Arc<PersistenceManager>,
transport_factory: Arc<dyn crate::transport::TransportFactory>,
http_client: Arc<dyn crate::http::HttpClient>,
override_version: Option<(u32, u32, u32)>,
cache_config: CacheConfig,
) -> (Arc<Self>, async_channel::Receiver<MajorSyncTask>) {
let mut unique_id_bytes = [0u8; 2];
rand::make_rng::<rand::rngs::StdRng>().fill_bytes(&mut unique_id_bytes);
let device_snapshot = persistence_manager.get_device_snapshot().await;
let core = wacore::client::CoreClient::new(device_snapshot.core.clone());
let (tx, rx) = async_channel::bounded(32);
let this = Self {
runtime: runtime.clone(),
core,
persistence_manager: persistence_manager.clone(),
media_conn: Arc::new(RwLock::new(None)),
is_logged_in: Arc::new(AtomicBool::new(false)),
is_connecting: Arc::new(AtomicBool::new(false)),
is_running: Arc::new(AtomicBool::new(false)),
is_connected: Arc::new(AtomicBool::new(false)),
shutdown_notifier: Arc::new(event_listener::Event::new()),
last_data_received_ms: Arc::new(AtomicU64::new(0)),
last_data_sent_ms: Arc::new(AtomicU64::new(0)),
transport: Arc::new(Mutex::new(None)),
transport_events: Arc::new(Mutex::new(None)),
transport_factory,
noise_socket: Arc::new(Mutex::new(None)),
response_waiters: Arc::new(Mutex::new(HashMap::new())),
node_waiters: std::sync::Mutex::new(Vec::new()),
node_waiter_count: AtomicUsize::new(0),
unique_id: format!("{}.{}", unique_id_bytes[0], unique_id_bytes[1]),
id_counter: Arc::new(AtomicU64::new(0)),
unified_session: crate::unified_session::UnifiedSessionManager::new(),
signal_cache: Arc::new(crate::store::signal_cache::SignalStoreCache::new()),
message_processing_semaphore: std::sync::Mutex::new(Arc::new(
async_lock::Semaphore::new(1),
)),
message_semaphore_generation: Arc::new(AtomicU64::new(0)),
session_locks: Cache::builder()
.max_capacity(cache_config.session_locks_capacity.max(1))
.build(),
message_queues: Cache::builder()
.max_capacity(cache_config.message_queues_capacity.max(1))
.build(),
lid_pn_cache: Arc::new(LidPnCache::with_config(
&cache_config.lid_pn_cache,
cache_config.cache_stores.lid_pn_cache.clone(),
)),
message_enqueue_locks: Cache::builder()
.max_capacity(cache_config.message_enqueue_locks_capacity.max(1))
.build(),
group_cache: async_lock::Mutex::new(None),
device_cache: async_lock::Mutex::new(None),
retried_group_messages: cache_config.retried_group_messages.build_with_ttl(),
expected_disconnect: Arc::new(AtomicBool::new(false)),
connection_generation: Arc::new(AtomicU64::new(0)),
recent_messages: cache_config.recent_messages.build_with_ttl(),
pending_retries: Arc::new(async_lock::Mutex::new(HashSet::new())),
message_retry_counts: cache_config.message_retry_counts.build_with_ttl(),
offline_sync_metrics: Arc::new(OfflineSyncMetrics {
active: AtomicBool::new(false),
total_messages: AtomicUsize::new(0),
processed_messages: AtomicUsize::new(0),
start_time: std::sync::Mutex::new(None),
}),
enable_auto_reconnect: Arc::new(AtomicBool::new(true)),
auto_reconnect_errors: Arc::new(AtomicU32::new(0)),
needs_initial_full_sync: Arc::new(AtomicBool::new(false)),
app_state_processor: async_lock::Mutex::new(None),
app_state_key_requests: Arc::new(Mutex::new(HashMap::new())),
app_state_syncing: Arc::new(Mutex::new(HashSet::new())),
initial_keys_synced_notifier: Arc::new(event_listener::Event::new()),
initial_app_state_keys_received: Arc::new(AtomicBool::new(false)),
server_has_prekeys: Arc::new(AtomicBool::new(true)),
prekey_upload_lock: Arc::new(async_lock::Mutex::new(())),
offline_sync_notifier: Arc::new(event_listener::Event::new()),
offline_sync_completed: Arc::new(AtomicBool::new(false)),
history_sync_tasks_in_flight: Arc::new(AtomicUsize::new(0)),
history_sync_idle_notifier: Arc::new(event_listener::Event::new()),
presence_subscriptions: Arc::new(async_lock::Mutex::new(HashSet::new())),
socket_ready_notifier: Arc::new(event_listener::Event::new()),
is_ready: Arc::new(AtomicBool::new(false)),
connected_notifier: Arc::new(event_listener::Event::new()),
major_sync_task_sender: tx,
pairing_cancellation_tx: Arc::new(Mutex::new(None)),
pair_code_state: Arc::new(Mutex::new(wacore::pair_code::PairCodeState::default())),
custom_enc_handlers: Arc::new(async_lock::RwLock::new(HashMap::new())),
chatstate_handlers: Arc::new(RwLock::new(Vec::new())),
pdo_pending_requests: cache_config.pdo_pending_requests.build_with_ttl(),
device_registry_cache: cache_config.device_registry_cache.build_typed_ttl(
cache_config.cache_stores.device_registry_cache.clone(),
"device_registry",
),
stanza_router: Self::create_stanza_router(),
synchronous_ack: false,
http_client,
override_version,
skip_history_sync: AtomicBool::new(false),
cache_config,
};
let arc = Arc::new(this);
let warm_up_arc = arc.clone();
arc.runtime
.spawn(Box::pin(async move {
if let Err(e) = warm_up_arc.warm_up_lid_pn_cache().await {
warn!("Failed to warm up LID-PN cache: {e}");
}
}))
.detach();
let cleanup_arc = arc.clone();
arc.runtime
.spawn(Box::pin(async move {
cleanup_arc.device_registry_cleanup_loop().await;
}))
.detach();
(arc, rx)
}
pub(crate) async fn get_group_cache(&self) -> Arc<TypedCache<Jid, GroupInfo>> {
let mut guard = self.group_cache.lock().await;
if let Some(cache) = guard.as_ref() {
return cache.clone();
}
debug!("Initializing Group Cache for the first time.");
let cache = Arc::new(
self.cache_config
.group_cache
.build_typed_ttl(self.cache_config.cache_stores.group_cache.clone(), "group"),
);
*guard = Some(cache.clone());
cache
}
pub(crate) async fn get_device_cache(&self) -> Arc<TypedCache<Jid, Vec<Jid>>> {
let mut guard = self.device_cache.lock().await;
if let Some(cache) = guard.as_ref() {
return cache.clone();
}
debug!("Initializing Device Cache for the first time.");
let cache = Arc::new(self.cache_config.device_cache.build_typed_ttl(
self.cache_config.cache_stores.device_cache.clone(),
"device",
));
*guard = Some(cache.clone());
cache
}
pub(crate) async fn get_app_state_processor(&self) -> Arc<AppStateProcessor> {
let mut guard = self.app_state_processor.lock().await;
if let Some(proc) = guard.as_ref() {
return proc.clone();
}
debug!("Initializing AppStateProcessor for the first time.");
let proc = Arc::new(AppStateProcessor::new(
self.persistence_manager.backend(),
self.runtime.clone(),
));
*guard = Some(proc.clone());
proc
}
fn create_stanza_router() -> crate::handlers::router::StanzaRouter {
use crate::handlers::{
basic::{AckHandler, FailureHandler, StreamErrorHandler, SuccessHandler},
chatstate::ChatstateHandler,
ib::IbHandler,
iq::IqHandler,
message::MessageHandler,
notification::NotificationHandler,
receipt::ReceiptHandler,
router::StanzaRouter,
unimplemented::UnimplementedHandler,
};
let mut router = StanzaRouter::new();
router.register(Arc::new(MessageHandler));
router.register(Arc::new(ReceiptHandler));
router.register(Arc::new(IqHandler));
router.register(Arc::new(SuccessHandler));
router.register(Arc::new(FailureHandler));
router.register(Arc::new(StreamErrorHandler));
router.register(Arc::new(IbHandler));
router.register(Arc::new(NotificationHandler));
router.register(Arc::new(AckHandler));
router.register(Arc::new(ChatstateHandler));
router.register(Arc::new(UnimplementedHandler::for_call()));
router.register(Arc::new(crate::handlers::presence::PresenceHandler));
router
}
pub fn register_handler(&self, handler: Arc<dyn wacore::types::events::EventHandler>) {
self.core.event_bus.add_handler(handler);
}
pub async fn register_chatstate_handler(
&self,
handler: Arc<dyn Fn(ChatStateEvent) + Send + Sync>,
) {
self.chatstate_handlers.write().await.push(handler);
}
pub(crate) async fn dispatch_chatstate_event(
&self,
stanza: wacore::iq::chatstate::ChatstateStanza,
) {
use wacore::iq::chatstate::{ChatstateSource, ReceivedChatState};
use wacore::types::events::ChatPresenceUpdate;
use wacore::types::message::MessageSource;
use wacore::types::presence::{ChatPresence, ChatPresenceMedia};
let (chat, sender, is_group) = match &stanza.source {
ChatstateSource::User { from } => (from.clone(), from.clone(), false),
ChatstateSource::Group { from, participant } => {
(from.clone(), participant.clone(), true)
}
};
let (state, media) = match stanza.state {
ReceivedChatState::Typing => (ChatPresence::Composing, ChatPresenceMedia::Text),
ReceivedChatState::RecordingAudio => {
(ChatPresence::Composing, ChatPresenceMedia::Audio)
}
ReceivedChatState::Idle => (ChatPresence::Paused, ChatPresenceMedia::Text),
};
self.core
.event_bus
.dispatch(&Event::ChatPresence(ChatPresenceUpdate {
source: MessageSource {
chat,
sender,
is_from_me: false,
is_group,
addressing_mode: None,
sender_alt: None,
recipient_alt: None,
broadcast_list_owner: None,
recipient: None,
},
state,
media,
}));
let event = ChatStateEvent::from_stanza(stanza);
let handlers = self.chatstate_handlers.read().await.clone();
for handler in handlers {
let event_clone = event.clone();
let handler_clone = handler.clone();
self.runtime
.spawn(Box::pin(async move {
(handler_clone)(event_clone);
}))
.detach();
}
}
pub async fn run(self: &Arc<Self>) {
if self.is_running.swap(true, Ordering::SeqCst) {
warn!("Client `run` method called while already running.");
return;
}
while self.is_running.load(Ordering::Relaxed) {
self.expected_disconnect.store(false, Ordering::Relaxed);
if let Err(connect_err) = self.connect().await {
error!("Failed to connect: {connect_err:#}. Will retry...");
} else {
if self.read_messages_loop().await.is_err() {
warn!(
"Message loop exited with an error. Will attempt to reconnect if enabled."
);
} else if self.expected_disconnect.load(Ordering::Relaxed) {
debug!("Message loop exited gracefully (expected disconnect).");
} else {
info!("Message loop exited gracefully.");
}
self.cleanup_connection_state().await;
}
if !self.enable_auto_reconnect.load(Ordering::Relaxed) {
info!("Auto-reconnect disabled, shutting down.");
self.is_running.store(false, Ordering::Relaxed);
break;
}
if self.expected_disconnect.load(Ordering::Relaxed) {
self.auto_reconnect_errors.store(0, Ordering::Relaxed);
info!("Expected disconnect (e.g., 515), reconnecting immediately...");
continue;
}
let error_count = self.auto_reconnect_errors.fetch_add(1, Ordering::SeqCst);
let delay = fibonacci_backoff(error_count);
info!(
"Will attempt to reconnect in {:?} (attempt {})",
delay,
error_count + 1
);
self.runtime.sleep(delay).await;
}
info!("Client run loop has shut down.");
}
pub async fn connect(self: &Arc<Self>) -> Result<(), anyhow::Error> {
if self.is_connecting.swap(true, Ordering::SeqCst) {
return Err(ClientError::AlreadyConnected.into());
}
let _guard = scopeguard::guard((), |_| {
self.is_connecting.store(false, Ordering::Relaxed);
});
if self.is_connected() {
return Err(ClientError::AlreadyConnected.into());
}
self.is_logged_in.store(false, Ordering::Relaxed);
self.is_ready.store(false, Ordering::Relaxed);
self.is_connected.store(false, Ordering::Relaxed);
self.offline_sync_completed.store(false, Ordering::Relaxed);
self.server_has_prekeys.store(true, Ordering::Relaxed);
let version_future = rt_timeout(
&*self.runtime,
TRANSPORT_CONNECT_TIMEOUT,
crate::version::resolve_and_update_version(
&self.persistence_manager,
&self.http_client,
self.override_version,
),
);
let transport_future = rt_timeout(
&*self.runtime,
TRANSPORT_CONNECT_TIMEOUT,
self.transport_factory.create_transport(),
);
debug!("Connecting WebSocket and fetching latest client version in parallel...");
let (version_result, transport_result) = futures::join!(version_future, transport_future);
version_result
.map_err(|_| anyhow!("Version fetch timed out after {TRANSPORT_CONNECT_TIMEOUT:?}"))?
.map_err(|e| anyhow!("Failed to resolve app version: {}", e))?;
let (transport, mut transport_events) = transport_result.map_err(|_| {
anyhow!("Transport connect timed out after {TRANSPORT_CONNECT_TIMEOUT:?}")
})??;
debug!("Version fetch and transport connection established.");
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
let noise_socket = handshake::do_handshake(
self.runtime.clone(),
&device_snapshot,
transport.clone(),
&mut transport_events,
)
.await?;
*self.transport.lock().await = Some(transport);
*self.transport_events.lock().await = Some(transport_events);
*self.noise_socket.lock().await = Some(noise_socket);
self.is_connected.store(true, Ordering::Release);
self.socket_ready_notifier.notify(usize::MAX);
let client_clone = self.clone();
self.runtime
.spawn(Box::pin(async move { client_clone.keepalive_loop().await }))
.detach();
Ok(())
}
pub async fn disconnect(self: &Arc<Self>) {
info!("Disconnecting client intentionally.");
self.expected_disconnect.store(true, Ordering::Relaxed);
self.is_running.store(false, Ordering::Relaxed);
self.shutdown_notifier.notify(usize::MAX);
if let Err(e) = self.persistence_manager.flush().await {
log::error!("Failed to flush device state during disconnect: {e}");
}
if let Some(transport) = self.transport.lock().await.as_ref() {
transport.disconnect().await;
}
self.cleanup_connection_state().await;
}
pub const RECONNECT_BACKOFF_STEP: u32 = 4;
pub async fn reconnect(self: &Arc<Self>) {
info!("Reconnecting: dropping transport for auto-reconnect.");
self.auto_reconnect_errors
.store(Self::RECONNECT_BACKOFF_STEP, Ordering::Relaxed);
if let Some(transport) = self.transport.lock().await.as_ref() {
transport.disconnect().await;
}
}
pub async fn reconnect_immediately(self: &Arc<Self>) {
info!("Reconnecting immediately (expected disconnect).");
self.expected_disconnect.store(true, Ordering::Relaxed);
if let Some(transport) = self.transport.lock().await.as_ref() {
transport.disconnect().await;
}
}
async fn cleanup_connection_state(&self) {
self.is_logged_in.store(false, Ordering::Relaxed);
self.is_ready.store(false, Ordering::Relaxed);
self.shutdown_notifier.notify(usize::MAX);
*self.transport.lock().await = None;
*self.transport_events.lock().await = None;
*self.noise_socket.lock().await = None;
self.is_connected.store(false, Ordering::Release);
self.retried_group_messages.invalidate_all();
self.signal_cache.clear().await;
self.swap_message_semaphore(1);
self.last_data_received_ms.store(0, Ordering::Relaxed);
self.last_data_sent_ms.store(0, Ordering::Relaxed);
self.offline_sync_completed.store(false, Ordering::Relaxed);
self.offline_sync_metrics
.active
.store(false, Ordering::Release);
self.offline_sync_metrics
.total_messages
.store(0, Ordering::Release);
self.offline_sync_metrics
.processed_messages
.store(0, Ordering::Release);
match self.offline_sync_metrics.start_time.lock() {
Ok(mut guard) => *guard = None,
Err(poison) => *poison.into_inner() = None,
}
self.server_has_prekeys.store(true, Ordering::Relaxed);
self.history_sync_tasks_in_flight
.store(0, Ordering::Relaxed);
self.history_sync_idle_notifier.notify(usize::MAX);
let mut waiters_map = self.response_waiters.lock().await;
let waiter_count = waiters_map.len();
*waiters_map = HashMap::new();
drop(waiters_map);
if waiter_count > 0 {
debug!(
"Dropping {} orphaned IQ response waiter(s) on disconnect",
waiter_count
);
}
*self.app_state_key_requests.lock().await = HashMap::new();
*self.app_state_syncing.lock().await = HashSet::new();
*self.media_conn.write().await = None;
if let Some(proc) = self.app_state_processor.lock().await.as_ref() {
proc.clear_key_cache().await;
}
}
#[cfg(feature = "debug-diagnostics")]
pub async fn memory_diagnostics(&self) -> MemoryDiagnostics {
let (sig_sessions, sig_identities, sig_sender_keys) =
self.signal_cache.entry_counts().await;
let (lid_lid, lid_pn) = self.lid_pn_cache.entry_counts();
MemoryDiagnostics {
group_cache: self
.group_cache
.lock()
.await
.as_ref()
.map_or(0, |c| c.entry_count()),
device_cache: self
.device_cache
.lock()
.await
.as_ref()
.map_or(0, |c| c.entry_count()),
device_registry_cache: self.device_registry_cache.entry_count(),
lid_pn_lid_entries: lid_lid,
lid_pn_pn_entries: lid_pn,
retried_group_messages: self.retried_group_messages.entry_count(),
recent_messages: self.recent_messages.entry_count(),
message_retry_counts: self.message_retry_counts.entry_count(),
pdo_pending_requests: self.pdo_pending_requests.entry_count(),
session_locks: self.session_locks.entry_count(),
message_queues: self.message_queues.entry_count(),
message_enqueue_locks: self.message_enqueue_locks.entry_count(),
response_waiters: self.response_waiters.lock().await.len(),
node_waiters: self.node_waiter_count.load(Ordering::Relaxed),
pending_retries: self.pending_retries.lock().await.len(),
presence_subscriptions: self.presence_subscriptions.lock().await.len(),
app_state_key_requests: self.app_state_key_requests.lock().await.len(),
app_state_syncing: self.app_state_syncing.lock().await.len(),
signal_cache_sessions: sig_sessions,
signal_cache_identities: sig_identities,
signal_cache_sender_keys: sig_sender_keys,
chatstate_handlers: self.chatstate_handlers.read().await.len(),
custom_enc_handlers: self.custom_enc_handlers.read().await.len(),
}
}
pub(crate) async fn flush_signal_cache(&self) -> Result<(), anyhow::Error> {
let device = self.persistence_manager.get_device_arc().await;
let device_guard = device.read().await;
self.signal_cache
.flush(&*device_guard.backend)
.await
.map_err(|e| anyhow::anyhow!("Failed to flush signal cache: {e}"))
}
async fn read_messages_loop(self: &Arc<Self>) -> Result<(), anyhow::Error> {
debug!("Starting message processing loop...");
let mut rx_guard = self.transport_events.lock().await;
let transport_events = rx_guard
.take()
.ok_or_else(|| anyhow::anyhow!("Cannot start message loop: not connected"))?;
drop(rx_guard);
let mut frame_decoder = wacore::framing::FrameDecoder::new();
loop {
futures::select_biased! {
_ = self.shutdown_notifier.listen().fuse() => {
debug!("Shutdown signaled in message loop. Exiting message loop.");
return Ok(());
},
event_result = transport_events.recv().fuse() => {
match event_result {
Ok(crate::transport::TransportEvent::DataReceived(data)) => {
self.last_data_received_ms.store(
wacore::time::now_millis() as u64,
Ordering::Relaxed,
);
frame_decoder.feed(&data);
let mut frames_in_batch: u32 = 0;
while let Some(encrypted_frame) = frame_decoder.decode_frame() {
if let Some(node) = self.decrypt_frame(&encrypted_frame).await {
let process_inline = matches!(
node.tag.as_ref(),
"success" | "failure" | "stream:error" | "message" | "ib"
);
if process_inline {
self.process_decrypted_node(node).await;
} else {
let client = self.clone();
self.runtime.spawn(Box::pin(async move {
client.process_decrypted_node(node).await;
})).detach();
}
}
if self.expected_disconnect.load(Ordering::Relaxed) {
debug!("Expected disconnect signaled during frame processing. Exiting message loop.");
return Ok(());
}
frames_in_batch += 1;
if frames_in_batch.is_multiple_of(self.runtime.yield_frequency())
&& let Some(yield_fut) = self.runtime.yield_now()
{
yield_fut.await;
}
}
},
Ok(crate::transport::TransportEvent::Disconnected) | Err(_) => {
self.cleanup_connection_state().await;
if !self.expected_disconnect.load(Ordering::Relaxed) {
self.core.event_bus.dispatch(&Event::Disconnected(crate::types::events::Disconnected));
debug!("Transport disconnected unexpectedly.");
return Err(anyhow::anyhow!("Transport disconnected unexpectedly"));
} else {
debug!("Transport disconnected as expected.");
return Ok(());
}
}
Ok(crate::transport::TransportEvent::Connected) => {
debug!("Transport connected event received");
}
}
}
}
}
}
pub(crate) async fn decrypt_frame(
self: &Arc<Self>,
encrypted_frame: &bytes::Bytes,
) -> Option<wacore_binary::node::Node> {
let noise_socket_arc = { self.noise_socket.lock().await.clone() };
let noise_socket = match noise_socket_arc {
Some(s) => s,
None => {
log::error!("Cannot process frame: not connected (no noise socket)");
return None;
}
};
let decrypted_payload = match noise_socket.decrypt_frame(encrypted_frame) {
Ok(p) => p,
Err(e) => {
log::error!("Failed to decrypt frame: {e}");
return None;
}
};
let unpacked_data_cow = match wacore_binary::util::unpack(&decrypted_payload) {
Ok(data) => data,
Err(e) => {
log::warn!(target: "Client/Recv", "Failed to decompress frame: {e}");
return None;
}
};
match wacore_binary::marshal::unmarshal_ref(unpacked_data_cow.as_ref()) {
Ok(node_ref) => Some(node_ref.to_owned()),
Err(e) => {
log::warn!(target: "Client/Recv", "Failed to unmarshal node: {e}");
None
}
}
}
pub(crate) async fn process_decrypted_node(self: &Arc<Self>, node: wacore_binary::node::Node) {
let node_arc = Arc::new(node);
self.process_node(node_arc).await;
}
pub(crate) async fn process_node(self: &Arc<Self>, node: Arc<Node>) {
use wacore::xml::DisplayableNode;
if node.tag.as_ref() == "ib" {
if let Some(preview) = node.get_optional_child("offline_preview") {
let count: usize = preview
.attrs
.get("count")
.and_then(|v| v.as_str().parse().ok())
.unwrap_or(0);
if count == 0 {
self.offline_sync_metrics
.active
.store(false, Ordering::Release);
debug!(target: "Client/OfflineSync", "Sync COMPLETED: 0 items.");
} else {
self.offline_sync_metrics
.total_messages
.store(count, Ordering::Release);
self.offline_sync_metrics
.processed_messages
.store(0, Ordering::Release);
self.offline_sync_metrics
.active
.store(true, Ordering::Release);
match self.offline_sync_metrics.start_time.lock() {
Ok(mut guard) => *guard = Some(wacore::time::Instant::now()),
Err(poison) => *poison.into_inner() = Some(wacore::time::Instant::now()),
}
debug!(target: "Client/OfflineSync", "Sync STARTED: Expecting {} items.", count);
}
} else if self.offline_sync_metrics.active.load(Ordering::Acquire)
&& node.get_optional_child("offline").is_some()
{
let processed = self
.offline_sync_metrics
.processed_messages
.load(Ordering::Acquire);
let elapsed = match self.offline_sync_metrics.start_time.lock() {
Ok(guard) => guard.map(|t| t.elapsed()).unwrap_or_default(),
Err(poison) => poison.into_inner().map(|t| t.elapsed()).unwrap_or_default(),
};
debug!(target: "Client/OfflineSync", "Sync COMPLETED: End marker received. Processed {} items in {:.2?}.", processed, elapsed);
self.offline_sync_metrics
.active
.store(false, Ordering::Release);
}
}
if self.offline_sync_metrics.active.load(Ordering::Acquire) {
if node.attrs.contains_key("offline") {
let processed = self
.offline_sync_metrics
.processed_messages
.fetch_add(1, Ordering::Release)
+ 1;
let total = self
.offline_sync_metrics
.total_messages
.load(Ordering::Acquire);
if processed.is_multiple_of(50) || processed == total {
trace!(target: "Client/OfflineSync", "Sync Progress: {}/{}", processed, total);
}
if processed >= total {
let elapsed = match self.offline_sync_metrics.start_time.lock() {
Ok(guard) => guard.map(|t| t.elapsed()).unwrap_or_default(),
Err(poison) => poison.into_inner().map(|t| t.elapsed()).unwrap_or_default(),
};
debug!(target: "Client/OfflineSync", "Sync COMPLETED: Processed {} items in {:.2?}.", processed, elapsed);
self.offline_sync_metrics
.active
.store(false, Ordering::Release);
}
}
}
if node.tag.as_ref() == "iq"
&& let Some(sync_node) = node.get_optional_child("sync")
&& let Some(collection_node) = sync_node.get_optional_child("collection")
{
let name = collection_node.attrs().optional_string("name");
let name = name.as_deref().unwrap_or("<unknown>");
debug!(target: "Client/Recv", "Received app state sync response for '{name}' (hiding content).");
} else {
debug!(target: "Client/Recv","{}", DisplayableNode(&node));
}
let mut cancelled = false;
if node.tag.as_ref() == "xmlstreamend" {
if self.expected_disconnect.load(Ordering::Relaxed) {
debug!("Received <xmlstreamend/>, expected disconnect.");
} else {
warn!("Received <xmlstreamend/>, treating as disconnect.");
}
self.shutdown_notifier.notify(usize::MAX);
return;
}
if self.node_waiter_count.load(Ordering::Relaxed) > 0 {
self.resolve_node_waiters(&node);
}
if node.tag.as_ref() == "iq"
&& let Some(id) = node.attrs.get("id").map(|v| v.as_str())
{
let has_waiter = self.response_waiters.lock().await.contains_key(id.as_ref());
if has_waiter && self.handle_iq_response(Arc::clone(&node)).await {
return;
}
}
if !self
.stanza_router
.dispatch(self.clone(), Arc::clone(&node), &mut cancelled)
.await
{
warn!(
"Received unknown top-level node: {}",
DisplayableNode(&node)
);
}
if self.should_ack(&node) && !cancelled {
self.maybe_deferred_ack(node).await;
}
}
fn should_ack(&self, node: &Node) -> bool {
matches!(
node.tag.as_ref(),
"message" | "receipt" | "notification" | "call"
) && node.attrs.contains_key("id")
&& node.attrs.contains_key("from")
}
async fn maybe_deferred_ack(self: &Arc<Self>, node: Arc<Node>) {
if self.synchronous_ack {
if let Err(e) = self.send_ack_for(&node).await {
warn!("Failed to send ack: {e:?}");
}
} else {
let this = self.clone();
self.runtime
.spawn(Box::pin(async move {
if let Err(e) = this.send_ack_for(&node).await {
warn!("Failed to send ack: {e:?}");
}
}))
.detach();
}
}
async fn send_ack_for(&self, node: &Node) -> Result<(), ClientError> {
if self.expected_disconnect.load(Ordering::Relaxed) {
return Ok(());
}
if !self.is_connected() {
return Err(ClientError::NotConnected);
}
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
let ack = match build_ack_node(node, device_snapshot.pn.as_ref()) {
Some(ack) => ack,
None => return Ok(()),
};
self.send_node(ack).await
}
pub(crate) async fn handle_unimplemented(&self, tag: &str) {
warn!("TODO: Implement handler for <{tag}>");
}
pub async fn set_passive(&self, passive: bool) -> Result<(), crate::request::IqError> {
use wacore::iq::passive::PassiveModeSpec;
self.execute(PassiveModeSpec::new(passive)).await
}
pub async fn clean_dirty_bits(
&self,
type_: &str,
timestamp: Option<&str>,
) -> Result<(), crate::request::IqError> {
use wacore::iq::dirty::CleanDirtyBitsSpec;
let spec = CleanDirtyBitsSpec::single(type_, timestamp)?;
self.execute(spec).await
}
pub async fn fetch_props(&self) -> Result<(), crate::request::IqError> {
use wacore::iq::props::PropsSpec;
use wacore::store::commands::DeviceCommand;
let stored_hash = self
.persistence_manager
.get_device_snapshot()
.await
.props_hash
.clone();
let spec = match &stored_hash {
Some(hash) => {
debug!("Fetching props with hash for delta update...");
PropsSpec::with_hash(hash)
}
None => {
debug!("Fetching props (full, no stored hash)...");
PropsSpec::new()
}
};
let response = self.execute(spec).await?;
if response.delta_update {
debug!(
"Props delta update received ({} changed props)",
response.props.len()
);
} else {
debug!(
"Props full update received ({} props, hash={:?})",
response.props.len(),
response.hash
);
}
if let Some(new_hash) = response.hash {
self.persistence_manager
.process_command(DeviceCommand::SetPropsHash(Some(new_hash)))
.await;
}
Ok(())
}
pub async fn fetch_privacy_settings(
&self,
) -> Result<wacore::iq::privacy::PrivacySettingsResponse, crate::request::IqError> {
use wacore::iq::privacy::PrivacySettingsSpec;
debug!("Fetching privacy settings...");
self.execute(PrivacySettingsSpec::new()).await
}
pub async fn set_privacy_setting(
&self,
category: &str,
value: &str,
) -> Result<(), crate::request::IqError> {
use wacore::iq::privacy::SetPrivacySettingSpec;
self.execute(SetPrivacySettingSpec::new(category, value))
.await
}
pub async fn set_default_disappearing_mode(
&self,
duration: u32,
) -> Result<(), crate::request::IqError> {
use wacore::iq::privacy::SetDefaultDisappearingModeSpec;
self.execute(SetDefaultDisappearingModeSpec::new(duration))
.await
}
pub async fn get_business_profile(
&self,
jid: &wacore_binary::jid::Jid,
) -> Result<Option<wacore::iq::business::BusinessProfile>, crate::request::IqError> {
use wacore::iq::business::BusinessProfileSpec;
self.execute(BusinessProfileSpec::new(jid)).await
}
pub async fn reject_call(
&self,
call_id: &str,
call_from: &wacore_binary::jid::Jid,
) -> Result<(), anyhow::Error> {
anyhow::ensure!(!call_id.is_empty(), "call_id cannot be empty");
let id = self.generate_request_id();
let stanza = wacore_binary::builder::NodeBuilder::new("call")
.attr("to", call_from.clone())
.attr("id", id)
.children([wacore_binary::builder::NodeBuilder::new("reject")
.attr("call-id", call_id)
.attr("call-creator", call_from.clone())
.attr("count", "0")
.build()])
.build();
self.send_node(stanza).await?;
Ok(())
}
pub async fn send_digest_key_bundle(&self) -> Result<(), crate::request::IqError> {
use wacore::iq::prekeys::DigestKeyBundleSpec;
debug!("Sending digest key bundle...");
self.execute(DigestKeyBundleSpec::new()).await.map(|_| ())
}
pub(crate) async fn handle_success(self: &Arc<Self>, node: &wacore_binary::node::Node) {
if self.expected_disconnect.load(Ordering::Relaxed) {
debug!("Ignoring <success> stanza: expected disconnect pending");
return;
}
if self.is_logged_in.swap(true, Ordering::SeqCst) {
debug!("Ignoring duplicate <success> stanza (already logged in)");
return;
}
let current_generation = self.connection_generation.fetch_add(1, Ordering::SeqCst) + 1;
info!(
"Successfully authenticated with WhatsApp servers! (gen={})",
current_generation
);
self.auto_reconnect_errors.store(0, Ordering::Relaxed);
self.update_server_time_offset(node);
if let Some(lid_value) = node.attrs.get("lid") {
if let Some(lid) = lid_value.to_jid() {
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
if device_snapshot.lid.as_ref() != Some(&lid) {
debug!("Updating LID from server to '{lid}'");
self.persistence_manager
.process_command(DeviceCommand::SetLid(Some(lid)))
.await;
}
} else {
warn!("Failed to parse LID from success stanza: {lid_value}");
}
} else {
warn!("LID not found in <success> stanza. Group messaging may fail.");
}
let client_clone = self.clone();
let task_generation = current_generation;
self.runtime.spawn(Box::pin(async move {
macro_rules! check_generation {
() => {
if client_clone.connection_generation.load(Ordering::SeqCst) != task_generation
{
debug!("Post-login task cancelled: connection generation changed");
return;
}
};
}
debug!(
"Starting post-login initialization sequence (gen={})...",
task_generation
);
let device_snapshot = client_clone.persistence_manager.get_device_snapshot().await;
let needs_pushname_from_sync = device_snapshot.push_name.is_empty();
if needs_pushname_from_sync {
debug!("Push name is empty - will be set from app state sync (setting_pushName)");
}
if !client_clone.is_connected() {
debug!(
"Skipping post-login init: connection closed (likely pairing phase reconnect)"
);
return;
}
check_generation!();
client_clone.send_unified_session().await;
check_generation!();
if let Err(e) = client_clone
.establish_primary_phone_session_immediate()
.await
{
warn!(target: "Client/PDO", "Failed to establish session with primary phone on login: {:?}", e);
}
check_generation!();
if let Err(e) = client_clone.upload_pre_keys(false).await {
warn!("Failed to upload pre-keys during startup: {e:?}");
}
check_generation!();
if let Err(e) = client_clone.set_passive(false).await {
warn!("Failed to send post-connect active IQ: {e:?}");
}
client_clone.wait_for_offline_delivery_end().await;
check_generation!();
check_generation!();
if !client_clone.is_connected() {
debug!("Skipping presence: connection closed");
return;
}
let bg_client = client_clone.clone();
let bg_generation = task_generation;
client_clone.runtime.spawn(Box::pin(async move {
if bg_client.connection_generation.load(Ordering::SeqCst) != bg_generation {
debug!("Skipping background init queries: connection generation changed");
return;
}
if !bg_client.is_connected() {
debug!("Skipping background init queries: connection closed");
return;
}
debug!(
"Sending background initialization queries (Props, Blocklist, Privacy, Digest)..."
);
let props_fut = bg_client.fetch_props();
let binding = bg_client.blocking();
let blocklist_fut = binding.get_blocklist();
let privacy_fut = bg_client.fetch_privacy_settings();
let digest_fut = bg_client.validate_digest_key();
let (r_props, r_block, r_priv, r_digest) =
futures::join!(props_fut, blocklist_fut, privacy_fut, digest_fut);
if let Err(e) = r_props {
warn!("Background init: Failed to fetch props: {e:?}");
}
if let Err(e) = r_block {
warn!("Background init: Failed to fetch blocklist: {e:?}");
}
if let Err(e) = r_priv {
warn!("Background init: Failed to fetch privacy settings: {e:?}");
}
if let Err(e) = r_digest {
warn!("Background init: Failed to validate digest key: {e:?}");
}
if let Err(e) = bg_client.tc_token().prune_expired().await {
warn!("Background init: Failed to prune expired tc_tokens: {e:?}");
}
})).detach();
check_generation!();
let flag_set = client_clone.needs_initial_full_sync.load(Ordering::Relaxed);
let needs_initial_sync = flag_set || needs_pushname_from_sync;
if needs_initial_sync {
debug!(
target: "Client/AppState",
"Starting Initial App State Sync (flag_set={flag_set}, needs_pushname={needs_pushname_from_sync})"
);
if !client_clone
.initial_app_state_keys_received
.load(Ordering::Relaxed)
{
debug!(
target: "Client/AppState",
"Waiting up to 5s for app state keys..."
);
let _ = rt_timeout(
&*client_clone.runtime,
Duration::from_secs(5),
client_clone.initial_keys_synced_notifier.listen(),
)
.await;
check_generation!();
}
const CRITICAL_SYNC_TIMEOUT_SECS: u64 = 180;
let timeout_client = client_clone.clone();
let timeout_generation = task_generation;
let timeout_rt = client_clone.runtime.clone();
let critical_sync_timeout_handle = timeout_rt.spawn(Box::pin(async move {
timeout_client.runtime.sleep(Duration::from_secs(CRITICAL_SYNC_TIMEOUT_SECS)).await;
if timeout_client.connection_generation.load(Ordering::SeqCst)
!= timeout_generation
{
return;
}
let push_name = timeout_client.get_push_name().await;
if push_name.is_empty() {
warn!(
target: "Client/AppState",
"Critical app state sync timed out after {CRITICAL_SYNC_TIMEOUT_SECS}s \
(push_name not synced). Reconnecting to retry."
);
timeout_client.reconnect_immediately().await;
} else {
debug!(
target: "Client/AppState",
"Critical sync timeout fired but push_name was already synced"
);
}
}));
check_generation!();
match client_clone
.sync_collections_batched(vec![
WAPatchName::CriticalBlock,
WAPatchName::CriticalUnblockLow,
])
.await
{
Ok(()) => {
critical_sync_timeout_handle.abort();
check_generation!();
client_clone
.resubscribe_presence_subscriptions(task_generation)
.await;
check_generation!();
client_clone.dispatch_connected();
}
Err(e) => {
client_clone.log_sync_error("critical app state sync", &e);
return;
}
}
let sync_client = client_clone.clone();
let sync_generation = task_generation;
client_clone.runtime.spawn(Box::pin(async move {
if sync_client.connection_generation.load(Ordering::SeqCst) != sync_generation {
debug!("App state sync cancelled: connection generation changed");
return;
}
if let Err(e) = sync_client
.sync_collections_batched(vec![
WAPatchName::RegularLow,
WAPatchName::RegularHigh,
WAPatchName::Regular,
])
.await
{
sync_client.log_sync_error("non-critical app state sync", &e);
}
sync_client
.needs_initial_full_sync
.store(false, Ordering::Relaxed);
debug!(target: "Client/AppState", "Initial App State Sync Completed.");
})).detach();
} else {
let device_snapshot = client_clone.persistence_manager.get_device_snapshot().await;
if !device_snapshot.push_name.is_empty() {
if let Err(e) = client_clone.presence().set_available().await {
warn!("Failed to send initial presence: {e:?}");
} else {
debug!("Initial presence sent successfully.");
}
}
client_clone
.resubscribe_presence_subscriptions(task_generation)
.await;
check_generation!();
client_clone.dispatch_connected();
}
})).detach();
}
pub(crate) async fn handle_ack_response(&self, node: Node) -> bool {
let id_opt = node.attrs.get("id").map(|v| v.as_str().into_owned());
if let Some(id) = id_opt
&& let Some(waiter) = self.response_waiters.lock().await.remove(&id)
{
if waiter.send(node).is_err() {
warn!(target: "Client/Ack", "Failed to send ACK response to waiter for ID {id}. Receiver was likely dropped.");
}
return true;
}
false
}
#[allow(dead_code)] pub(crate) async fn fetch_app_state_with_retry(&self, name: WAPatchName) -> anyhow::Result<()> {
{
let mut syncing = self.app_state_syncing.lock().await;
if !syncing.insert(name) {
debug!(target: "Client/AppState", "Skipping sync for {:?}: already in flight", name);
return Ok(());
}
}
let result = self.fetch_app_state_with_retry_inner(name).await;
self.app_state_syncing.lock().await.remove(&name);
result
}
#[allow(dead_code)]
async fn fetch_app_state_with_retry_inner(&self, name: WAPatchName) -> anyhow::Result<()> {
let mut attempt = 0u32;
loop {
attempt += 1;
let res = self.process_app_state_sync_task(name, false).await;
match res {
Ok(()) => return Ok(()),
Err(e) => {
if e.downcast_ref::<crate::appstate_sync::AppStateSyncError>()
.is_some_and(|ase| {
matches!(ase, crate::appstate_sync::AppStateSyncError::KeyNotFound(_))
})
&& attempt == 1
{
if !self.initial_app_state_keys_received.load(Ordering::Relaxed) {
debug!(target: "Client/AppState", "App state key missing for {:?}; waiting up to 10s for key share then retrying", name);
if rt_timeout(
&*self.runtime,
Duration::from_secs(10),
self.initial_keys_synced_notifier.listen(),
)
.await
.is_err()
{
warn!(target: "Client/AppState", "Timeout waiting for key share for {:?}; retrying anyway", name);
}
}
continue;
}
let is_db_locked = e.downcast_ref::<wacore::store::error::StoreError>()
.is_some_and(|se| matches!(se, wacore::store::error::StoreError::Database(msg) if msg.contains("locked") || msg.contains("busy")))
|| e.downcast_ref::<crate::appstate_sync::AppStateSyncError>()
.is_some_and(|ase| matches!(ase, crate::appstate_sync::AppStateSyncError::Store(wacore::store::error::StoreError::Database(msg)) if msg.contains("locked") || msg.contains("busy")));
if is_db_locked && attempt < APP_STATE_RETRY_MAX_ATTEMPTS {
let backoff = Duration::from_millis(200 * attempt as u64 + 150);
warn!(target: "Client/AppState", "Attempt {} for {:?} failed due to locked DB; backing off {:?} and retrying", attempt, name, backoff);
self.runtime.sleep(backoff).await;
continue;
}
return Err(e);
}
}
}
}
pub(crate) async fn sync_collections_batched(
&self,
collections: Vec<WAPatchName>,
) -> anyhow::Result<()> {
if collections.is_empty() {
return Ok(());
}
let pending = {
let mut syncing = self.app_state_syncing.lock().await;
let mut filtered = Vec::with_capacity(collections.len());
for name in collections {
if syncing.insert(name) {
filtered.push(name);
} else {
debug!(target: "Client/AppState", "Skipping {:?} in batch: already in flight", name);
}
}
filtered
};
if pending.is_empty() {
return Ok(());
}
let all_collections: Vec<WAPatchName> = pending.clone();
let result = self.sync_collections_batched_inner(pending).await;
{
let mut syncing = self.app_state_syncing.lock().await;
for name in &all_collections {
syncing.remove(name);
}
}
result
}
async fn sync_collections_batched_inner(
&self,
mut pending: Vec<WAPatchName>,
) -> anyhow::Result<()> {
use wacore::appstate::patch_decode::CollectionSyncError;
const MAX_ITERATIONS: usize = 5;
let mut iteration = 0;
while !pending.is_empty() && iteration < MAX_ITERATIONS {
iteration += 1;
debug!(
target: "Client/AppState",
"Batched sync iteration {}/{}: {:?}",
iteration, MAX_ITERATIONS, pending
);
let backend = self.persistence_manager.backend();
let mut collection_nodes = Vec::with_capacity(pending.len());
let mut was_snapshot = std::collections::HashSet::new();
for &name in &pending {
let state = backend.get_version(name.as_str()).await?;
let want_snapshot = state.version == 0;
if want_snapshot {
was_snapshot.insert(name);
}
let mut builder = NodeBuilder::new("collection")
.attr("name", name.as_str())
.attr(
"return_snapshot",
if want_snapshot { "true" } else { "false" },
);
if !want_snapshot {
builder = builder.attr("version", state.version.to_string());
}
collection_nodes.push(builder.build());
}
let sync_node = NodeBuilder::new("sync").children(collection_nodes).build();
let iq = crate::request::InfoQuery {
namespace: "w:sync:app:state",
query_type: crate::request::InfoQueryType::Set,
to: server_jid().clone(),
target: None,
id: None,
content: Some(wacore_binary::node::NodeContent::Nodes(vec![sync_node])),
timeout: Some(Duration::from_secs(30)),
};
let resp = self.send_iq(iq).await?;
let mut pre_downloaded: std::collections::HashMap<String, Vec<u8>> =
std::collections::HashMap::new();
if let Ok(patch_lists) = wacore::appstate::patch_decode::parse_patch_lists(&resp) {
for pl in &patch_lists {
if let Some(ext) = &pl.snapshot_ref
&& let Some(path) = &ext.direct_path
{
match self.download(ext).await {
Ok(bytes) => {
pre_downloaded.insert(path.clone(), bytes);
}
Err(e) => {
warn!(
"Failed to download external snapshot for {:?}: {e}",
pl.name
);
}
}
}
for patch in &pl.patches {
if let Some(ext) = &patch.external_mutations
&& let Some(path) = &ext.direct_path
{
match self.download(ext).await {
Ok(bytes) => {
pre_downloaded.insert(path.clone(), bytes);
}
Err(e) => {
let v =
patch.version.as_ref().and_then(|v| v.version).unwrap_or(0);
warn!(
"Failed to download external mutations for patch v{}: {e}",
v
);
}
}
}
}
}
}
let download = |ext: &wa::ExternalBlobReference| -> anyhow::Result<Vec<u8>> {
if let Some(path) = &ext.direct_path {
if let Some(bytes) = pre_downloaded.get(path) {
Ok(bytes.clone())
} else {
Err(anyhow::anyhow!(
"external blob not pre-downloaded: {}",
path
))
}
} else {
Err(anyhow::anyhow!("external blob has no directPath"))
}
};
let proc = self.get_app_state_processor().await;
let results = proc.decode_multi_patch_list(&resp, &download, true).await?;
let mut needs_refetch = Vec::new();
for (mutations, new_state, list) in results {
let name = list.name;
if let Some(ref err) = list.error {
match err {
CollectionSyncError::Conflict { has_more } => {
if *has_more {
warn!(target: "Client/AppState", "Collection {:?} conflict (has_more=true), will refetch", name);
needs_refetch.push(name);
} else {
debug!(target: "Client/AppState", "Collection {:?} conflict (has_more=false), treating as success (no pending mutations)", name);
}
continue;
}
CollectionSyncError::Fatal { code, text } => {
warn!(target: "Client/AppState", "Collection {:?} fatal error {}: {}", name, code, text);
continue;
}
CollectionSyncError::Retry { code, text } => {
warn!(target: "Client/AppState", "Collection {:?} retryable error {}: {}, will refetch", name, code, text);
needs_refetch.push(name);
continue;
}
}
}
let missing = match proc.get_missing_key_ids(&list).await {
Ok(v) => v,
Err(e) => {
warn!("Failed to get missing key IDs for {:?}: {}", name, e);
Vec::new()
}
};
if !missing.is_empty() {
let mut to_request: Vec<Vec<u8>> = Vec::with_capacity(missing.len());
let mut guard = self.app_state_key_requests.lock().await;
let now = wacore::time::Instant::now();
for key_id in missing {
let hex_id = hex::encode(&key_id);
let should = guard
.get(&hex_id)
.map(|t| t.elapsed() > std::time::Duration::from_secs(24 * 3600))
.unwrap_or(true);
if should {
guard.insert(hex_id, now);
to_request.push(key_id);
}
}
guard.retain(|_, t| t.elapsed() < std::time::Duration::from_secs(24 * 3600));
drop(guard);
if !to_request.is_empty() {
self.request_app_state_keys(&to_request).await;
}
}
let full_sync = was_snapshot.contains(&name);
for m in mutations {
self.dispatch_app_state_mutation(&m, full_sync).await;
}
backend
.set_version(name.as_str(), new_state.clone())
.await?;
if list.has_more_patches {
needs_refetch.push(name);
}
debug!(
target: "Client/AppState",
"Batched sync: {:?} done (version={}, has_more={})",
name, new_state.version, list.has_more_patches
);
}
pending = needs_refetch;
}
if !pending.is_empty() {
warn!(
target: "Client/AppState",
"Batched sync: max iterations ({}) reached for {:?}",
MAX_ITERATIONS, pending
);
}
Ok(())
}
pub(crate) async fn process_app_state_sync_task(
&self,
name: WAPatchName,
full_sync: bool,
) -> anyhow::Result<()> {
if self.is_shutting_down() {
debug!(target: "Client/AppState", "Skipping app state sync task {:?}: client is shutting down", name);
return Ok(());
}
let backend = self.persistence_manager.backend();
let mut full_sync = full_sync;
let mut state = backend.get_version(name.as_str()).await?;
if state.version == 0 {
full_sync = true;
}
let mut has_more = true;
let mut want_snapshot = full_sync;
const MAX_PAGINATION_ITERATIONS: u32 = 500;
let mut iteration = 0u32;
while has_more {
if self.is_shutting_down() {
debug!(target: "Client/AppState", "Stopping app state sync task {:?}: shutdown detected", name);
break;
}
iteration += 1;
if iteration > MAX_PAGINATION_ITERATIONS {
warn!(target: "Client/AppState", "App state sync for {:?} exceeded {} iterations, aborting", name, MAX_PAGINATION_ITERATIONS);
break;
}
debug!(target: "Client/AppState", "Fetching app state patch batch: name={:?} want_snapshot={want_snapshot} version={} full_sync={} has_more_previous={}", name, state.version, full_sync, has_more);
let mut collection_builder = NodeBuilder::new("collection")
.attr("name", name.as_str())
.attr(
"return_snapshot",
if want_snapshot { "true" } else { "false" },
);
if !want_snapshot {
collection_builder = collection_builder.attr("version", state.version.to_string());
}
let sync_node = NodeBuilder::new("sync")
.children([collection_builder.build()])
.build();
let iq = crate::request::InfoQuery {
namespace: "w:sync:app:state",
query_type: crate::request::InfoQueryType::Set,
to: server_jid().clone(),
target: None,
id: None,
content: Some(wacore_binary::node::NodeContent::Nodes(vec![sync_node])),
timeout: None,
};
let resp = self.send_iq(iq).await?;
if self.is_shutting_down() {
debug!(target: "Client/AppState", "Discarding app state sync response for {:?}: shutdown detected", name);
break;
}
debug!(target: "Client/AppState", "Received IQ response for {:?}; decoding patches", name);
let _decode_start = wacore::time::Instant::now();
let mut pre_downloaded: std::collections::HashMap<String, Vec<u8>> =
std::collections::HashMap::new();
if let Ok(pl) = wacore::appstate::patch_decode::parse_patch_list(&resp) {
debug!(target: "Client/AppState", "Parsed patch list for {:?}: has_snapshot_ref={} has_more_patches={} patches_count={}",
name, pl.snapshot_ref.is_some(), pl.has_more_patches, pl.patches.len());
if let Some(ext) = &pl.snapshot_ref
&& let Some(path) = &ext.direct_path
{
match self.download(ext).await {
Ok(bytes) => {
debug!(target: "Client/AppState", "Downloaded external snapshot ({} bytes)", bytes.len());
pre_downloaded.insert(path.clone(), bytes);
}
Err(e) => {
warn!("Failed to download external snapshot: {e}");
}
}
}
for patch in &pl.patches {
if let Some(ext) = &patch.external_mutations
&& let Some(path) = &ext.direct_path
{
let patch_version =
patch.version.as_ref().and_then(|v| v.version).unwrap_or(0);
match self.download(ext).await {
Ok(bytes) => {
debug!(target: "Client/AppState", "Downloaded external mutations for patch v{} ({} bytes)", patch_version, bytes.len());
pre_downloaded.insert(path.clone(), bytes);
}
Err(e) => {
warn!(
"Failed to download external mutations for patch v{}: {e}",
patch_version
);
}
}
}
}
}
let download = |ext: &wa::ExternalBlobReference| -> anyhow::Result<Vec<u8>> {
if let Some(path) = &ext.direct_path {
if let Some(bytes) = pre_downloaded.get(path) {
Ok(bytes.clone())
} else {
Err(anyhow::anyhow!(
"external blob not pre-downloaded: {}",
path
))
}
} else {
Err(anyhow::anyhow!("external blob has no directPath"))
}
};
let proc = self.get_app_state_processor().await;
let (mutations, new_state, list) =
proc.decode_patch_list(&resp, &download, true).await?;
let decode_elapsed = _decode_start.elapsed();
if decode_elapsed.as_millis() > 500 {
debug!(target: "Client/AppState", "Patch decode for {:?} took {:?}", name, decode_elapsed);
}
let missing = match proc.get_missing_key_ids(&list).await {
Ok(v) => v,
Err(e) => {
warn!("Failed to get missing key IDs for {:?}: {}", name, e);
Vec::new()
}
};
if !missing.is_empty() {
let mut to_request: Vec<Vec<u8>> = Vec::with_capacity(missing.len());
let mut guard = self.app_state_key_requests.lock().await;
let now = wacore::time::Instant::now();
for key_id in missing {
let hex_id = hex::encode(&key_id);
let should = guard
.get(&hex_id)
.map(|t| t.elapsed() > std::time::Duration::from_secs(24 * 3600))
.unwrap_or(true);
if should {
guard.insert(hex_id, now);
to_request.push(key_id);
}
}
guard.retain(|_, t| t.elapsed() < std::time::Duration::from_secs(24 * 3600));
drop(guard);
if !to_request.is_empty() {
self.request_app_state_keys(&to_request).await;
}
}
for m in mutations {
debug!(target: "Client/AppState", "Dispatching mutation kind={} index_len={} full_sync={}", m.index.first().map(|s| s.as_str()).unwrap_or(""), m.index.len(), full_sync);
self.dispatch_app_state_mutation(&m, full_sync).await;
}
state = new_state;
has_more = list.has_more_patches;
want_snapshot = false;
debug!(target: "Client/AppState", "After processing batch name={:?} has_more={has_more} new_version={}", name, state.version);
}
backend.set_version(name.as_str(), state.clone()).await?;
debug!(target: "Client/AppState", "Completed and saved app state sync for {:?} (final version={})", name, state.version);
Ok(())
}
async fn request_app_state_keys(&self, raw_key_ids: &[Vec<u8>]) {
if raw_key_ids.is_empty() {
return;
}
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
let own_jid = match device_snapshot.pn.clone() {
Some(j) => j,
None => return,
};
let key_ids: Vec<wa::message::AppStateSyncKeyId> = raw_key_ids
.iter()
.map(|k| wa::message::AppStateSyncKeyId {
key_id: Some(k.clone()),
})
.collect();
let msg = wa::Message {
protocol_message: Some(Box::new(wa::message::ProtocolMessage {
r#type: Some(wa::message::protocol_message::Type::AppStateSyncKeyRequest as i32),
app_state_sync_key_request: Some(wa::message::AppStateSyncKeyRequest { key_ids }),
..Default::default()
})),
..Default::default()
};
if let Err(e) = self
.send_message_impl(
own_jid,
&msg,
Some(self.generate_message_id().await),
true,
false,
None,
vec![],
)
.await
{
warn!("Failed to send app state key request: {e}");
}
}
pub(crate) async fn send_app_state_patch(
&self,
collection_name: &str,
mutations: Vec<(wa::SyncdMutation, Vec<u8>)>,
) -> Result<()> {
let proc = self.get_app_state_processor().await;
let (patch_bytes, base_version) = proc.build_patch(collection_name, mutations).await?;
let collection_node = NodeBuilder::new("collection")
.attr("name", collection_name)
.attr("version", base_version.to_string())
.attr("return_snapshot", "false")
.children([NodeBuilder::new("patch").bytes(patch_bytes).build()])
.build();
let sync_node = NodeBuilder::new("sync").children([collection_node]).build();
let iq = crate::request::InfoQuery {
namespace: "w:sync:app:state",
query_type: crate::request::InfoQueryType::Set,
to: server_jid().clone(),
target: None,
id: None,
content: Some(wacore_binary::node::NodeContent::Nodes(vec![sync_node])),
timeout: None,
};
self.send_iq(iq).await?;
if let Ok(patch_name) = collection_name.parse::<WAPatchName>()
&& let Err(e) = self.fetch_app_state_with_retry(patch_name).await
{
log::warn!("Failed to re-sync {collection_name} after patch send: {e}");
}
Ok(())
}
async fn dispatch_app_state_mutation(
&self,
m: &crate::appstate_sync::Mutation,
full_sync: bool,
) {
use wacore::types::events::Event;
if m.operation != wa::syncd_mutation::SyncdOperation::Set {
return;
}
if m.index.is_empty() {
return;
}
if crate::features::chat_actions::dispatch_chat_mutation(&self.core.event_bus, m, full_sync)
{
return;
}
if m.index[0] == "setting_pushName"
&& let Some(val) = &m.action_value
&& let Some(act) = &val.push_name_setting
&& let Some(new_name) = &act.name
{
let new_name = new_name.clone();
let bus = self.core.event_bus.clone();
let snapshot = self.persistence_manager.get_device_snapshot().await;
let old = snapshot.push_name.clone();
if old != new_name {
debug!(target: "Client/AppState", "Persisting push name from app state mutation: '{}' (old='{}')", new_name, old);
self.persistence_manager
.process_command(DeviceCommand::SetPushName(new_name.clone()))
.await;
bus.dispatch(&Event::SelfPushNameUpdated(
crate::types::events::SelfPushNameUpdated {
from_server: true,
old_name: old.clone(),
new_name: new_name.clone(),
},
));
if old.is_empty() && !new_name.is_empty() {
debug!(target: "Client/AppState", "Sending presence after receiving initial pushname from app state sync");
if let Err(e) = self.presence().set_available().await {
warn!(target: "Client/AppState", "Failed to send presence after pushname sync: {e:?}");
}
}
} else {
debug!(target: "Client/AppState", "Push name mutation received but name unchanged: '{}'", new_name);
}
}
}
async fn expect_disconnect(&self) {
self.expected_disconnect.store(true, Ordering::Relaxed);
}
pub(crate) async fn handle_stream_error(&self, node: &wacore_binary::node::Node) {
self.is_logged_in.store(false, Ordering::Relaxed);
let mut attrs = node.attrs();
let code_cow = attrs.optional_string("code");
let code = code_cow.as_deref().unwrap_or("");
let conflict_type = node
.get_optional_child("conflict")
.map(|n| {
n.attrs()
.optional_string("type")
.as_deref()
.unwrap_or("")
.to_string()
})
.unwrap_or_default();
if !conflict_type.is_empty() {
info!(
"Got stream error indicating client was removed or replaced (conflict={}). Logging out.",
conflict_type
);
self.expect_disconnect().await;
self.enable_auto_reconnect.store(false, Ordering::Relaxed);
let event = if conflict_type == "replaced" {
Event::StreamReplaced(crate::types::events::StreamReplaced)
} else {
Event::LoggedOut(crate::types::events::LoggedOut {
on_connect: false,
reason: ConnectFailureReason::LoggedOut,
})
};
self.core.event_bus.dispatch(&event);
let transport_opt = self.transport.lock().await.clone();
if let Some(transport) = transport_opt {
self.runtime
.spawn(Box::pin(async move {
info!("Disconnecting transport after conflict");
transport.disconnect().await;
}))
.detach();
}
} else {
match code {
"515" => {
info!(
"Got 515 stream error, server is closing stream (expected after pairing). Will auto-reconnect."
);
self.expect_disconnect().await;
let transport_opt = self.transport.lock().await.clone();
if let Some(transport) = transport_opt {
self.runtime
.spawn(Box::pin(async move {
info!("Disconnecting transport after 515");
transport.disconnect().await;
}))
.detach();
}
}
"516" => {
info!("Got 516 stream error (device removed). Logging out.");
self.expect_disconnect().await;
self.enable_auto_reconnect.store(false, Ordering::Relaxed);
self.core.event_bus.dispatch(&Event::LoggedOut(
crate::types::events::LoggedOut {
on_connect: false,
reason: ConnectFailureReason::LoggedOut,
},
));
let transport_opt = self.transport.lock().await.clone();
if let Some(transport) = transport_opt {
self.runtime
.spawn(Box::pin(async move {
info!("Disconnecting transport after 516");
transport.disconnect().await;
}))
.detach();
}
}
"401" => {
info!("Got 401 stream error (unauthorized). Logging out.");
self.expect_disconnect().await;
self.enable_auto_reconnect.store(false, Ordering::Relaxed);
self.core.event_bus.dispatch(&Event::LoggedOut(
crate::types::events::LoggedOut {
on_connect: false,
reason: ConnectFailureReason::LoggedOut,
},
));
let transport_opt = self.transport.lock().await.clone();
if let Some(transport) = transport_opt {
self.runtime
.spawn(Box::pin(async move {
info!("Disconnecting transport after 401");
transport.disconnect().await;
}))
.detach();
}
}
"409" => {
info!("Got 409 stream error (conflict). Another session replaced this one.");
self.expect_disconnect().await;
self.enable_auto_reconnect.store(false, Ordering::Relaxed);
self.core
.event_bus
.dispatch(&Event::StreamReplaced(crate::types::events::StreamReplaced));
let transport_opt = self.transport.lock().await.clone();
if let Some(transport) = transport_opt {
self.runtime
.spawn(Box::pin(async move {
info!("Disconnecting transport after 409");
transport.disconnect().await;
}))
.detach();
}
}
"429" => {
warn!(
"Got 429 stream error (rate limited). Will auto-reconnect with extended backoff."
);
self.auto_reconnect_errors.fetch_add(5, Ordering::Relaxed);
}
"503" => {
info!("Got 503 service unavailable, will auto-reconnect.");
}
_ => {
error!("Unknown stream error: {}", DisplayableNode(node));
self.expect_disconnect().await;
self.core.event_bus.dispatch(&Event::StreamError(
crate::types::events::StreamError {
code: code.to_string(),
raw: Some(node.clone()),
},
));
}
}
}
info!("Notifying shutdown from stream error handler");
self.shutdown_notifier.notify(usize::MAX);
}
pub(crate) async fn handle_connect_failure(&self, node: &wacore_binary::node::Node) {
self.expected_disconnect.store(true, Ordering::Relaxed);
self.shutdown_notifier.notify(usize::MAX);
let mut attrs = node.attrs();
let reason_code = attrs.optional_u64("reason").unwrap_or(0) as i32;
let reason = ConnectFailureReason::from(reason_code);
if reason.should_reconnect() {
self.expected_disconnect.store(false, Ordering::Relaxed);
} else {
self.enable_auto_reconnect.store(false, Ordering::Relaxed);
}
if reason.is_logged_out() {
info!("Got {reason:?} connect failure, logging out.");
self.core
.event_bus
.dispatch(&wacore::types::events::Event::LoggedOut(
crate::types::events::LoggedOut {
on_connect: true,
reason,
},
));
} else if let ConnectFailureReason::TempBanned = reason {
let ban_code = attrs.optional_u64("code").unwrap_or(0) as i32;
let expire_secs = attrs.optional_u64("expire").unwrap_or(0);
let expire_duration =
chrono::Duration::try_seconds(expire_secs as i64).unwrap_or_default();
warn!("Temporary ban connect failure: {}", DisplayableNode(node));
self.core.event_bus.dispatch(&Event::TemporaryBan(
crate::types::events::TemporaryBan {
code: crate::types::events::TempBanReason::from(ban_code),
expire: expire_duration,
},
));
} else if let ConnectFailureReason::ClientOutdated = reason {
error!("Client is outdated and was rejected by server.");
self.core
.event_bus
.dispatch(&Event::ClientOutdated(crate::types::events::ClientOutdated));
} else {
warn!("Unknown connect failure: {}", DisplayableNode(node));
self.core.event_bus.dispatch(&Event::ConnectFailure(
crate::types::events::ConnectFailure {
reason,
message: attrs
.optional_string("message")
.as_deref()
.unwrap_or("")
.to_string(),
raw: Some(node.clone()),
},
));
}
}
pub(crate) async fn handle_iq(self: &Arc<Self>, node: &wacore_binary::node::Node) -> bool {
if node.attrs.get("type").is_some_and(|s| s == "get")
&& (node.get_optional_child("ping").is_some()
|| node
.attrs
.get("xmlns")
.is_some_and(|s| s == "urn:xmpp:ping"))
{
info!("Received ping, sending pong.");
let mut parser = node.attrs();
let from_jid = parser.jid("from");
let id = parser.optional_string("id").map(|s| s.to_string());
let pong = build_pong(from_jid.to_string(), id.as_deref());
if let Err(e) = self.send_node(pong).await {
warn!("Failed to send pong: {e:?}");
}
return true;
}
if pair::handle_iq(self, node).await {
return true;
}
false
}
pub fn is_connected(&self) -> bool {
self.is_connected.load(Ordering::Acquire)
}
pub fn is_logged_in(&self) -> bool {
self.is_logged_in.load(Ordering::Relaxed)
}
pub fn wait_for_node(
&self,
filter: NodeFilter,
) -> futures::channel::oneshot::Receiver<Arc<Node>> {
let (tx, rx) = futures::channel::oneshot::channel();
self.node_waiter_count.fetch_add(1, Ordering::Release);
let mut waiters = self
.node_waiters
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
waiters.push(NodeWaiter { filter, tx });
rx
}
fn resolve_node_waiters(&self, node: &Arc<Node>) {
let mut waiters = self
.node_waiters
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let mut i = 0;
while i < waiters.len() {
if waiters[i].tx.is_canceled() {
waiters.swap_remove(i);
self.node_waiter_count.fetch_sub(1, Ordering::Release);
} else if waiters[i].filter.matches(node) {
let w = waiters.swap_remove(i);
self.node_waiter_count.fetch_sub(1, Ordering::Release);
let _ = w.tx.send(Arc::clone(node));
} else {
i += 1;
}
}
}
pub(crate) fn update_server_time_offset(&self, node: &wacore_binary::node::Node) {
self.unified_session.update_server_time_offset(node);
}
pub(crate) async fn send_unified_session(&self) {
if !self.is_connected() {
debug!(target: "Client/UnifiedSession", "Skipping: not connected");
return;
}
let Some((node, _sequence)) = self.unified_session.prepare_send().await else {
return;
};
if let Err(e) = self.send_node(node).await {
debug!(target: "Client/UnifiedSession", "Send failed: {e}");
self.unified_session.clear_last_sent().await;
}
}
pub async fn wait_for_socket(&self, timeout: std::time::Duration) -> Result<(), anyhow::Error> {
if self.is_connected() {
return Ok(());
}
let notified = self.socket_ready_notifier.listen();
if self.is_connected() {
return Ok(());
}
rt_timeout(&*self.runtime, timeout, notified)
.await
.map_err(|_| anyhow::anyhow!("Timeout waiting for socket"))
}
pub async fn wait_for_connected(
&self,
timeout: std::time::Duration,
) -> Result<(), anyhow::Error> {
if self.is_fully_ready() {
return Ok(());
}
let notified = self.connected_notifier.listen();
if self.is_fully_ready() {
return Ok(());
}
rt_timeout(&*self.runtime, timeout, notified)
.await
.map_err(|_| anyhow::anyhow!("Timeout waiting for connection"))
}
pub fn persistence_manager(&self) -> Arc<PersistenceManager> {
self.persistence_manager.clone()
}
pub async fn edit_message(
&self,
to: Jid,
original_id: impl Into<String>,
new_content: wa::Message,
) -> Result<String, anyhow::Error> {
let original_id = original_id.into();
let participant = if to.is_group() {
Some(
self.get_own_jid_for_group(&to)
.await?
.to_non_ad()
.to_string(),
)
} else {
if self.get_pn().await.is_none() {
return Err(anyhow::Error::from(ClientError::NotLoggedIn));
}
None
};
let edit_container_message = wa::Message {
edited_message: Some(Box::new(wa::message::FutureProofMessage {
message: Some(Box::new(wa::Message {
protocol_message: Some(Box::new(wa::message::ProtocolMessage {
key: Some(wa::MessageKey {
remote_jid: Some(to.to_string()),
from_me: Some(true),
id: Some(original_id.clone()),
participant,
}),
r#type: Some(wa::message::protocol_message::Type::MessageEdit as i32),
edited_message: Some(Box::new(new_content)),
timestamp_ms: Some(wacore::time::now_millis()),
..Default::default()
})),
..Default::default()
})),
})),
..Default::default()
};
self.send_message_impl(
to,
&edit_container_message,
None,
false,
false,
Some(crate::types::message::EditAttribute::MessageEdit),
vec![],
)
.await?;
Ok(original_id)
}
pub async fn send_node(&self, node: Node) -> Result<(), ClientError> {
let noise_socket_arc = { self.noise_socket.lock().await.clone() };
let noise_socket = match noise_socket_arc {
Some(socket) => socket,
None => return Err(ClientError::NotConnected),
};
debug!(target: "Client/Send", "{}", DisplayableNode(&node));
let mut plaintext_buf = Vec::with_capacity(1024);
if let Err(e) = wacore_binary::marshal::marshal_to(&node, &mut plaintext_buf) {
error!("Failed to marshal node: {e:?}");
return Err(SocketError::Crypto("Marshal error".to_string()).into());
}
let encrypted_buf = Vec::with_capacity(plaintext_buf.len() + 32);
if let Err(e) = noise_socket
.encrypt_and_send(plaintext_buf, encrypted_buf)
.await
{
return Err(e.into());
}
self.last_data_sent_ms
.store(wacore::time::now_millis() as u64, Ordering::Relaxed);
Ok(())
}
pub(crate) async fn update_push_name_and_notify(self: &Arc<Self>, new_name: String) {
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
let old_name = device_snapshot.push_name.clone();
if old_name == new_name {
return;
}
log::debug!("Updating push name from '{}' -> '{}'", old_name, new_name);
self.persistence_manager
.process_command(DeviceCommand::SetPushName(new_name.clone()))
.await;
self.core.event_bus.dispatch(&Event::SelfPushNameUpdated(
crate::types::events::SelfPushNameUpdated {
from_server: true,
old_name,
new_name: new_name.clone(),
},
));
let client_clone = self.clone();
self.runtime
.spawn(Box::pin(async move {
if let Err(e) = client_clone.presence().set_available().await {
log::warn!("Failed to send presence after push name update: {:?}", e);
} else {
log::debug!("Sent presence after push name update.");
}
}))
.detach();
}
pub async fn get_push_name(&self) -> String {
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
device_snapshot.push_name.clone()
}
pub async fn get_pn(&self) -> Option<Jid> {
let snapshot = self.persistence_manager.get_device_snapshot().await;
snapshot.pn.clone()
}
pub async fn get_lid(&self) -> Option<Jid> {
let snapshot = self.persistence_manager.get_device_snapshot().await;
snapshot.lid.clone()
}
pub(crate) async fn get_own_jid_for_group(
&self,
group_jid: &Jid,
) -> Result<Jid, anyhow::Error> {
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
let own_pn = device_snapshot
.pn
.clone()
.ok_or_else(|| anyhow::Error::from(ClientError::NotLoggedIn))?;
let addressing_mode = self
.groups()
.query_info(group_jid)
.await
.map(|info| info.addressing_mode)
.unwrap_or(crate::types::message::AddressingMode::Pn);
Ok(match addressing_mode {
crate::types::message::AddressingMode::Lid => {
device_snapshot.lid.clone().unwrap_or(own_pn)
}
crate::types::message::AddressingMode::Pn => own_pn,
})
}
pub(crate) async fn make_stanza_key(&self, chat: Jid, id: String) -> StanzaKey {
let chat = self.resolve_encryption_jid(&chat).await;
StanzaKey { chat, id }
}
pub(crate) async fn send_protocol_receipt(
&self,
id: String,
receipt_type: crate::types::presence::ReceiptType,
) {
if id.is_empty() {
return;
}
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
if let Some(own_jid) = &device_snapshot.pn {
let type_str = match receipt_type {
crate::types::presence::ReceiptType::HistorySync => "hist_sync",
crate::types::presence::ReceiptType::Read => "read",
crate::types::presence::ReceiptType::ReadSelf => "read-self",
crate::types::presence::ReceiptType::Delivered => "delivery",
crate::types::presence::ReceiptType::Played => "played",
crate::types::presence::ReceiptType::PlayedSelf => "played-self",
crate::types::presence::ReceiptType::Inactive => "inactive",
crate::types::presence::ReceiptType::PeerMsg => "peer_msg",
crate::types::presence::ReceiptType::Sender => "sender",
crate::types::presence::ReceiptType::ServerError => "server-error",
crate::types::presence::ReceiptType::Retry => "retry",
crate::types::presence::ReceiptType::EncRekeyRetry => "enc_rekey_retry",
crate::types::presence::ReceiptType::Other(ref s) => s.as_str(),
};
let node = NodeBuilder::new("receipt")
.attrs([
("id", id),
("type", type_str.to_string()),
("to", own_jid.to_non_ad().to_string()),
])
.build();
if let Err(e) = self.send_node(node).await {
warn!(
"Failed to send protocol receipt of type {:?} for message ID {}: {:?}",
receipt_type, self.unique_id, e
);
}
}
}
}
fn build_pong(to: String, id: Option<&str>) -> wacore_binary::node::Node {
let mut builder = NodeBuilder::new("iq").attr("to", to).attr("type", "result");
if let Some(id) = id {
builder = builder.attr("id", id);
}
builder.build()
}
fn build_ack_node(node: &Node, own_device_pn: Option<&Jid>) -> Option<Node> {
let id = node.attrs.get("id")?.clone();
let from = node.attrs.get("from")?.clone();
let participant = node.attrs.get("participant").cloned();
let typ = if node.tag != "message" && !is_encrypt_identity_notification(node) {
node.attrs.get("type").cloned()
} else {
None
};
let mut attrs = Attrs::new();
attrs.insert("class", NodeValue::String(node.tag.to_string()));
attrs.insert("id", id);
attrs.insert("to", from);
if node.tag == "message"
&& let Some(own_device_pn) = own_device_pn
{
attrs.insert("from", NodeValue::Jid(own_device_pn.clone()));
}
if let Some(p) = participant {
attrs.insert("participant", p);
}
if let Some(t) = typ {
attrs.insert("type", t);
}
Some(Node {
tag: Cow::Borrowed("ack"),
attrs,
content: None,
})
}
fn is_encrypt_identity_notification(node: &Node) -> bool {
node.tag == "notification"
&& node.attrs.get("type").is_some_and(|v| v == "encrypt")
&& node.get_optional_child("identity").is_some()
}
fn fibonacci_backoff(attempt: u32) -> Duration {
const MAX_MS: u64 = 900_000;
let mut a: u64 = 1000;
let mut b: u64 = 1000;
for _ in 0..attempt {
let next = a.saturating_add(b).min(MAX_MS);
a = b;
b = next;
}
let base = a.min(MAX_MS);
let jitter_range = base / 10;
let jitter = if jitter_range > 0 {
rand::make_rng::<rand::rngs::StdRng>().random_range(0..=(jitter_range * 2)) as i64
- jitter_range as i64
} else {
0
};
let ms = (base as i64 + jitter).max(0) as u64;
Duration::from_millis(ms)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lid_pn_cache::LearningSource;
use crate::test_utils::MockHttpClient;
use futures::channel::oneshot;
use wacore_binary::jid::SERVER_JID;
#[tokio::test]
async fn test_ack_behavior_for_incoming_stanzas() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
use wacore_binary::node::{Attrs, Node, NodeContent};
let mut receipt_attrs = Attrs::new();
receipt_attrs.insert("from".to_string(), "@s.whatsapp.net".to_string());
receipt_attrs.insert("id".to_string(), "RCPT-1".to_string());
let receipt_node = Node::new(
"receipt",
receipt_attrs,
Some(NodeContent::String("test".to_string())),
);
let mut notification_attrs = Attrs::new();
notification_attrs.insert("from".to_string(), "@s.whatsapp.net".to_string());
notification_attrs.insert("id".to_string(), "NOTIF-1".to_string());
let notification_node = Node::new(
"notification",
notification_attrs,
Some(NodeContent::String("test".to_string())),
);
assert!(
client.should_ack(&receipt_node),
"should_ack must still return TRUE for <receipt> stanzas."
);
assert!(
client.should_ack(¬ification_node),
"should_ack must still return TRUE for <notification> stanzas."
);
info!(
"✅ test_ack_behavior_for_incoming_stanzas passed: Client correctly differentiates which stanzas to acknowledge."
);
}
#[tokio::test]
async fn test_ack_waiter_resolves() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let test_id = "ack-test-123".to_string();
let (tx, rx) = oneshot::channel();
client
.response_waiters
.lock()
.await
.insert(test_id.clone(), tx);
assert!(
client.response_waiters.lock().await.contains_key(&test_id),
"Waiter should be inserted before handling ack"
);
let ack_node = NodeBuilder::new("ack")
.attr("id", test_id.clone())
.attr("from", SERVER_JID)
.build();
let handled = client.handle_ack_response(ack_node).await;
assert!(
handled,
"handle_ack_response should return true when waiter exists"
);
match tokio::time::timeout(Duration::from_secs(1), rx).await {
Ok(Ok(response_node)) => {
assert!(
response_node
.attrs
.get("id")
.is_some_and(|v| v == test_id.as_str()),
"Response node should have correct ID"
);
}
Ok(Err(_)) => panic!("Receiver was dropped without being sent a value"),
Err(_) => panic!("Test timed out waiting for ack response"),
}
assert!(
!client.response_waiters.lock().await.contains_key(&test_id),
"Waiter should be removed after handling"
);
info!(
"✅ test_ack_waiter_resolves passed: ACK response correctly resolves pending waiters"
);
}
#[tokio::test]
async fn test_ack_without_matching_waiter() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let ack_node = NodeBuilder::new("ack")
.attr("id", "non-existent-id")
.attr("from", SERVER_JID)
.build();
let handled = client.handle_ack_response(ack_node).await;
assert!(
!handled,
"handle_ack_response should return false when no waiter exists"
);
info!(
"✅ test_ack_without_matching_waiter passed: ACK without matching waiter handled gracefully"
);
}
#[tokio::test]
async fn test_lid_pn_cache_basic_operations() {
let backend = Arc::new(
crate::store::SqliteStore::new("file:memdb_lid_cache_basic?mode=memory&cache=shared")
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let phone = "559980000001";
let lid = "100000012345678";
assert!(
client.lid_pn_cache.get_current_lid(phone).await.is_none(),
"Cache should be empty initially"
);
client
.add_lid_pn_mapping(lid, phone, LearningSource::Usync)
.await
.expect("Failed to persist LID-PN mapping in tests");
let cached_lid = client.lid_pn_cache.get_current_lid(phone).await;
assert!(cached_lid.is_some(), "Cache should contain the mapping");
assert_eq!(
cached_lid.expect("cache should have LID"),
lid,
"Cached LID should match what we inserted"
);
let cached_phone = client.lid_pn_cache.get_phone_number(lid).await;
assert!(cached_phone.is_some(), "Reverse lookup should work");
assert_eq!(
cached_phone.expect("reverse lookup should return phone"),
phone,
"Cached phone should match what we inserted"
);
assert!(
client
.lid_pn_cache
.get_current_lid("559980000002")
.await
.is_none(),
"Different phone number should not have a mapping"
);
info!("✅ test_lid_pn_cache_basic_operations passed: LID-PN cache works correctly");
}
#[tokio::test]
async fn test_lid_pn_cache_timestamp_resolution() {
let backend = Arc::new(
crate::store::SqliteStore::new(
"file:memdb_lid_cache_timestamp?mode=memory&cache=shared",
)
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let phone = "559980000001";
let lid_old = "100000012345678";
let lid_new = "100000087654321";
client
.add_lid_pn_mapping(lid_old, phone, LearningSource::Usync)
.await
.expect("Failed to persist LID-PN mapping in tests");
assert_eq!(
client
.lid_pn_cache
.get_current_lid(phone)
.await
.expect("cache should have LID"),
lid_old,
"Initial LID should be stored"
);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
client
.add_lid_pn_mapping(lid_new, phone, LearningSource::PeerPnMessage)
.await
.expect("Failed to persist LID-PN mapping in tests");
assert_eq!(
client
.lid_pn_cache
.get_current_lid(phone)
.await
.expect("cache should have newer LID"),
lid_new,
"Newer LID should be returned for phone lookup"
);
assert_eq!(
client
.lid_pn_cache
.get_phone_number(lid_old)
.await
.expect("reverse lookup should return phone"),
phone,
"Old LID should still map to phone"
);
assert_eq!(
client
.lid_pn_cache
.get_phone_number(lid_new)
.await
.expect("reverse lookup should return phone"),
phone,
"New LID should also map to phone"
);
info!(
"✅ test_lid_pn_cache_timestamp_resolution passed: Timestamp-based resolution works correctly"
);
}
#[tokio::test]
async fn test_get_lid_for_phone_via_send_context_resolver() {
use wacore::client::context::SendContextResolver;
let backend = Arc::new(
crate::store::SqliteStore::new("file:memdb_get_lid_for_phone?mode=memory&cache=shared")
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let phone = "559980000001";
let lid = "100000012345678";
assert!(
client.get_lid_for_phone(phone).await.is_none(),
"get_lid_for_phone should return None before caching"
);
client
.add_lid_pn_mapping(lid, phone, LearningSource::Usync)
.await
.expect("Failed to persist LID-PN mapping in tests");
let result = client.get_lid_for_phone(phone).await;
assert!(
result.is_some(),
"get_lid_for_phone should return Some after caching"
);
assert_eq!(
result.expect("get_lid_for_phone should return Some"),
lid,
"get_lid_for_phone should return the cached LID"
);
info!(
"✅ test_get_lid_for_phone_via_send_context_resolver passed: SendContextResolver correctly returns cached LID"
);
}
#[tokio::test]
async fn test_wait_for_offline_delivery_end_returns_immediately_when_flag_set() {
let backend = Arc::new(
crate::store::SqliteStore::new(
"file:memdb_offline_sync_flag_set?mode=memory&cache=shared",
)
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
client
.offline_sync_completed
.store(true, std::sync::atomic::Ordering::Relaxed);
let start = std::time::Instant::now();
client.wait_for_offline_delivery_end().await;
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() < 100,
"wait_for_offline_delivery_end should return immediately when flag is set, took {:?}",
elapsed
);
info!("✅ test_wait_for_offline_delivery_end_returns_immediately_when_flag_set passed");
}
#[tokio::test]
async fn test_wait_for_offline_delivery_end_times_out_when_flag_not_set() {
let backend = Arc::new(
crate::store::SqliteStore::new(
"file:memdb_offline_sync_timeout?mode=memory&cache=shared",
)
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let start = std::time::Instant::now();
client
.wait_for_offline_delivery_end_with_timeout(std::time::Duration::from_millis(50))
.await;
let elapsed = start.elapsed();
let semaphore = match client.message_processing_semaphore.lock() {
Ok(guard) => guard.clone(),
Err(poisoned) => poisoned.into_inner().clone(),
};
let mut guards = Vec::new();
while let Some(guard) = semaphore.try_acquire() {
guards.push(guard);
}
let permits = guards.len();
drop(guards);
assert!(
elapsed.as_millis() >= 45, "Should have waited for the configured timeout duration, took {:?}",
elapsed
);
assert!(
client
.offline_sync_completed
.load(std::sync::atomic::Ordering::Relaxed),
"wait_for_offline_delivery_end should mark offline sync complete on timeout"
);
assert_eq!(
permits, 64,
"timeout completion should restore parallel permits"
);
info!("✅ test_wait_for_offline_delivery_end_times_out_when_flag_not_set passed");
}
#[tokio::test]
async fn test_wait_for_offline_delivery_end_returns_on_notify() {
let backend = Arc::new(
crate::store::SqliteStore::new("file:memdb_offline_notify?mode=memory&cache=shared")
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let client_clone = client.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
client_clone.offline_sync_notifier.notify(usize::MAX);
});
let start = std::time::Instant::now();
client.wait_for_offline_delivery_end().await;
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() < 200,
"wait_for_offline_delivery_end should return when notified, took {:?}",
elapsed
);
assert!(
elapsed.as_millis() >= 45, "Should have waited for the notify, only took {:?}",
elapsed
);
info!("✅ test_wait_for_offline_delivery_end_returns_on_notify passed");
}
#[tokio::test]
async fn test_offline_sync_flag_initially_false() {
let backend = Arc::new(
crate::store::SqliteStore::new(
"file:memdb_offline_flag_initial?mode=memory&cache=shared",
)
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
assert!(
!client
.offline_sync_completed
.load(std::sync::atomic::Ordering::Relaxed),
"offline_sync_completed should be false when Client is first created"
);
info!("✅ test_offline_sync_flag_initially_false passed");
}
#[tokio::test]
async fn test_offline_sync_lifecycle() {
use std::sync::atomic::Ordering;
let backend = Arc::new(
crate::store::SqliteStore::new("file:memdb_offline_lifecycle?mode=memory&cache=shared")
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
assert!(!client.offline_sync_completed.load(Ordering::Relaxed));
let client_waiter = client.clone();
let waiter_handle = tokio::spawn(async move {
client_waiter.wait_for_offline_delivery_end().await;
true });
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert!(
!waiter_handle.is_finished(),
"Waiter should still be waiting"
);
client.offline_sync_completed.store(true, Ordering::Relaxed);
client.offline_sync_notifier.notify(usize::MAX);
let result = tokio::time::timeout(std::time::Duration::from_millis(100), waiter_handle)
.await
.expect("Waiter should complete after notify")
.expect("Waiter task should not panic");
assert!(result, "Waiter should have completed successfully");
assert!(client.offline_sync_completed.load(Ordering::Relaxed));
info!("✅ test_offline_sync_lifecycle passed");
}
#[tokio::test]
async fn test_establish_primary_phone_session_fails_without_pn() {
let backend = Arc::new(
crate::store::SqliteStore::new("file:memdb_no_pn?mode=memory&cache=shared")
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let result = client.establish_primary_phone_session_immediate().await;
assert!(
result.is_err(),
"establish_primary_phone_session_immediate should fail when no PN is set"
);
let err = result.unwrap_err();
assert!(
err.downcast_ref::<ClientError>()
.is_some_and(|e| matches!(e, ClientError::NotLoggedIn)),
"Error should be ClientError::NotLoggedIn, got: {}",
err
);
info!("✅ test_establish_primary_phone_session_fails_without_pn passed");
}
#[tokio::test]
async fn test_ensure_e2e_sessions_waits_for_offline_sync() {
use std::sync::atomic::Ordering;
use wacore_binary::jid::Jid;
let backend = Arc::new(
crate::store::SqliteStore::new("file:memdb_ensure_e2e_waits?mode=memory&cache=shared")
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
assert!(!client.offline_sync_completed.load(Ordering::Relaxed));
let client_clone = client.clone();
let ensure_handle = tokio::spawn(async move {
client_clone.ensure_e2e_sessions(vec![]).await
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert!(
ensure_handle.is_finished(),
"ensure_e2e_sessions should return immediately for empty JID list"
);
let client_clone = client.clone();
let test_jid = Jid::pn("559999999999");
let ensure_handle = tokio::spawn(async move {
let start = std::time::Instant::now();
let _ = client_clone.ensure_e2e_sessions(vec![test_jid]).await;
start.elapsed()
});
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
assert!(
!ensure_handle.is_finished(),
"ensure_e2e_sessions should be waiting for offline sync"
);
client.offline_sync_completed.store(true, Ordering::Relaxed);
client.offline_sync_notifier.notify(usize::MAX);
let result = tokio::time::timeout(std::time::Duration::from_secs(2), ensure_handle).await;
assert!(
result.is_ok(),
"ensure_e2e_sessions should complete after offline sync"
);
info!("✅ test_ensure_e2e_sessions_waits_for_offline_sync passed");
}
#[tokio::test]
async fn test_immediate_session_does_not_wait_for_offline_sync() {
use std::sync::atomic::Ordering;
use wacore_binary::jid::Jid;
let backend = Arc::new(
crate::store::SqliteStore::new("file:memdb_immediate_no_wait?mode=memory&cache=shared")
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend.clone())
.await
.expect("persistence manager should initialize"),
);
pm.modify_device(|device| {
device.pn = Some(Jid::pn("559999999999"));
})
.await;
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
assert!(!client.offline_sync_completed.load(Ordering::Relaxed));
let start = std::time::Instant::now();
let result = tokio::time::timeout(
std::time::Duration::from_millis(500),
client.establish_primary_phone_session_immediate(),
)
.await;
let elapsed = start.elapsed();
assert!(
result.is_ok(),
"establish_primary_phone_session_immediate should not wait for offline sync, timed out"
);
assert!(
elapsed.as_millis() < 500,
"establish_primary_phone_session_immediate should not wait, took {:?}",
elapsed
);
info!(
"establish_primary_phone_session_immediate completed in {:?} (result: {:?})",
elapsed,
result.unwrap().is_ok()
);
info!("✅ test_immediate_session_does_not_wait_for_offline_sync passed");
}
#[tokio::test]
async fn test_establish_session_skips_when_exists() {
use wacore::libsignal::protocol::SessionRecord;
use wacore::libsignal::store::SessionStore;
use wacore::types::jid::JidExt;
use wacore_binary::jid::Jid;
let backend = Arc::new(
crate::store::SqliteStore::new("file:memdb_skip_existing?mode=memory&cache=shared")
.await
.expect("Failed to create in-memory backend for test"),
);
let pm = Arc::new(
PersistenceManager::new(backend.clone())
.await
.expect("persistence manager should initialize"),
);
let own_pn = Jid::pn("559999999999");
pm.modify_device(|device| {
device.pn = Some(own_pn.clone());
})
.await;
let primary_phone_jid = own_pn.with_device(0);
let signal_addr = primary_phone_jid.to_protocol_address();
let dummy_session = SessionRecord::new_fresh();
{
let device_arc = pm.get_device_arc().await;
let device = device_arc.read().await;
device
.store_session(&signal_addr, &dummy_session)
.await
.expect("Failed to store test session");
let exists = device
.contains_session(&signal_addr)
.await
.expect("Failed to check session");
assert!(exists, "Session should exist after store");
}
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm.clone(),
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let result = client.establish_primary_phone_session_immediate().await;
assert!(
result.is_ok(),
"establish_primary_phone_session_immediate should succeed when session exists"
);
{
let device_arc = pm.get_device_arc().await;
let device = device_arc.read().await;
let exists = device
.contains_session(&signal_addr)
.await
.expect("Failed to check session");
assert!(exists, "Session should still exist after the call");
}
info!("✅ test_establish_session_skips_when_exists passed");
}
#[test]
fn test_mac_failure_prevention_flow_documentation() {
fn should_establish_session(
check_result: Result<bool, &'static str>,
) -> Result<bool, String> {
match check_result {
Ok(true) => Ok(false), Ok(false) => Ok(true), Err(e) => Err(format!("Cannot verify session: {}", e)), }
}
let result = should_establish_session(Ok(true));
assert_eq!(result, Ok(false), "Should skip when session exists");
let result = should_establish_session(Ok(false));
assert_eq!(result, Ok(true), "Should establish when no session");
let result = should_establish_session(Err("database error"));
assert!(result.is_err(), "Should fail when check fails");
info!("✅ test_mac_failure_prevention_flow_documentation passed");
}
#[test]
fn test_unified_session_id_calculation() {
const DAY_MS: i64 = 24 * 60 * 60 * 1000;
const WEEK_MS: i64 = 7 * DAY_MS;
const OFFSET_MS: i64 = 3 * DAY_MS;
fn calculate_session_id(now_ms: i64, server_offset_ms: i64) -> i64 {
let adjusted_now = now_ms + server_offset_ms;
(adjusted_now + OFFSET_MS) % WEEK_MS
}
let now_ms = 1706000000000_i64; let id = calculate_session_id(now_ms, 0);
assert!(
(0..WEEK_MS).contains(&id),
"Session ID should be in [0, WEEK_MS)"
);
let id_with_positive_offset = calculate_session_id(now_ms, 5000);
assert!(
(0..WEEK_MS).contains(&id_with_positive_offset),
"Session ID should be in [0, WEEK_MS)"
);
let id_with_negative_offset = calculate_session_id(now_ms, -5000);
assert!(
(0..WEEK_MS).contains(&id_with_negative_offset),
"Session ID should be in [0, WEEK_MS)"
);
let wrap_test_now = WEEK_MS - OFFSET_MS + 1000; let wrapped_id = calculate_session_id(wrap_test_now, 0);
assert_eq!(wrapped_id, 1000, "Should wrap around correctly");
let boundary_now = WEEK_MS - OFFSET_MS;
let boundary_id = calculate_session_id(boundary_now, 0);
assert_eq!(boundary_id, 0, "At exact boundary should be 0");
}
#[tokio::test]
async fn test_server_time_offset_extraction() {
use wacore_binary::builder::NodeBuilder;
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
assert_eq!(
client.unified_session.server_time_offset_ms(),
0,
"Initial offset should be 0"
);
let server_time = wacore::time::now_secs() + 10; let node = NodeBuilder::new("success")
.attr("t", server_time.to_string())
.build();
client.update_server_time_offset(&node);
let offset = client.unified_session.server_time_offset_ms();
assert!(
(offset - 10000).abs() < 1000, "Offset should be approximately 10000ms, got {}",
offset
);
let node_no_t = NodeBuilder::new("success").build();
client.update_server_time_offset(&node_no_t);
let offset_after = client.unified_session.server_time_offset_ms();
assert!(
(offset_after - offset).abs() < 100, "Offset should not change when 't' is missing"
);
let node_invalid = NodeBuilder::new("success")
.attr("t", "not_a_number")
.build();
client.update_server_time_offset(&node_invalid);
let offset_after_invalid = client.unified_session.server_time_offset_ms();
assert!(
(offset_after_invalid - offset).abs() < 100,
"Offset should not change when 't' is invalid"
);
let node_zero = NodeBuilder::new("success").attr("t", "0").build();
client.update_server_time_offset(&node_zero);
let offset_after_zero = client.unified_session.server_time_offset_ms();
assert!(
(offset_after_zero - offset).abs() < 100,
"Offset should not change when 't' is 0"
);
info!("✅ test_server_time_offset_extraction passed");
}
#[tokio::test]
async fn test_unified_session_manager_integration() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
assert_eq!(
client.unified_session.sequence(),
0,
"Initial sequence should be 0"
);
loop {
client.unified_session.reset().await;
let result = client.unified_session.prepare_send().await;
assert!(result.is_some(), "First send should succeed");
let (node, seq) = result.unwrap();
assert_eq!(node.tag, "ib", "Should be an IB stanza");
assert_eq!(seq, 1, "First sequence should be 1 (pre-increment)");
assert_eq!(client.unified_session.sequence(), 1);
let result2 = client.unified_session.prepare_send().await;
if result2.is_none() {
assert_eq!(client.unified_session.sequence(), 1);
break;
}
tokio::task::yield_now().await;
}
client.unified_session.clear_last_sent().await;
let result3 = client.unified_session.prepare_send().await;
assert!(result3.is_some(), "Should succeed after clearing");
let (_, seq3) = result3.unwrap();
assert_eq!(seq3, 1, "Sequence resets when session ID changes");
assert_eq!(client.unified_session.sequence(), 1);
info!("✅ test_unified_session_manager_integration passed");
}
#[test]
fn test_unified_session_protocol_node() {
use wacore::ib::{IbStanza, UnifiedSession};
use wacore::protocol::ProtocolNode;
let session = UnifiedSession::new("123456789");
assert_eq!(session.id, "123456789");
assert_eq!(session.tag(), "unified_session");
let node = session.into_node();
assert_eq!(node.tag, "unified_session");
assert!(node.attrs.get("id").is_some_and(|v| v == "123456789"));
let stanza = IbStanza::unified_session(UnifiedSession::new("987654321"));
assert_eq!(stanza.tag(), "ib");
let ib_node = stanza.into_node();
assert_eq!(ib_node.tag, "ib");
let children = ib_node.children().expect("IB stanza should have children");
assert_eq!(children.len(), 1);
assert_eq!(children[0].tag, "unified_session");
assert!(
children[0]
.attrs
.get("id")
.is_some_and(|v| v == "987654321")
);
info!("✅ test_unified_session_protocol_node passed");
}
async fn create_offline_sync_test_client() -> Arc<Client> {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
client
}
#[tokio::test]
async fn test_ib_thread_metadata_does_not_end_sync() {
let client = create_offline_sync_test_client().await;
client
.offline_sync_metrics
.active
.store(true, Ordering::Release);
let node = NodeBuilder::new("ib")
.children([NodeBuilder::new("thread_metadata")
.children([NodeBuilder::new("item").build()])
.build()])
.build();
client.process_node(Arc::new(node)).await;
assert!(
client.offline_sync_metrics.active.load(Ordering::Acquire),
"<ib><thread_metadata> should NOT end offline sync"
);
}
#[tokio::test]
async fn test_ib_edge_routing_does_not_end_sync() {
let client = create_offline_sync_test_client().await;
client
.offline_sync_metrics
.active
.store(true, Ordering::Release);
let node = NodeBuilder::new("ib")
.children([NodeBuilder::new("edge_routing")
.children([NodeBuilder::new("routing_info")
.bytes(vec![1, 2, 3])
.build()])
.build()])
.build();
client.process_node(Arc::new(node)).await;
assert!(
client.offline_sync_metrics.active.load(Ordering::Acquire),
"<ib><edge_routing> should NOT end offline sync"
);
}
#[tokio::test]
async fn test_ib_dirty_does_not_end_sync() {
let client = create_offline_sync_test_client().await;
client
.offline_sync_metrics
.active
.store(true, Ordering::Release);
let node = NodeBuilder::new("ib")
.children([NodeBuilder::new("dirty")
.attr("type", "groups")
.attr("timestamp", "1234")
.build()])
.build();
client.process_node(Arc::new(node)).await;
assert!(
client.offline_sync_metrics.active.load(Ordering::Acquire),
"<ib><dirty> should NOT end offline sync"
);
}
#[tokio::test]
async fn test_ib_offline_child_ends_sync() {
let client = create_offline_sync_test_client().await;
client
.offline_sync_metrics
.active
.store(true, Ordering::Release);
client
.offline_sync_metrics
.total_messages
.store(301, Ordering::Release);
let node = NodeBuilder::new("ib")
.children([NodeBuilder::new("offline").attr("count", "301").build()])
.build();
client.process_node(Arc::new(node)).await;
assert!(
!client.offline_sync_metrics.active.load(Ordering::Acquire),
"<ib><offline count='301'/> should end offline sync"
);
}
#[tokio::test]
async fn test_ib_offline_preview_starts_sync() {
let client = create_offline_sync_test_client().await;
let node = NodeBuilder::new("ib")
.children([NodeBuilder::new("offline_preview")
.attr("count", "301")
.attr("message", "168")
.attr("notification", "62")
.attr("receipt", "68")
.attr("appdata", "0")
.build()])
.build();
client.process_node(Arc::new(node)).await;
assert!(
client.offline_sync_metrics.active.load(Ordering::Acquire),
"offline_preview with count>0 should activate sync"
);
assert_eq!(
client
.offline_sync_metrics
.total_messages
.load(Ordering::Acquire),
301
);
}
#[tokio::test]
async fn test_offline_message_increments_processed() {
let client = create_offline_sync_test_client().await;
client
.offline_sync_metrics
.active
.store(true, Ordering::Release);
client
.offline_sync_metrics
.total_messages
.store(100, Ordering::Release);
let node = NodeBuilder::new("message")
.attr("offline", "1")
.attr("from", "5551234567@s.whatsapp.net")
.attr("id", "TEST123")
.attr("t", "1772884671")
.attr("type", "text")
.build();
client.process_node(Arc::new(node)).await;
assert_eq!(
client
.offline_sync_metrics
.processed_messages
.load(Ordering::Acquire),
1,
"offline message should increment processed count"
);
}
#[tokio::test]
async fn test_handle_iq_ping_with_child_element() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let ping_node = NodeBuilder::new("iq")
.attr("type", "get")
.attr("from", SERVER_JID)
.attr("id", "ping-child-1")
.children([NodeBuilder::new("ping").build()])
.build();
let handled = client.handle_iq(&ping_node).await;
assert!(
handled,
"handle_iq must recognize ping with <ping> child element"
);
}
#[tokio::test]
async fn test_handle_iq_ping_with_xmlns_attribute() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let ping_node = NodeBuilder::new("iq")
.attr("type", "get")
.attr("from", SERVER_JID)
.attr("id", "ping-xmlns-1")
.attr("xmlns", "urn:xmpp:ping")
.build();
let handled = client.handle_iq(&ping_node).await;
assert!(
handled,
"handle_iq must recognize ping with xmlns=\"urn:xmpp:ping\" attribute (no children)"
);
}
#[tokio::test]
async fn test_handle_iq_ping_with_both_child_and_xmlns() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let ping_node = NodeBuilder::new("iq")
.attr("type", "get")
.attr("from", SERVER_JID)
.attr("id", "ping-both-1")
.attr("xmlns", "urn:xmpp:ping")
.children([NodeBuilder::new("ping").build()])
.build();
let handled = client.handle_iq(&ping_node).await;
assert!(
handled,
"handle_iq must handle ping with both child and xmlns"
);
}
#[tokio::test]
async fn test_handle_iq_non_ping_returns_false() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let non_ping_node = NodeBuilder::new("iq")
.attr("type", "get")
.attr("from", SERVER_JID)
.attr("id", "not-a-ping")
.attr("xmlns", "some:other:namespace")
.build();
let handled = client.handle_iq(&non_ping_node).await;
assert!(
!handled,
"handle_iq must NOT treat non-ping xmlns as a ping"
);
}
#[tokio::test]
async fn test_handle_iq_ping_wrong_type_returns_false() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let result_node = NodeBuilder::new("iq")
.attr("type", "result")
.attr("from", SERVER_JID)
.attr("id", "ping-result-1")
.attr("xmlns", "urn:xmpp:ping")
.build();
let handled = client.handle_iq(&result_node).await;
assert!(
!handled,
"handle_iq must NOT respond to type=\"result\" even with ping xmlns"
);
}
#[test]
fn test_build_pong_with_id() {
let pong = build_pong("s.whatsapp.net".to_string(), Some("ping-123"));
assert!(
pong.attrs.get("id").is_some_and(|v| v == "ping-123"),
"pong should include id when server ping has one"
);
assert!(pong.attrs.get("type").is_some_and(|v| v == "result"));
assert!(pong.attrs.get("to").is_some_and(|v| v == "s.whatsapp.net"));
}
#[test]
fn test_build_pong_without_id() {
let pong = build_pong("s.whatsapp.net".to_string(), None);
assert!(
!pong.attrs.contains_key("id"),
"pong should NOT include id when server ping has none"
);
assert!(pong.attrs.get("type").is_some_and(|v| v == "result"));
}
#[test]
fn test_encrypt_identity_notification_omits_type() {
let node = NodeBuilder::new("notification")
.attr("from", "186303081611421@lid")
.attr("id", "4128735301")
.attr("type", "encrypt")
.children([NodeBuilder::new("identity").build()])
.build();
assert!(
is_encrypt_identity_notification(&node),
"identity-change notification ACK must omit type to match WA Web"
);
}
#[test]
fn test_device_notification_is_not_encrypt_identity() {
let node = NodeBuilder::new("notification")
.attr("from", "186303081611421@lid")
.attr("id", "269488578")
.attr("type", "devices")
.children([NodeBuilder::new("remove").build()])
.build();
assert!(
!is_encrypt_identity_notification(&node),
"device notification is not an encrypt+identity notification"
);
}
#[test]
fn test_build_ack_node_for_message_omits_type_includes_from() {
let incoming = NodeBuilder::new("message")
.attr("from", "120363161500776365@g.us")
.attr("id", "A5791A5392EF60E3FB0670098DE010D4")
.attr("type", "text")
.attr("participant", "181531758878822@lid")
.build();
let own_device_pn: Jid = "155500012345:48@s.whatsapp.net"
.parse()
.expect("own device PN JID should parse");
let ack = build_ack_node(&incoming, Some(&own_device_pn))
.expect("message ack should be buildable");
assert_eq!(ack.tag, "ack");
assert!(ack.attrs.get("class").is_some_and(|v| v == "message"));
assert!(
ack.attrs
.get("to")
.is_some_and(|v| v == "120363161500776365@g.us")
);
assert!(
ack.attrs
.get("from")
.is_some_and(|v| v == "155500012345:48@s.whatsapp.net")
);
assert!(
ack.attrs
.get("participant")
.is_some_and(|v| v == "181531758878822@lid")
);
assert!(
!ack.attrs.contains_key("type"),
"message ACK must NOT echo type (matches whatsmeow behavior)"
);
}
#[test]
fn test_build_ack_node_for_identity_change_omits_type_and_from() {
let incoming = NodeBuilder::new("notification")
.attr("from", "186303081611421@lid")
.attr("id", "4128735301")
.attr("type", "encrypt")
.children([NodeBuilder::new("identity").build()])
.build();
let own_device_pn: Jid = "155500012345:48@s.whatsapp.net"
.parse()
.expect("own device PN JID should parse");
let ack = build_ack_node(&incoming, Some(&own_device_pn))
.expect("notification ack should be buildable");
assert!(ack.attrs.get("class").is_some_and(|v| v == "notification"));
assert!(
!ack.attrs.contains_key("type"),
"identity-change notification ACK must omit type"
);
assert!(
!ack.attrs.contains_key("from"),
"notification ACKs should not include our device PN"
);
}
#[test]
fn test_build_ack_node_for_receipt_with_type_echoes_type() {
let incoming = NodeBuilder::new("receipt")
.attr("from", "156535032389744@lid")
.attr("id", "RCPT-WITH-TYPE")
.attr("type", "read")
.build();
let own_device_pn: Jid = "155500012345:48@s.whatsapp.net"
.parse()
.expect("own device PN JID should parse");
let ack = build_ack_node(&incoming, Some(&own_device_pn))
.expect("receipt ack should be buildable");
assert!(ack.attrs.get("class").is_some_and(|v| v == "receipt"));
assert!(
ack.attrs.get("type").is_some_and(|v| v == "read"),
"receipt ACK must echo the type attribute when present"
);
assert!(
!ack.attrs.contains_key("from"),
"receipt ACKs should not include our device PN"
);
}
#[test]
fn test_build_ack_node_for_receipt_without_type_omits_type() {
let incoming = NodeBuilder::new("receipt")
.attr("from", "156535032389744@lid")
.attr("id", "RCPT-NO-TYPE")
.build();
let own_device_pn: Jid = "155500012345:48@s.whatsapp.net"
.parse()
.expect("own device PN JID should parse");
let ack = build_ack_node(&incoming, Some(&own_device_pn))
.expect("receipt ack should be buildable");
assert!(ack.attrs.get("class").is_some_and(|v| v == "receipt"));
assert!(
!ack.attrs.contains_key("type"),
"receipt ACK must NOT contain type when the incoming receipt has no type attribute"
);
assert!(
!ack.attrs.contains_key("from"),
"receipt ACKs should not include our device PN"
);
}
#[tokio::test]
async fn test_handle_iq_ping_without_id() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let ping_node = NodeBuilder::new("iq")
.attr("type", "get")
.attr("from", SERVER_JID)
.attr("xmlns", "urn:xmpp:ping")
.build();
let handled = client.handle_iq(&ping_node).await;
assert!(
handled,
"handle_iq must recognize ping without id attribute"
);
}
#[test]
fn test_fibonacci_backoff_sequence() {
let expected_base_ms = [1000, 1000, 2000, 3000, 5000, 8000, 13000, 21000];
for (attempt, &base) in expected_base_ms.iter().enumerate() {
let delay = fibonacci_backoff(attempt as u32);
let ms = delay.as_millis() as u64;
let low = base - base / 10;
let high = base + base / 10;
assert!(
ms >= low && ms <= high,
"attempt {attempt}: expected {low}..={high}ms, got {ms}ms"
);
}
}
#[test]
fn test_fibonacci_backoff_max_900s() {
let delay = fibonacci_backoff(100);
let ms = delay.as_millis() as u64;
assert!(
ms <= 990_000,
"should never exceed 900s + 10% jitter, got {ms}ms"
);
assert!(
ms >= 810_000,
"should be at least 900s - 10% jitter, got {ms}ms"
);
}
#[test]
fn test_fibonacci_backoff_first_attempt_is_1s() {
let delay = fibonacci_backoff(0);
let ms = delay.as_millis() as u64;
assert!(
(900..=1100).contains(&ms),
"first attempt should be ~1s (±10%), got {ms}ms"
);
}
#[tokio::test]
async fn test_stream_error_401_disables_reconnect() {
let client = create_offline_sync_test_client().await;
let node = NodeBuilder::new("stream:error").attr("code", "401").build();
client.handle_stream_error(&node).await;
assert!(
!client.enable_auto_reconnect.load(Ordering::Relaxed),
"401 should disable auto-reconnect"
);
}
#[tokio::test]
async fn test_stream_error_409_disables_reconnect() {
let client = create_offline_sync_test_client().await;
let node = NodeBuilder::new("stream:error").attr("code", "409").build();
client.handle_stream_error(&node).await;
assert!(
!client.enable_auto_reconnect.load(Ordering::Relaxed),
"409 should disable auto-reconnect"
);
}
#[tokio::test]
async fn test_stream_error_429_keeps_reconnect_with_backoff() {
let client = create_offline_sync_test_client().await;
let before = client.auto_reconnect_errors.load(Ordering::Relaxed);
let node = NodeBuilder::new("stream:error").attr("code", "429").build();
client.handle_stream_error(&node).await;
assert!(
client.enable_auto_reconnect.load(Ordering::Relaxed),
"429 should keep auto-reconnect enabled"
);
let after = client.auto_reconnect_errors.load(Ordering::Relaxed);
assert_eq!(
after,
before + 5,
"429 should increase backoff by exactly 5: before={before}, after={after}"
);
}
#[tokio::test]
async fn test_stream_error_503_keeps_reconnect() {
let client = create_offline_sync_test_client().await;
let node = NodeBuilder::new("stream:error").attr("code", "503").build();
client.handle_stream_error(&node).await;
assert!(
client.enable_auto_reconnect.load(Ordering::Relaxed),
"503 should keep auto-reconnect enabled"
);
}
#[tokio::test]
async fn test_custom_cache_config_is_respected() {
use crate::cache_config::{CacheConfig, CacheEntryConfig};
use std::time::Duration;
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let custom_config = CacheConfig {
group_cache: CacheEntryConfig::new(Some(Duration::from_secs(60)), 10),
device_cache: CacheEntryConfig::new(Some(Duration::from_secs(60)), 10),
..CacheConfig::default()
};
let (client, _rx) = Client::new_with_cache_config(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
custom_config,
)
.await;
assert!(!client.is_logged_in());
}
#[tokio::test]
async fn test_is_connected_not_affected_by_mutex_contention() {
use crate::socket::NoiseSocket;
use wacore::handshake::NoiseCipher;
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
assert!(!client.is_connected(), "should start disconnected");
let transport: Arc<dyn crate::transport::Transport> =
Arc::new(crate::transport::mock::MockTransport);
let key = [0u8; 32];
let write_key = NoiseCipher::new(&key).expect("valid key");
let read_key = NoiseCipher::new(&key).expect("valid key");
let noise_socket = NoiseSocket::new(
Arc::new(crate::runtime_impl::TokioRuntime),
transport,
write_key,
read_key,
);
*client.noise_socket.lock().await = Some(Arc::new(noise_socket));
client.is_connected.store(true, Ordering::Release);
assert!(client.is_connected(), "should report connected");
let _guard = client.noise_socket.lock().await;
assert!(
client.is_connected(),
"is_connected() must return true even while noise_socket mutex is held"
);
}
#[tokio::test]
async fn test_send_ack_for_returns_error_when_disconnected() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
let receipt = NodeBuilder::new("receipt")
.attr("from", "120363040237990503@g.us")
.attr("id", "TEST-RECEIPT-ID")
.attr("participant", "236395184570386@lid")
.build();
let result = client.send_ack_for(&receipt).await;
assert!(
matches!(result, Err(ClientError::NotConnected)),
"send_ack_for must return Err(NotConnected) when disconnected, got: {result:?}"
);
}
#[tokio::test]
async fn test_send_ack_for_returns_ok_on_expected_disconnect() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend)
.await
.expect("persistence manager should initialize"),
);
let (client, _rx) = Client::new(
Arc::new(crate::runtime_impl::TokioRuntime),
pm,
Arc::new(crate::transport::mock::MockTransportFactory::new()),
Arc::new(MockHttpClient),
None,
)
.await;
client.expected_disconnect.store(true, Ordering::Relaxed);
let receipt = NodeBuilder::new("receipt")
.attr("from", "120363040237990503@g.us")
.attr("id", "TEST-RECEIPT-ID")
.build();
let result = client.send_ack_for(&receipt).await;
assert!(
result.is_ok(),
"send_ack_for should return Ok during expected disconnect"
);
}
}