use crate::client::Client;
use crate::types::message::EditAttribute;
use anyhow::anyhow;
use log::debug;
use wacore::client::context::SendContextResolver;
use wacore::libsignal::protocol::SignalProtocolError;
use wacore::types::jid::JidExt;
use wacore::types::message::AddressingMode;
#[cfg(test)]
use wacore_binary::DeviceKey;
use wacore_binary::Node;
use wacore_binary::builder::NodeBuilder;
use wacore_binary::{Jid, JidExt as _, Server};
use waproto::whatsapp as wa;
#[derive(Debug, Clone, Default)]
pub struct SendOptions {
pub message_id: Option<String>,
pub extra_stanza_nodes: Vec<Node>,
pub ephemeral_expiration: Option<u32>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SendResult {
pub message_id: String,
pub to: Jid,
}
impl SendResult {
pub fn message_key(&self) -> wa::MessageKey {
wa::MessageKey {
remote_jid: Some(self.to.to_string()),
from_me: Some(true),
id: Some(self.message_id.clone()),
participant: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum PinDuration {
Hours24,
#[default]
Days7,
Days30,
}
impl PinDuration {
fn as_secs(self) -> u32 {
match self {
Self::Hours24 => 86_400,
Self::Days7 => 604_800,
Self::Days30 => 2_592_000,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum RevokeType {
#[default]
Sender,
Admin { original_sender: Jid },
}
fn infer_stanza_metadata(msg: &wa::Message) -> (Option<EditAttribute>, Option<Node>) {
if msg.pin_in_chat_message.is_some() {
return (Some(EditAttribute::PinInChat), None);
}
if msg.poll_creation_message.is_some()
|| msg.poll_creation_message_v2.is_some()
|| msg.poll_creation_message_v3.is_some()
{
return (None, Some(meta_node("polltype", "creation")));
}
if let Some(ref poll_update) = msg.poll_update_message
&& poll_update.vote.is_some()
{
return (None, Some(meta_node("polltype", "vote")));
}
if msg.event_message.is_some() {
return (None, Some(meta_node("event_type", "creation")));
}
if msg.enc_event_response_message.is_some() {
return (None, Some(meta_node("event_type", "response")));
}
if let Some(ref sec) = msg.secret_encrypted_message
&& sec.secret_enc_type
== Some(wa::message::secret_encrypted_message::SecretEncType::EventEdit as i32)
{
return (None, Some(meta_node("event_type", "edit")));
}
(None, None)
}
fn meta_node(key: &'static str, value: &'static str) -> Node {
NodeBuilder::new("meta").attr(key, value).build()
}
fn infer_biz_node(msg: &wa::Message) -> Option<Node> {
let interactive = extract_interactive_message(msg)?;
let wa::message::interactive_message::InteractiveMessage::NativeFlowMessage(nf) =
interactive.interactive_message.as_ref()?
else {
return None;
};
let first_button_name = nf.buttons.first()?.name.as_deref()?;
let flow_name = button_name_to_flow_name(first_button_name);
Some(
NodeBuilder::new("biz")
.children([NodeBuilder::new("interactive")
.attr("type", "native_flow")
.attr("v", "1")
.children([NodeBuilder::new("native_flow")
.attr("name", flow_name)
.build()])
.build()])
.build(),
)
}
fn extract_interactive_message(msg: &wa::Message) -> Option<&wa::message::InteractiveMessage> {
if let Some(ref doc) = msg.document_with_caption_message
&& let Some(ref inner) = doc.message
&& let Some(ref im) = inner.interactive_message
{
return Some(im);
}
msg.interactive_message.as_deref()
}
fn button_name_to_flow_name(button_name: &str) -> &str {
match button_name {
"review_and_pay" => "order_details",
"payment_info" => "payment_info",
"review_order" | "order_status" => "order_status",
"payment_status" => "payment_status",
"payment_method" => "payment_method",
"payment_reminder" => "payment_reminder",
"open_webview" => "message_with_link",
"message_with_link_status" => "message_with_link_status",
"cta_url" => "cta_url",
"cta_call" => "cta_call",
"cta_copy" => "cta_copy",
"cta_catalog" => "cta_catalog",
"catalog_message" => "catalog_message",
"quick_reply" => "quick_reply",
"galaxy_message" => "galaxy_message",
"booking_confirmation" => "booking_confirmation",
"call_permission_request" => "call_permission_request",
other => other,
}
}
fn build_revoke_message(
remote_jid: &Jid,
from_me: bool,
message_id: String,
participant: Option<String>,
) -> wa::Message {
wa::Message {
protocol_message: Some(Box::new(wa::message::ProtocolMessage {
key: Some(wa::MessageKey {
remote_jid: Some(remote_jid.to_string()),
from_me: Some(from_me),
id: Some(message_id),
participant,
}),
r#type: Some(wa::message::protocol_message::Type::Revoke as i32),
..Default::default()
})),
..Default::default()
}
}
impl Client {
pub async fn send_message(
&self,
to: Jid,
message: wa::Message,
) -> Result<SendResult, anyhow::Error> {
self.send_message_with_options(to, message, SendOptions::default())
.await
}
pub async fn send_message_with_options(
&self,
to: Jid,
mut message: wa::Message,
options: SendOptions,
) -> Result<SendResult, anyhow::Error> {
if let Some(exp) = options.ephemeral_expiration
&& exp > 0
{
use wacore::proto_helpers::MessageExt;
if !message.set_ephemeral_expiration(exp) {
log::warn!("Could not set contextInfo.expiration on this message type");
}
}
let request_id = match options.message_id {
Some(id) => id,
None => self.generate_message_id().await,
};
let result = SendResult {
message_id: request_id.clone(),
to: to.clone(),
};
if to.is_newsletter() {
use prost::Message as _;
let stanza_type = wacore::send::stanza_type_from_message(&message);
let (_, meta_node) = infer_stanza_metadata(&message);
let mut plaintext_builder = NodeBuilder::new("plaintext");
if let Some(mt) = wacore::send::media_type_from_message(&message) {
plaintext_builder = plaintext_builder.attr("mediatype", mt);
}
let mut children = vec![plaintext_builder.bytes(message.encode_to_vec()).build()];
children.extend(meta_node);
children.extend(options.extra_stanza_nodes);
let stanza = NodeBuilder::new("message")
.attr("to", to)
.attr("type", stanza_type)
.attr("id", &request_id)
.children(children)
.build();
self.send_node(stanza).await?;
return Ok(result);
}
let (edit, inferred_meta) = infer_stanza_metadata(&message);
let inferred_biz = infer_biz_node(&message);
let extra_nodes = if inferred_meta.is_none() && inferred_biz.is_none() {
options.extra_stanza_nodes
} else {
let mut nodes = Vec::with_capacity(2 + options.extra_stanza_nodes.len());
nodes.extend(inferred_meta);
nodes.extend(inferred_biz);
nodes.extend(options.extra_stanza_nodes);
nodes
};
self.send_message_impl(
to,
&message,
Some(request_id),
false,
false,
edit,
extra_nodes,
)
.await?;
Ok(result)
}
pub(crate) async fn send_status_message(
&self,
message: wa::Message,
recipients: &[Jid],
options: crate::features::status::StatusSendOptions,
) -> Result<SendResult, anyhow::Error> {
use wacore::client::context::GroupInfo;
use wacore_binary::builder::NodeBuilder;
if recipients.is_empty() {
return Err(anyhow!("Cannot send status with no recipients"));
}
let to = Jid::status_broadcast();
let request_id = self.generate_message_id().await;
let mut device_snapshot = self.persistence_manager.get_device_snapshot().await;
let account_info = device_snapshot.account.take();
let own_jid = device_snapshot
.pn
.take()
.ok_or(crate::client::ClientError::NotLoggedIn)?;
let own_lid = device_snapshot.lid.take().ok_or_else(|| {
anyhow!(
"Cannot send status: device has no LID yet. Finish pairing / LID \
migration before posting status."
)
})?;
for jid in recipients {
if !(jid.is_pn() || jid.is_lid()) {
return Err(anyhow!(
"Invalid status recipient {}: must be a user JID (PN or LID), \
not a group/broadcast/newsletter/hosted/etc.",
jid
));
}
}
use std::collections::HashMap;
let mut resolved: Vec<Option<Jid>> = Vec::with_capacity(recipients.len());
let mut lid_to_pn_map: HashMap<wacore_binary::CompactString, Jid> =
HashMap::with_capacity(recipients.len() + 1);
for jid in recipients {
if let Some(lid_jid) = self.resolve_recipient_to_lid(jid).await {
if jid.is_pn() {
lid_to_pn_map.insert(lid_jid.user.clone(), jid.to_non_ad());
}
resolved.push(Some(lid_jid));
} else {
resolved.push(None);
}
}
lid_to_pn_map.insert(own_lid.user.clone(), own_jid.to_non_ad());
let participants = wacore::send::assemble_status_participants(resolved, &own_lid)?;
let mut group_info =
GroupInfo::with_lid_to_pn_map(participants, AddressingMode::Lid, lid_to_pn_map);
self.add_recent_message(&to, &request_id, &message).await;
let device_store_arc = self.persistence_manager.get_device_arc().await;
let to_str = to.to_string();
let force_skdm = {
use wacore::libsignal::store::sender_key_name::SenderKeyName;
let sender_address = own_lid.to_protocol_address();
let sender_key_name = SenderKeyName::from_parts(&to_str, sender_address.as_str());
let device_guard = device_store_arc.read().await;
let key_exists = self
.signal_cache
.get_sender_key(&sender_key_name, &*device_guard.backend)
.await?
.is_some();
!key_exists
};
let mut store_adapter = self.signal_adapter_from(device_store_arc.clone());
let mut stores = store_adapter.as_signal_stores();
let skdm_target_devices: Option<Vec<Jid>> = if force_skdm {
None
} else {
self.resolve_skdm_targets(&to_str, &group_info, &own_lid)
.await
};
let extra_stanza_nodes = if wacore::send::status_carries_privacy_meta(&message) {
vec![
NodeBuilder::new("meta")
.attr("status_setting", options.privacy.as_str())
.build(),
]
} else {
vec![]
};
let prepared = match wacore::send::prepare_group_stanza(
&*self.runtime,
&mut stores,
self,
&mut group_info,
&own_jid,
&own_lid,
account_info.as_ref(),
to.clone(),
&message,
request_id.clone(),
force_skdm,
skdm_target_devices,
None,
&extra_stanza_nodes,
)
.await
{
Ok(prepared) => prepared,
Err(e) => {
if let Some(SignalProtocolError::NoSenderKeyState(_)) =
e.downcast_ref::<SignalProtocolError>()
{
log::warn!("No sender key for status broadcast, forcing distribution.");
if let Err(e) = self
.persistence_manager
.clear_sender_key_devices(&to_str)
.await
{
log::warn!(
"Failed to clear status SKDM recipients for {}: {:?}",
to_str,
e
);
}
self.sender_key_device_cache.invalidate(&to_str).await;
let mut store_adapter_retry =
self.signal_adapter_from(device_store_arc.clone());
let mut stores_retry = store_adapter_retry.as_signal_stores();
wacore::send::prepare_group_stanza(
&*self.runtime,
&mut stores_retry,
self,
&mut group_info,
&own_jid,
&own_lid,
account_info.as_ref(),
to.clone(),
&message,
request_id.clone(),
true,
None,
None,
&extra_stanza_nodes,
)
.await?
} else {
return Err(e);
}
}
};
let stanza = self
.ensure_status_participants(prepared.node, &group_info)
.await?;
let ack = if let Some(phash) = stanza
.attrs()
.optional_string("phash")
.map(|s| s.into_owned())
{
let rx = self.register_ack_waiter(&request_id).await;
Some((rx, phash))
} else {
None
};
if let Err(e) = self.send_node(stanza).await {
if ack.is_some() {
self.response_waiters.lock().await.remove(&request_id);
}
return Err(e.into());
}
if let Some((rx, phash)) = ack {
self.spawn_phash_validation(rx, phash, to.clone(), true, request_id.clone());
}
self.update_sender_key_devices(&to_str, &prepared.skdm_devices)
.await;
for user in &prepared.stale_device_users {
self.invalidate_device_cache(user).await;
}
self.flush_signal_cache_logged("send_status_message", None)
.await;
Ok(SendResult {
message_id: request_id,
to,
})
}
async fn resolve_skdm_targets(
&self,
group_jid: &str,
group_info: &wacore::client::context::GroupInfo,
own_sending_jid: &Jid,
) -> Option<Vec<Jid>> {
use crate::sender_key_device_cache::SenderKeyDeviceMap;
let pm = self.persistence_manager.clone();
let cached_map = self
.sender_key_device_cache
.get_or_init(group_jid, async {
let db_rows = pm
.get_sender_key_devices(group_jid)
.await
.unwrap_or_else(|e| {
log::warn!(
"Failed to read sender key devices for {}: {:?}",
group_jid,
e
);
vec![]
});
std::sync::Arc::new(SenderKeyDeviceMap::from_db_rows(&db_rows))
})
.await;
let is_lid_mode = group_info.addressing_mode == wacore::types::message::AddressingMode::Lid;
let jids_to_resolve: Vec<Jid> = group_info
.participants
.iter()
.map(|jid| {
if is_lid_mode
&& jid.is_lid()
&& let Some(pn) = group_info.phone_jid_for_lid_user(&jid.user)
{
return pn.to_non_ad();
}
jid.to_non_ad()
})
.collect();
match SendContextResolver::resolve_devices(self, &jids_to_resolve).await {
Ok(all_devices) => {
let all_devices: Vec<Jid> = if is_lid_mode {
all_devices
.into_iter()
.map(|d| group_info.phone_device_jid_to_lid(&d))
.collect()
} else {
all_devices
};
let needs_skdm: Vec<Jid> = all_devices
.into_iter()
.filter(|device| {
if device.is_hosted() {
return false;
}
if device.user == own_sending_jid.user
&& device.device == own_sending_jid.device
{
return false;
}
!cached_map
.device_has_key(&device.user, device.device)
.unwrap_or(false)
|| cached_map.is_user_forgotten(&device.user)
})
.collect();
if needs_skdm.is_empty() {
Some(vec![])
} else {
log::debug!(
"Found {} devices needing SKDM for {}",
needs_skdm.len(),
group_jid
);
Some(needs_skdm)
}
}
Err(e) => {
log::warn!(
"Failed to resolve devices for SKDM check in {}: {:?}",
group_jid,
e
);
None
}
}
}
async fn update_sender_key_devices(&self, group_jid: &str, devices: &[Jid]) {
if devices.is_empty() {
return;
}
if let Err(e) = self
.set_sender_key_status_for_devices(group_jid, devices, true, false)
.await
{
log::warn!(
"Failed to update sender key devices for {}: {:?}",
group_jid,
e
);
}
self.sender_key_device_cache.invalidate(group_jid).await;
}
fn spawn_phash_validation(
&self,
rx: futures::channel::oneshot::Receiver<std::sync::Arc<wacore_binary::OwnedNodeRef>>,
our_phash: String,
jid: Jid,
invalidate_group_cache: bool,
message_id: String,
) {
let Some(client) = self.self_weak.get().and_then(|w| w.upgrade()) else {
return;
};
self.runtime
.spawn(Box::pin(async move {
let ack = match wacore::runtime::timeout(
&*client.runtime,
std::time::Duration::from_secs(10),
rx,
)
.await
{
Ok(Ok(node)) => node,
_ => {
client.response_waiters.lock().await.remove(&message_id);
return;
}
};
if let Some(server) = ack.get().get_attr("phash").map(|v| v.as_str())
&& server != our_phash
{
log::warn!(
"Phash mismatch for {jid}: ours={our_phash}, server={server}. Invalidating caches."
);
if !jid.is_group() && !jid.is_status_broadcast() {
client.invalidate_device_cache(&jid.user).await;
if let Some(own_pn) =
&client.persistence_manager.get_device_snapshot().await.pn
{
client.invalidate_device_cache(&own_pn.user).await;
}
}
let jid_str = jid.to_string();
if jid.is_group() || jid.is_status_broadcast() {
let cleared = client
.persistence_manager
.clear_sender_key_devices(&jid_str)
.await;
if let Err(e) = cleared {
log::warn!(
"phash mismatch: clear_sender_key_devices failed: {e} — \
deleting own sender key as fallback to force redistribution"
);
use wacore::libsignal::store::sender_key_name::SenderKeyName;
use wacore::types::jid::JidExt;
let snapshot =
client.persistence_manager.get_device_snapshot().await;
for own in snapshot.lid.iter().chain(snapshot.pn.iter()) {
let sk =
SenderKeyName::from_parts(&jid_str, own.to_protocol_address().as_str());
client.signal_cache.delete_sender_key(sk.cache_key()).await;
}
let _ = client
.flush_signal_cache_logged("phash-mismatch-fallback", None)
.await;
}
}
client.sender_key_device_cache.invalidate(&jid_str).await;
if invalidate_group_cache {
client.get_group_cache().await.invalidate(&jid).await;
}
}
}))
.detach();
}
async fn ensure_status_participants(
&self,
stanza: wacore_binary::Node,
group_info: &wacore::client::context::GroupInfo,
) -> Result<wacore_binary::Node, anyhow::Error> {
Ok(wacore::send::ensure_status_participants(stanza, group_info))
}
pub async fn revoke_message(
&self,
to: Jid,
message_id: impl Into<String>,
revoke_type: RevokeType,
) -> Result<(), anyhow::Error> {
let message_id = message_id.into();
self.require_pn().await?;
let (from_me, participant, edit_attr) = match &revoke_type {
RevokeType::Sender => {
(
true,
None,
crate::types::message::EditAttribute::SenderRevoke,
)
}
RevokeType::Admin { original_sender } => {
if !to.is_group() {
return Err(anyhow!("Admin revoke is only valid for group chats"));
}
let participant_str = original_sender.to_non_ad().to_string();
log::debug!(
"Admin revoke: using participant {} for MessageKey",
participant_str
);
(
false,
Some(participant_str),
crate::types::message::EditAttribute::AdminRevoke,
)
}
};
let revoke_message = build_revoke_message(&to, from_me, message_id, participant);
let force_skdm = matches!(revoke_type, RevokeType::Admin { .. });
self.send_message_impl(
to,
&revoke_message,
None,
false,
force_skdm,
Some(edit_attr),
vec![],
)
.await
}
pub async fn pin_message(
&self,
chat: Jid,
key: wa::MessageKey,
duration: PinDuration,
) -> Result<(), anyhow::Error> {
self.send_pin(
chat,
key,
wa::message::pin_in_chat_message::Type::PinForAll,
duration.as_secs(),
)
.await
}
pub async fn unpin_message(&self, chat: Jid, key: wa::MessageKey) -> Result<(), anyhow::Error> {
self.send_pin(
chat,
key,
wa::message::pin_in_chat_message::Type::UnpinForAll,
0,
)
.await
}
async fn send_pin(
&self,
chat: Jid,
key: wa::MessageKey,
pin_type: wa::message::pin_in_chat_message::Type,
duration_secs: u32,
) -> Result<(), anyhow::Error> {
let message = wa::Message {
pin_in_chat_message: Some(wa::message::PinInChatMessage {
key: Some(key),
r#type: Some(pin_type as i32),
sender_timestamp_ms: Some(wacore::time::now_millis()),
}),
message_context_info: Some(wa::MessageContextInfo {
message_add_on_duration_in_secs: Some(duration_secs),
..Default::default()
}),
..Default::default()
};
self.send_message_impl(
chat,
&message,
None,
false,
false,
Some(crate::types::message::EditAttribute::PinInChat),
vec![],
)
.await
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn send_message_impl(
&self,
to: Jid,
message: &wa::Message,
request_id_override: Option<String>,
peer: bool,
force_key_distribution: bool,
edit: Option<crate::types::message::EditAttribute>,
extra_stanza_nodes: Vec<Node>,
) -> Result<(), anyhow::Error> {
let (to, is_status_addon) = if to.is_status_broadcast() {
let author = message
.reaction_message
.as_ref()
.and_then(|rm| rm.key.as_ref())
.and_then(|k| k.participant.as_ref())
.and_then(|p| p.parse::<Jid>().ok())
.filter(|jid| jid.is_pn() || jid.is_lid())
.ok_or_else(|| {
anyhow!(
"send_message to status@broadcast requires \
reaction_message.key.participant = status author (user JID). \
Use client.status() for posting new statuses."
)
})?;
(author, true)
} else {
(to, false)
};
let request_id = match request_id_override {
Some(id) => id,
None => self.generate_message_id().await,
};
struct SkdmUpdate {
to_str: String,
devices: Vec<Jid>,
stale_users: Vec<String>,
}
let mut skdm_update: Option<SkdmUpdate> = None;
let mut should_issue_tc_token_after_send = false;
let mut used_cached_tc_token_key: Option<String> = None;
let tc_issue_target = to.clone();
let mut dm_phash: Option<String> = None;
let stanza_to_send: wacore_binary::Node = if peer && !to.is_group() {
let encryption_jid = self.resolve_encryption_jid(&to).await;
let signal_addr = encryption_jid.to_protocol_address();
let session_mutex = self.session_lock_for(signal_addr.as_str()).await;
let _session_guard = session_mutex.lock().await;
let mut store_adapter = self.signal_adapter().await;
wacore::send::prepare_peer_stanza(
&mut store_adapter.session_store,
&mut store_adapter.identity_store,
to,
&signal_addr,
message,
request_id,
)
.await?
} else if to.is_group() {
let mut group_info = self.groups().query_info(&to).await?;
let mut device_snapshot = self.persistence_manager.get_device_snapshot().await;
let account_info = device_snapshot.account.take();
let own_jid = device_snapshot
.pn
.take()
.ok_or(crate::client::ClientError::NotLoggedIn)?;
let own_lid = device_snapshot
.lid
.take()
.ok_or_else(|| anyhow!("LID not set, cannot send to group"))?;
self.add_recent_message(&to, &request_id, message).await;
let device_store_arc = self.persistence_manager.get_device_arc().await;
let to_str = to.to_string();
let (own_sending_jid, _) = match group_info.addressing_mode {
crate::types::message::AddressingMode::Lid => (own_lid.clone(), "lid"),
crate::types::message::AddressingMode::Pn => (own_jid.clone(), "pn"),
};
if !group_info
.participants
.iter()
.any(|participant| participant.is_same_user_as(&own_sending_jid))
{
group_info.participants.push(own_sending_jid.to_non_ad());
}
let force_skdm = {
use wacore::libsignal::store::sender_key_name::SenderKeyName;
let sender_address = own_sending_jid.to_protocol_address();
let sender_key_name = SenderKeyName::from_parts(&to_str, sender_address.as_str());
let device_guard = device_store_arc.read().await;
let record = self
.signal_cache
.get_sender_key(&sender_key_name, &*device_guard.backend)
.await?;
let key_exists = record.is_some();
const SENDER_KEY_ROTATION_THRESHOLD: u32 = 1000;
let needs_rotation = record
.and_then(|mut r| r.sender_key_state_mut().ok().cloned())
.and_then(|state| state.sender_chain_key().map(|ck| ck.iteration()))
.is_some_and(|iter| iter >= SENDER_KEY_ROTATION_THRESHOLD);
drop(device_guard);
if needs_rotation {
log::info!(
"Periodic sender-key rotation for {to} (chain iteration ≥ {SENDER_KEY_ROTATION_THRESHOLD})"
);
self.signal_cache
.delete_sender_key(sender_key_name.cache_key())
.await;
if let Err(e) = self
.persistence_manager
.clear_sender_key_devices(&to_str)
.await
{
log::warn!("periodic rotation: clear_sender_key_devices failed: {e}");
}
self.sender_key_device_cache.invalidate(&to_str).await;
}
force_key_distribution || !key_exists || needs_rotation
};
let mut store_adapter = self.signal_adapter_from(device_store_arc.clone());
let mut stores = store_adapter.as_signal_stores();
let skdm_target_devices: Option<Vec<Jid>> = if force_skdm {
None
} else {
self.resolve_skdm_targets(&to_str, &group_info, &own_sending_jid)
.await
};
match wacore::send::prepare_group_stanza(
&*self.runtime,
&mut stores,
self,
&mut group_info,
&own_jid,
&own_lid,
account_info.as_ref(),
to.clone(),
message,
request_id.clone(),
force_skdm,
skdm_target_devices,
edit.clone(),
&extra_stanza_nodes,
)
.await
{
Ok(prepared) => {
skdm_update = Some(SkdmUpdate {
to_str: to_str.clone(),
devices: prepared.skdm_devices,
stale_users: prepared.stale_device_users,
});
prepared.node
}
Err(e) => {
if let Some(SignalProtocolError::NoSenderKeyState(_)) =
e.downcast_ref::<SignalProtocolError>()
{
log::warn!("No sender key for group {}, forcing distribution.", to);
if let Err(e) = self
.persistence_manager
.clear_sender_key_devices(&to_str)
.await
{
log::warn!("Failed to clear SKDM recipients: {:?}", e);
}
self.sender_key_device_cache.invalidate(&to_str).await;
let mut store_adapter_retry =
self.signal_adapter_from(device_store_arc.clone());
let mut stores_retry = store_adapter_retry.as_signal_stores();
let retry_prepared = wacore::send::prepare_group_stanza(
&*self.runtime,
&mut stores_retry,
self,
&mut group_info,
&own_jid,
&own_lid,
account_info.as_ref(),
to,
message,
request_id,
true,
None,
edit.clone(),
&extra_stanza_nodes,
)
.await?;
skdm_update = Some(SkdmUpdate {
to_str,
devices: retry_prepared.skdm_devices,
stale_users: retry_prepared.stale_device_users,
});
retry_prepared.node
} else {
return Err(e);
}
}
}
} else {
if is_status_addon {
self.add_recent_message(&Jid::status_broadcast(), &request_id, message)
.await;
} else {
self.add_recent_message(&to, &request_id, message).await;
}
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
let own_jid = device_snapshot
.pn
.as_ref()
.ok_or(crate::client::ClientError::NotLoggedIn)?;
if to.is_pn() && self.lid_pn_cache.get_current_lid(&to.user).await.is_none() {
let sid = self.generate_request_id();
let spec = wacore::iq::usync::LidQuerySpec::new(vec![to.to_non_ad()], sid);
match self.execute(spec).await {
Ok(resp) => {
for mapping in &resp.lid_mappings {
if let Err(e) = self
.add_lid_pn_mapping(
&mapping.lid,
&mapping.phone_number,
crate::lid_pn_cache::LearningSource::Usync,
)
.await
{
log::warn!(
"Failed to persist LID mapping {} -> {}: {e:?}",
mapping.phone_number,
mapping.lid
);
}
}
}
Err(e) => {
log::warn!("LID query failed for {}, falling back to PN: {e:?}", to);
}
}
}
let recipient_bare = self.resolve_encryption_jid(&to).await.to_non_ad();
let mut recipient_cached = self.get_devices_from_registry(&recipient_bare).await;
if recipient_cached.is_none() {
let _ = self.get_user_devices(std::slice::from_ref(&to)).await;
recipient_cached = self.get_devices_from_registry(&recipient_bare).await;
}
let mut own_cached = self.get_devices_from_registry(own_jid).await;
if own_cached.is_none() {
let _ = self.get_user_devices(std::slice::from_ref(own_jid)).await;
own_cached = self.get_devices_from_registry(own_jid).await;
}
let mut all_dm_jids = match recipient_cached {
Some(mut devices) => {
devices.retain(|j| !j.is_hosted());
devices
}
None => vec![recipient_bare],
};
if let Some(mut own_devices) = own_cached {
own_devices.retain(|j| !j.is_hosted());
all_dm_jids.append(&mut own_devices);
}
let own_lid = device_snapshot.lid.as_ref();
all_dm_jids.retain(|j| {
let is_sender = (j.is_same_user_as(own_jid) && j.device == own_jid.device)
|| own_lid.is_some_and(|lid| j.is_same_user_as(lid) && j.device == lid.device);
!is_sender
});
wacore::types::jid::sort_dedup_by_device(&mut all_dm_jids);
self.ensure_e2e_sessions(&all_dm_jids).await?;
let mut extra_stanza_nodes = extra_stanza_nodes;
if !to.is_group() && !to.is_newsletter() && !is_status_addon {
let (should_issue_after_send, cached_token_key) = self
.maybe_include_tc_token(&to, &mut extra_stanza_nodes)
.await;
should_issue_tc_token_after_send = should_issue_after_send;
if should_issue_after_send {
used_cached_tc_token_key = cached_token_key;
}
}
if should_issue_tc_token_after_send {
debug!(target: "Client/TcToken", "Scheduled tc token issuance after send for {}", to);
}
let lock_jids = self.build_session_lock_keys(&all_dm_jids).await;
let _session_mutexes = self.session_mutexes_for(&lock_jids).await;
let mut _session_guards = Vec::with_capacity(_session_mutexes.len());
for mutex in &_session_mutexes {
_session_guards.push(mutex.lock().await);
}
let mut store_adapter = self.signal_adapter().await;
let mut stores = store_adapter.as_signal_stores();
let prepared = wacore::send::prepare_dm_stanza(
&*self.runtime,
&mut stores,
self,
own_jid,
device_snapshot.lid.as_ref(),
device_snapshot.account.as_ref(),
to,
message,
request_id,
edit,
&extra_stanza_nodes,
all_dm_jids,
)
.await?;
dm_phash = prepared.phash;
prepared.node
};
let ack = if let Some(phash) = dm_phash
&& let Some(msg_id) = stanza_to_send
.attrs()
.optional_string("id")
.map(|s| s.into_owned())
{
let rx = self.register_ack_waiter(&msg_id).await;
Some((rx, phash, msg_id))
} else {
None
};
let mut stanza_to_send = stanza_to_send;
if is_status_addon {
stanza_to_send.attrs.insert("to", Jid::status_broadcast());
}
if let Err(e) = self.send_node(stanza_to_send).await {
if let Some((_, _, ref msg_id)) = ack {
self.response_waiters.lock().await.remove(msg_id);
}
return Err(e.into());
}
if let Some((rx, phash, msg_id)) = ack {
let invalidate_group = tc_issue_target.is_group();
self.spawn_phash_validation(
rx,
phash,
tc_issue_target.clone(),
invalidate_group,
msg_id,
);
}
if let Some(update) = skdm_update {
self.update_sender_key_devices(&update.to_str, &update.devices)
.await;
for user in &update.stale_users {
self.invalidate_device_cache(user).await;
}
}
self.flush_signal_cache_logged("send_message_impl", None)
.await;
if should_issue_tc_token_after_send {
if let Some(client) = self.self_weak.get().and_then(|w| w.upgrade()) {
let target = tc_issue_target;
let cached_key = used_cached_tc_token_key;
self.runtime
.spawn(Box::pin(async move {
let issued_ok = client.issue_tc_token_after_send(&target).await;
if issued_ok && let Some(token_key) = cached_key {
client.mark_tc_token_used_after_send(&token_key).await;
}
}))
.detach();
} else {
log::debug!(target: "Client/TcToken", "Skipping fire-and-forget issuance: client dropped");
}
}
Ok(())
}
async fn maybe_include_tc_token(
&self,
to: &Jid,
extra_nodes: &mut Vec<Node>,
) -> (bool, Option<String>) {
use wacore::iq::props::config_codes;
use wacore::iq::tctoken::{
build_cs_token_node, build_tc_token_node, compute_cs_token, is_tc_token_expired_with,
should_send_new_tc_token_with,
};
let snapshot = self.persistence_manager.get_device_snapshot().await;
let is_self = snapshot
.pn
.as_ref()
.is_some_and(|pn| pn.is_same_user_as(to))
|| snapshot
.lid
.as_ref()
.is_some_and(|lid| lid.is_same_user_as(to));
if is_self {
return (false, None);
}
if to.is_bot() || to.is_status_broadcast() {
return (false, None);
}
let cached_lid = if to.is_lid() {
None
} else {
self.lid_pn_cache.get_current_lid(&to.user).await
};
let resolved_lid_user: Option<&str> = if to.is_lid() {
Some(&to.user)
} else {
cached_lid.as_deref()
};
let token_jid: &str = resolved_lid_user.unwrap_or(&to.user);
let backend = self.persistence_manager.backend();
let tc_config = self.tc_token_config().await;
let existing = match backend.get_tc_token(token_jid).await {
Ok(entry) => entry,
Err(e) => {
log::warn!(target: "Client/TcToken", "Failed to get tc_token for {}: {e}", token_jid);
None
}
};
let should_issue_after_send = should_send_new_tc_token_with(
existing.as_ref().and_then(|entry| entry.sender_timestamp),
&tc_config,
);
let token_send_enabled = self
.ab_props
.is_enabled_or(config_codes::PRIVACY_TOKEN_ON_ALL_1_ON_1_MESSAGES, false)
.await;
if token_send_enabled {
match existing {
Some(ref entry)
if !is_tc_token_expired_with(entry.token_timestamp, &tc_config)
&& !entry.token.is_empty() =>
{
extra_nodes.push(build_tc_token_node(&entry.token));
return (should_issue_after_send, Some(token_jid.to_string()));
}
_ => {
let nct_send_enabled = self
.ab_props
.is_enabled_or(config_codes::NCT_TOKEN_SEND_ENABLED, false)
.await;
if nct_send_enabled
&& let Some(salt) = &snapshot.nct_salt
&& let Some(lid_user) = &resolved_lid_user
{
let recipient_lid =
wacore_binary::Jid::new(*lid_user, Server::Lid).to_string();
let cs_token = compute_cs_token(salt, &recipient_lid);
extra_nodes.push(build_cs_token_node(&cs_token));
log::debug!(target: "Client/CsToken", "Attached cstoken for {} (NCT fallback)", to);
} else {
log::debug!(target: "Client/CsToken", "No tctoken or NCT salt/LID available for {}", to);
}
}
}
}
(should_issue_after_send, None)
}
async fn issue_tc_token_after_send(&self, to: &Jid) -> bool {
use wacore::iq::tctoken::IssuePrivacyTokensSpec;
if to.is_bot() || to.is_status_broadcast() {
return false;
}
let issuance_jid = self.resolve_issuance_jid(to).await;
let Ok(response) = self
.execute(IssuePrivacyTokensSpec::new(std::slice::from_ref(
&issuance_jid,
)))
.await
else {
log::debug!(target: "Client/TcToken", "Failed to issue tc_token for {}", issuance_jid);
return false;
};
self.store_issued_tc_tokens(&response.tokens).await
}
pub(crate) async fn store_issued_tc_tokens(
&self,
tokens: &[wacore::iq::tctoken::ReceivedTcToken],
) -> bool {
use wacore::store::traits::TcTokenEntry;
if tokens.is_empty() {
return false;
}
let backend = self.persistence_manager.backend();
let now = wacore::time::now_secs();
let mut any_stored = false;
for received in tokens {
if received.token.is_empty() {
log::warn!(target: "Client/TcToken", "Server returned empty tc_token for {}, skipping", received.jid);
continue;
}
let entry = TcTokenEntry {
token: received.token.clone(),
token_timestamp: received.timestamp,
sender_timestamp: Some(now),
};
if let Err(e) = backend.put_tc_token(&received.jid.user, &entry).await {
log::warn!(target: "Client/TcToken", "Failed to store issued tc_token: {e}");
} else {
any_stored = true;
}
}
any_stored
}
async fn store_issued_tc_tokens_with_sender_ts(
&self,
tokens: &[wacore::iq::tctoken::ReceivedTcToken],
sender_ts: i64,
) {
use wacore::store::traits::TcTokenEntry;
let backend = self.persistence_manager.backend();
for received in tokens {
if received.token.is_empty() {
continue;
}
let entry = TcTokenEntry {
token: received.token.clone(),
token_timestamp: received.timestamp,
sender_timestamp: Some(sender_ts),
};
if let Err(e) = backend.put_tc_token(&received.jid.user, &entry).await {
log::warn!(target: "Client/TcToken", "Failed to store re-issued tc_token: {e}");
}
}
}
async fn mark_tc_token_used_after_send(&self, token_key: &str) {
use wacore::store::traits::TcTokenEntry;
let backend = self.persistence_manager.backend();
let existing = match backend.get_tc_token(token_key).await {
Ok(entry) => entry,
Err(e) => {
log::warn!(target: "Client/TcToken", "Failed to reload tc_token for {}: {e}", token_key);
return;
}
};
let Some(entry) = existing else {
return;
};
if entry.token.is_empty() {
return;
}
let updated_entry = TcTokenEntry {
sender_timestamp: Some(wacore::time::now_secs()),
..entry
};
if let Err(e) = backend.put_tc_token(token_key, &updated_entry).await {
log::warn!(target: "Client/TcToken", "Failed to update sender_timestamp for {}: {e}", token_key);
}
}
pub(crate) async fn reissue_tc_token_after_identity_change(&self, sender: &Jid) {
use wacore::iq::tctoken::{IssuePrivacyTokensSpec, is_sender_tc_token_expired};
let bare = sender.to_non_ad().to_string();
let mutex = self.session_lock_for(&bare).await;
let Some(_guard) = mutex.try_lock() else {
return;
};
let resolved_lid = if sender.is_lid() {
None
} else {
self.lid_pn_cache.get_current_lid(&sender.user).await
};
let token_jid: &str = resolved_lid.as_deref().unwrap_or(&sender.user);
let backend = self.persistence_manager.backend();
let entry = match backend.get_tc_token(token_jid).await {
Ok(Some(e)) => e,
_ => return,
};
let Some(sender_ts) = entry.sender_timestamp else {
return;
};
let tc_config = self.tc_token_config().await;
if is_sender_tc_token_expired(sender_ts, &tc_config) {
return;
}
let issuance_jid = self.resolve_issuance_jid(sender).await;
match self
.execute(IssuePrivacyTokensSpec::with_timestamp(
std::slice::from_ref(&issuance_jid),
sender_ts,
))
.await
{
Ok(response) => {
self.store_issued_tc_tokens_with_sender_ts(&response.tokens, sender_ts)
.await;
log::debug!(
target: "Client/TcToken",
"Re-issued tctoken after identity change for {}",
sender
);
}
Err(e) => {
log::debug!(
target: "Client/TcToken",
"Failed to re-issue tctoken after identity change for {}: {e}",
sender
);
}
}
}
pub(crate) async fn lookup_tc_token_for_jid(&self, jid: &Jid) -> Option<Vec<u8>> {
use wacore::iq::tctoken::is_tc_token_expired_with;
let resolved_lid = if jid.is_lid() {
None
} else {
self.lid_pn_cache.get_current_lid(&jid.user).await
};
let token_jid: &str = resolved_lid.as_deref().unwrap_or(&jid.user);
let tc_config = self.tc_token_config().await;
let backend = self.persistence_manager.backend();
match backend.get_tc_token(token_jid).await {
Ok(Some(entry))
if !entry.token.is_empty()
&& !is_tc_token_expired_with(entry.token_timestamp, &tc_config) =>
{
Some(entry.token)
}
Ok(_) => None,
Err(e) => {
log::warn!(target: "Client/TcToken", "Failed to get tc_token for {}: {e}", token_jid);
None
}
}
}
pub(crate) async fn build_session_lock_keys(&self, device_jids: &[Jid]) -> Vec<Jid> {
let mut keys: Vec<Jid> = Vec::with_capacity(device_jids.len());
for jid in device_jids {
keys.push(self.resolve_encryption_jid(jid).await);
}
keys.sort_unstable_by(wacore::types::jid::cmp_for_lock_order);
keys.dedup_by(|a, b| wacore::types::jid::cmp_for_lock_order(a, b).is_eq());
keys
}
pub(crate) async fn session_mutexes_for(
&self,
jids: &[Jid],
) -> Vec<std::sync::Arc<async_lock::Mutex<()>>> {
let mut mutexes = Vec::with_capacity(jids.len());
let mut buf = wacore::types::jid::make_address_buffer();
for jid in jids {
wacore::types::jid::write_protocol_address_to(jid, &mut buf);
mutexes.push(self.session_lock_for(&buf).await);
}
mutexes
}
pub(crate) async fn tc_token_config(&self) -> wacore::iq::tctoken::TcTokenConfig {
use wacore::iq::props::config_codes;
use wacore::iq::tctoken::{TC_TOKEN_BUCKET_DURATION, TC_TOKEN_NUM_BUCKETS, TcTokenConfig};
TcTokenConfig {
bucket_duration: self
.ab_props
.get_int(config_codes::TCTOKEN_DURATION, TC_TOKEN_BUCKET_DURATION)
.await,
num_buckets: self
.ab_props
.get_int(config_codes::TCTOKEN_NUM_BUCKETS, TC_TOKEN_NUM_BUCKETS)
.await,
sender_bucket_duration: self
.ab_props
.get_int(
config_codes::TCTOKEN_DURATION_SENDER,
TC_TOKEN_BUCKET_DURATION,
)
.await,
sender_num_buckets: self
.ab_props
.get_int(
config_codes::TCTOKEN_NUM_BUCKETS_SENDER,
TC_TOKEN_NUM_BUCKETS,
)
.await,
}
.clamped()
}
async fn resolve_to_lid_jid(&self, jid: &Jid) -> Jid {
if jid.is_lid() {
return jid.to_non_ad();
}
if let Some(lid_user) = self.lid_pn_cache.get_current_lid(&jid.user).await {
Jid::new(&lid_user, Server::Lid)
} else {
jid.to_non_ad()
}
}
async fn resolve_issuance_jid(&self, jid: &Jid) -> Jid {
use wacore::iq::props::config_codes;
let issue_to_lid = self
.ab_props
.is_enabled_or(config_codes::LID_TRUSTED_TOKEN_ISSUE_TO_LID, true)
.await;
let resolved = if issue_to_lid {
self.resolve_to_lid_jid(jid).await
} else if jid.is_lid() {
if let Some(pn) = self.lid_pn_cache.get_phone_number(&jid.user).await {
Jid::new(&pn, Server::Pn)
} else {
jid.to_non_ad()
}
} else {
jid.to_non_ad()
};
resolved.to_non_ad()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[tokio::test]
async fn send_message_to_status_without_reaction_errors() {
let client = crate::test_utils::create_test_client().await;
let to = Jid::status_broadcast();
let err = client
.send_message(
to,
wa::Message {
conversation: Some("hi".into()),
..Default::default()
},
)
.await
.expect_err("status@broadcast without reaction must error");
let msg = format!("{err}");
assert!(
msg.contains("reaction_message") || msg.contains("status"),
"unexpected error: {msg}"
);
}
#[tokio::test]
async fn send_message_to_status_reaction_rejects_non_user_participant() {
let client = crate::test_utils::create_test_client().await;
let to = Jid::status_broadcast();
let err = client
.send_message(
to,
wa::Message {
reaction_message: Some(wa::message::ReactionMessage {
key: Some(wa::MessageKey {
remote_jid: Some("status@broadcast".into()),
from_me: Some(false),
id: Some("ORIGID".into()),
participant: Some("120363040237990503@g.us".into()),
}),
text: Some("❤️".into()),
sender_timestamp_ms: Some(1),
..Default::default()
}),
..Default::default()
},
)
.await
.expect_err("group JID as participant must error");
assert!(
format!("{err}").contains("user JID"),
"expected user-JID error, got: {err}"
);
}
#[tokio::test]
async fn send_message_to_status_reaction_without_participant_errors() {
let client = crate::test_utils::create_test_client().await;
let to = Jid::status_broadcast();
let err = client
.send_message(
to,
wa::Message {
reaction_message: Some(wa::message::ReactionMessage {
key: Some(wa::MessageKey {
remote_jid: Some("status@broadcast".into()),
from_me: Some(false),
id: Some("ORIGID".into()),
participant: None,
}),
text: Some("❤️".into()),
sender_timestamp_ms: Some(1),
..Default::default()
}),
..Default::default()
},
)
.await
.expect_err("reaction without key.participant must error");
assert!(
format!("{err}").contains("participant"),
"expected participant error, got: {err}"
);
}
#[test]
fn test_revoke_type_default_is_sender() {
let revoke_type = RevokeType::default();
assert_eq!(revoke_type, RevokeType::Sender);
}
#[test]
fn test_force_skdm_only_for_admin_revoke() {
let sender_jid = Jid::from_str("123456@s.whatsapp.net").unwrap();
let sender_revoke = RevokeType::Sender;
let admin_revoke = RevokeType::Admin {
original_sender: sender_jid,
};
let force_skdm_sender = matches!(sender_revoke, RevokeType::Admin { .. });
let force_skdm_admin = matches!(admin_revoke, RevokeType::Admin { .. });
assert!(!force_skdm_sender, "Sender revoke should NOT force SKDM");
assert!(force_skdm_admin, "Admin revoke MUST force SKDM");
}
#[test]
fn test_sender_revoke_message_key_structure() {
let to = Jid::from_str("120363040237990503@g.us").unwrap();
let message_id = "3EB0ABC123".to_string();
let (from_me, participant, edit_attr) = match RevokeType::Sender {
RevokeType::Sender => (
true,
None,
crate::types::message::EditAttribute::SenderRevoke,
),
RevokeType::Admin { original_sender } => (
false,
Some(original_sender.to_non_ad().to_string()),
crate::types::message::EditAttribute::AdminRevoke,
),
};
assert!(from_me, "Sender revoke must have from_me=true");
assert!(
participant.is_none(),
"Sender revoke must NOT set participant"
);
assert_eq!(edit_attr.to_string_val(), "7");
let revoke_message = build_revoke_message(&to, from_me, message_id.clone(), participant);
let proto_msg = revoke_message.protocol_message.unwrap();
let key = proto_msg.key.unwrap();
assert_eq!(key.from_me, Some(true));
assert_eq!(key.participant, None);
assert_eq!(key.id, Some(message_id));
}
#[test]
fn test_admin_revoke_message_key_structure() {
let to = Jid::from_str("120363040237990503@g.us").unwrap();
let message_id = "3EB0ABC123".to_string();
let original_sender = Jid::from_str("236395184570386:22@lid").unwrap();
let revoke_type = RevokeType::Admin {
original_sender: original_sender.clone(),
};
let (from_me, participant, edit_attr) = match revoke_type {
RevokeType::Sender => (
true,
None,
crate::types::message::EditAttribute::SenderRevoke,
),
RevokeType::Admin { original_sender } => (
false,
Some(original_sender.to_non_ad().to_string()),
crate::types::message::EditAttribute::AdminRevoke,
),
};
assert!(!from_me, "Admin revoke must have from_me=false");
assert!(
participant.is_some(),
"Admin revoke MUST set participant to original sender"
);
assert_eq!(edit_attr.to_string_val(), "8");
let revoke_message =
build_revoke_message(&to, from_me, message_id.clone(), participant.clone());
let proto_msg = revoke_message.protocol_message.unwrap();
let key = proto_msg.key.unwrap();
assert_eq!(key.from_me, Some(false));
assert_eq!(key.participant, Some("236395184570386@lid".to_string()));
assert_eq!(key.id, Some(message_id));
}
#[test]
fn test_admin_revoke_preserves_lid_format() {
let lid_sender = Jid::from_str("236395184570386:22@lid").unwrap();
let participant_str = lid_sender.to_non_ad().to_string();
assert_eq!(participant_str, "236395184570386@lid");
assert!(
participant_str.ends_with("@lid"),
"LID participant must preserve @lid suffix"
);
}
#[test]
fn test_skdm_recipient_filtering_basic() {
use std::collections::HashSet;
let known_recipients: Vec<Jid> = [
"1234567890:0@s.whatsapp.net",
"1234567890:5@s.whatsapp.net",
"9876543210:0@s.whatsapp.net",
]
.into_iter()
.map(|s| Jid::from_str(s).unwrap())
.collect();
let all_devices: Vec<Jid> = [
"1234567890:0@s.whatsapp.net",
"1234567890:5@s.whatsapp.net",
"9876543210:0@s.whatsapp.net",
"5555555555:0@s.whatsapp.net", ]
.into_iter()
.map(|s| Jid::from_str(s).unwrap())
.collect();
let known_set: HashSet<DeviceKey<'_>> =
known_recipients.iter().map(|j| j.device_key()).collect();
let new_devices: Vec<Jid> = all_devices
.into_iter()
.filter(|device| !known_set.contains(&device.device_key()))
.collect();
assert_eq!(new_devices.len(), 1);
assert_eq!(new_devices[0].user, "5555555555");
}
#[test]
fn test_skdm_recipient_filtering_lid_jids() {
use std::collections::HashSet;
let known_recipients: Vec<Jid> = [
"236395184570386:91@lid",
"129171292463295:0@lid",
"45857667830004:14@lid",
]
.into_iter()
.map(|s| Jid::from_str(s).unwrap())
.collect();
let all_devices: Vec<Jid> = [
"236395184570386:91@lid",
"129171292463295:0@lid",
"45857667830004:14@lid",
"45857667830004:15@lid", ]
.into_iter()
.map(|s| Jid::from_str(s).unwrap())
.collect();
let known_set: HashSet<DeviceKey<'_>> =
known_recipients.iter().map(|j| j.device_key()).collect();
let new_devices: Vec<Jid> = all_devices
.into_iter()
.filter(|device| !known_set.contains(&device.device_key()))
.collect();
assert_eq!(new_devices.len(), 1);
assert_eq!(new_devices[0].user, "45857667830004");
assert_eq!(new_devices[0].device, 15);
}
#[test]
fn test_skdm_recipient_filtering_all_known() {
use std::collections::HashSet;
let known_recipients: Vec<Jid> =
["1234567890:0@s.whatsapp.net", "1234567890:5@s.whatsapp.net"]
.into_iter()
.map(|s| Jid::from_str(s).unwrap())
.collect();
let all_devices: Vec<Jid> = ["1234567890:0@s.whatsapp.net", "1234567890:5@s.whatsapp.net"]
.into_iter()
.map(|s| Jid::from_str(s).unwrap())
.collect();
let known_set: HashSet<DeviceKey<'_>> =
known_recipients.iter().map(|j| j.device_key()).collect();
let new_devices: Vec<Jid> = all_devices
.into_iter()
.filter(|device| !known_set.contains(&device.device_key()))
.collect();
assert!(new_devices.is_empty());
}
#[test]
fn test_skdm_recipient_filtering_all_new() {
use std::collections::HashSet;
let known_recipients: Vec<Jid> = vec![];
let all_devices: Vec<Jid> = ["1234567890:0@s.whatsapp.net", "9876543210:0@s.whatsapp.net"]
.into_iter()
.map(|s| Jid::from_str(s).unwrap())
.collect();
let known_set: HashSet<DeviceKey<'_>> =
known_recipients.iter().map(|j| j.device_key()).collect();
let new_devices: Vec<Jid> = all_devices
.clone()
.into_iter()
.filter(|device| !known_set.contains(&device.device_key()))
.collect();
assert_eq!(new_devices.len(), all_devices.len());
}
#[test]
fn test_device_key_comparison() {
let test_cases = [
(
"1234567890:0@s.whatsapp.net",
"1234567890@s.whatsapp.net",
true,
),
(
"1234567890:5@s.whatsapp.net",
"1234567890:5@s.whatsapp.net",
true,
),
(
"1234567890:5@s.whatsapp.net",
"1234567890:6@s.whatsapp.net",
false,
),
("236395184570386:91@lid", "236395184570386:91@lid", true),
("236395184570386:0@lid", "236395184570386@lid", true),
("user1@s.whatsapp.net", "user2@s.whatsapp.net", false),
];
for (jid1_str, jid2_str, should_match) in test_cases {
let jid1: Jid = jid1_str.parse().expect("should parse jid1");
let jid2: Jid = jid2_str.parse().expect("should parse jid2");
let key1 = jid1.device_key();
let key2 = jid2.device_key();
assert_eq!(
key1 == key2,
should_match,
"DeviceKey comparison failed for '{}' vs '{}': expected match={}, got match={}",
jid1_str,
jid2_str,
should_match,
key1 == key2
);
assert_eq!(
jid1.device_eq(&jid2),
should_match,
"device_eq failed for '{}' vs '{}'",
jid1_str,
jid2_str
);
}
}
#[test]
fn empty_sender_key_device_map_marks_all_devices_for_skdm() {
use crate::sender_key_device_cache::SenderKeyDeviceMap;
let map = SenderKeyDeviceMap::from_db_rows(&[]);
assert_eq!(map.device_has_key("271060335329480", 0), None);
assert!(!map.is_user_forgotten("271060335329480"));
let all_resolved_devices: Vec<Jid> = [
"271060335329480@lid",
"77610646245392@lid",
"276661023027320:5@lid",
]
.into_iter()
.map(|s| Jid::from_str(s).unwrap())
.collect();
let needs_skdm: Vec<&Jid> = all_resolved_devices
.iter()
.filter(|device| {
!map.device_has_key(&device.user, device.device)
.unwrap_or(false)
|| map.is_user_forgotten(&device.user)
})
.collect();
assert_eq!(needs_skdm.len(), all_resolved_devices.len());
}
#[tokio::test]
async fn resolve_skdm_targets_distributes_when_cache_empty_but_devices_known() {
use wacore::client::context::GroupInfo;
use wacore::store::traits::{DeviceInfo, DeviceListRecord};
use wacore::types::message::AddressingMode;
let client = crate::test_utils::create_test_client().await;
let group_jid = "120363161500776365@g.us";
let own_lid = Jid::from_str("193832511623409:13@lid").unwrap();
let participant_users = ["271060335329480", "77610646245392", "276661023027320"];
for user in &participant_users {
let record = DeviceListRecord {
user: (*user).into(),
devices: vec![DeviceInfo {
device_id: 0,
key_index: None,
}],
timestamp: wacore::time::now_secs(),
phash: None,
raw_id: None,
};
client
.device_registry_cache
.insert((*user).into(), record)
.await;
}
let participants: Vec<Jid> = participant_users
.iter()
.map(|u| Jid::from_str(&format!("{u}@lid")).unwrap())
.collect();
let group_info = GroupInfo::new(participants.clone(), AddressingMode::Lid);
let result = client
.resolve_skdm_targets(group_jid, &group_info, &own_lid)
.await
.expect("None means the empty-cache early-exit is back");
assert_eq!(result.len(), participants.len());
for user in &participant_users {
assert!(result.iter().any(|j| j.user == *user));
}
}
#[test]
fn single_forgotten_row_keeps_full_distribution() {
use crate::sender_key_device_cache::SenderKeyDeviceMap;
let map = SenderKeyDeviceMap::from_db_rows(&[("271060335329480@lid".to_string(), false)]);
assert_eq!(map.device_has_key("271060335329480", 0), Some(false));
assert!(map.is_user_forgotten("271060335329480"));
let all_resolved_devices: Vec<Jid> = [
"271060335329480@lid",
"77610646245392@lid",
"276661023027320:5@lid",
]
.into_iter()
.map(|s| Jid::from_str(s).unwrap())
.collect();
let needs_skdm: Vec<&Jid> = all_resolved_devices
.iter()
.filter(|device| {
!map.device_has_key(&device.user, device.device)
.unwrap_or(false)
|| map.is_user_forgotten(&device.user)
})
.collect();
assert_eq!(
needs_skdm.len(),
3,
"after retry inserts one row, ALL devices correctly flagged for SKDM \
(this is what unblocks redistribution on the SECOND message)"
);
}
#[test]
fn test_skdm_filtering_large_group() {
use std::collections::HashSet;
let mut known_recipients: Vec<Jid> = Vec::with_capacity(1000);
let mut all_devices: Vec<Jid> = Vec::with_capacity(1010);
for i in 0..1000i64 {
let jid_str = format!("{}:1@lid", 100000000000000i64 + i);
let jid = Jid::from_str(&jid_str).unwrap();
known_recipients.push(jid.clone());
all_devices.push(jid);
}
for i in 1000i64..1010i64 {
let jid_str = format!("{}:1@lid", 100000000000000i64 + i);
all_devices.push(Jid::from_str(&jid_str).unwrap());
}
let known_set: HashSet<DeviceKey<'_>> =
known_recipients.iter().map(|j| j.device_key()).collect();
let new_devices: Vec<Jid> = all_devices
.into_iter()
.filter(|device| !known_set.contains(&device.device_key()))
.collect();
assert_eq!(new_devices.len(), 10);
}
mod infer_stanza {
use super::*;
#[test]
fn regular_message_returns_none() {
let msg = wa::Message {
conversation: Some("hello".into()),
..Default::default()
};
let (edit, node) = infer_stanza_metadata(&msg);
assert!(edit.is_none());
assert!(node.is_none());
}
#[test]
fn pin_returns_edit_attribute() {
let msg = wa::Message {
pin_in_chat_message: Some(wa::message::PinInChatMessage::default()),
..Default::default()
};
let (edit, node) = infer_stanza_metadata(&msg);
assert_eq!(edit, Some(EditAttribute::PinInChat));
assert!(node.is_none());
}
#[test]
fn poll_creation_v3_returns_meta_node() {
let msg = wa::Message {
poll_creation_message_v3: Some(Box::default()),
..Default::default()
};
let (edit, node) = infer_stanza_metadata(&msg);
assert!(edit.is_none());
let node = node.expect("should have meta node");
assert_eq!(node.tag, "meta");
let mut attrs = node.attrs();
assert_eq!(
attrs.optional_string("polltype").unwrap().as_ref(),
"creation"
);
}
#[test]
fn event_returns_meta_node() {
let msg = wa::Message {
event_message: Some(Box::default()),
..Default::default()
};
let (edit, node) = infer_stanza_metadata(&msg);
assert!(edit.is_none());
let node = node.expect("should have meta node");
assert_eq!(node.tag, "meta");
let mut attrs = node.attrs();
assert_eq!(
attrs.optional_string("event_type").unwrap().as_ref(),
"creation"
);
}
#[test]
fn empty_message_returns_none() {
let (edit, node) = infer_stanza_metadata(&wa::Message::default());
assert!(edit.is_none());
assert!(node.is_none());
}
#[test]
fn poll_creation_v1_returns_meta_node() {
let msg = wa::Message {
poll_creation_message: Some(Box::default()),
..Default::default()
};
let (edit, node) = infer_stanza_metadata(&msg);
assert!(edit.is_none());
let node = node.expect("should have meta node");
assert_eq!(node.tag, "meta");
let mut attrs = node.attrs();
assert_eq!(
attrs.optional_string("polltype").unwrap().as_ref(),
"creation"
);
}
#[test]
fn poll_creation_v2_returns_meta_node() {
let msg = wa::Message {
poll_creation_message_v2: Some(Box::default()),
..Default::default()
};
let (edit, node) = infer_stanza_metadata(&msg);
assert!(edit.is_none());
let node = node.expect("should have meta node");
assert_eq!(node.tag, "meta");
let mut attrs = node.attrs();
assert_eq!(
attrs.optional_string("polltype").unwrap().as_ref(),
"creation"
);
}
#[test]
fn poll_vote_returns_meta_node() {
let msg = wa::Message {
poll_update_message: Some(wa::message::PollUpdateMessage {
vote: Some(wa::message::PollEncValue::default()),
..Default::default()
}),
..Default::default()
};
let (edit, node) = infer_stanza_metadata(&msg);
assert!(edit.is_none());
let node = node.expect("should have meta node");
assert_eq!(node.tag, "meta");
let mut attrs = node.attrs();
assert_eq!(attrs.optional_string("polltype").unwrap().as_ref(), "vote");
}
#[test]
fn event_response_returns_meta_node() {
let msg = wa::Message {
enc_event_response_message: Some(Default::default()),
..Default::default()
};
let (edit, node) = infer_stanza_metadata(&msg);
assert!(edit.is_none());
let node = node.expect("should have meta node");
assert_eq!(node.tag, "meta");
let mut attrs = node.attrs();
assert_eq!(
attrs.optional_string("event_type").unwrap().as_ref(),
"response"
);
}
#[test]
fn poll_update_without_vote_returns_none() {
let msg = wa::Message {
poll_update_message: Some(wa::message::PollUpdateMessage {
vote: None,
..Default::default()
}),
..Default::default()
};
let (edit, node) = infer_stanza_metadata(&msg);
assert!(edit.is_none());
assert!(node.is_none());
}
}
mod infer_biz {
use super::*;
use wa::message::interactive_message::{
self, NativeFlowMessage, native_flow_message::NativeFlowButton,
};
fn msg_with_native_flow(button_name: &str) -> wa::Message {
wa::Message {
document_with_caption_message: Some(Box::new(wa::message::FutureProofMessage {
message: Some(Box::new(wa::Message {
interactive_message: Some(Box::new(wa::message::InteractiveMessage {
interactive_message: Some(
interactive_message::InteractiveMessage::NativeFlowMessage(
NativeFlowMessage {
buttons: vec![NativeFlowButton {
name: Some(button_name.to_string()),
button_params_json: None,
}],
message_version: Some(1),
message_params_json: None,
},
),
),
..Default::default()
})),
..Default::default()
})),
})),
..Default::default()
}
}
fn assert_biz_node(node: &Node, expected_flow_name: &str) {
assert_eq!(node.tag, "biz");
assert!(
node.attrs().optional_string("native_flow_name").is_none(),
"should NOT use simple attribute form"
);
let interactive = node.get_optional_child("interactive").unwrap();
let mut attrs = interactive.attrs();
assert_eq!(
attrs.optional_string("type").unwrap().as_ref(),
"native_flow"
);
assert_eq!(attrs.optional_string("v").unwrap().as_ref(), "1");
let nf = interactive.get_optional_child("native_flow").unwrap();
let mut nf_attrs = nf.attrs();
assert_eq!(
nf_attrs.optional_string("name").unwrap().as_ref(),
expected_flow_name
);
}
#[test]
fn all_button_types_use_nested_structure() {
for (button, expected_flow) in [
("cta_url", "cta_url"),
("payment_info", "payment_info"),
("review_and_pay", "order_details"),
("cta_catalog", "cta_catalog"),
("mpm", "mpm"),
("quick_reply", "quick_reply"),
] {
let node = infer_biz_node(&msg_with_native_flow(button))
.unwrap_or_else(|| panic!("{button} should produce biz node"));
assert_biz_node(&node, expected_flow);
}
}
#[test]
fn no_interactive_returns_none() {
let msg = wa::Message {
conversation: Some("hello".into()),
..Default::default()
};
assert!(infer_biz_node(&msg).is_none());
}
#[test]
fn interactive_without_native_flow_returns_none() {
let msg = wa::Message {
interactive_message: Some(Box::new(wa::message::InteractiveMessage {
interactive_message: Some(
interactive_message::InteractiveMessage::CollectionMessage(
Default::default(),
),
),
..Default::default()
})),
..Default::default()
};
assert!(infer_biz_node(&msg).is_none());
}
#[test]
fn native_flow_without_buttons_returns_none() {
let msg = wa::Message {
interactive_message: Some(Box::new(wa::message::InteractiveMessage {
interactive_message: Some(
interactive_message::InteractiveMessage::NativeFlowMessage(
NativeFlowMessage {
buttons: vec![],
message_version: Some(1),
message_params_json: None,
},
),
),
..Default::default()
})),
..Default::default()
};
assert!(infer_biz_node(&msg).is_none());
}
#[test]
fn direct_interactive_message_without_wrapper() {
let msg = wa::Message {
interactive_message: Some(Box::new(wa::message::InteractiveMessage {
interactive_message: Some(
interactive_message::InteractiveMessage::NativeFlowMessage(
NativeFlowMessage {
buttons: vec![NativeFlowButton {
name: Some("cta_url".to_string()),
button_params_json: None,
}],
message_version: Some(1),
message_params_json: None,
},
),
),
..Default::default()
})),
..Default::default()
};
let node = infer_biz_node(&msg).unwrap();
assert_biz_node(&node, "cta_url");
}
}
mod session_lock_regression {
use super::*;
#[tokio::test]
async fn per_device_lock_keys_cover_all_devices() {
let client = crate::test_utils::create_test_client().await;
let devices: Vec<Jid> = [
"100000012345678@lid",
"100000012345678:5@lid",
"100000012345678:33@lid",
]
.iter()
.map(|s| Jid::from_str(s).unwrap())
.collect();
let send_lock_keys = client.build_session_lock_keys(&devices).await;
assert_eq!(send_lock_keys.len(), 3);
assert_eq!(send_lock_keys[0].device, 0);
assert_eq!(send_lock_keys[1].device, 5);
assert_eq!(send_lock_keys[2].device, 33);
for device_jid in &devices {
assert!(
send_lock_keys.contains(device_jid),
"device {device_jid} not in send keys: {send_lock_keys:?}"
);
}
let bare_key = devices[0].to_protocol_address_string();
let device5_key = devices[1].to_protocol_address_string();
assert_ne!(bare_key, device5_key);
}
#[tokio::test]
async fn per_device_lock_serializes_concurrent_session_access() {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
let session_locks: crate::cache::Cache<String, Arc<async_lock::Mutex<()>>> =
crate::cache::Cache::builder().max_capacity(100).build();
let lock_key = "100000012345678:5@lid.0".to_string();
let access_counter = Arc::new(AtomicU32::new(0));
let max_concurrent = Arc::new(AtomicU32::new(0));
let mut handles = Vec::new();
for _ in 0..10 {
let locks = session_locks.clone();
let key = lock_key.clone();
let counter = access_counter.clone();
let max = max_concurrent.clone();
handles.push(tokio::spawn(async move {
let mutex: Arc<async_lock::Mutex<()>> = locks
.get_with_by_ref(&key, async { Arc::new(async_lock::Mutex::new(())) })
.await;
let _guard = mutex.lock_arc().await;
let active = counter.fetch_add(1, Ordering::SeqCst) + 1;
max.fetch_max(active, Ordering::SeqCst);
tokio::task::yield_now().await;
counter.fetch_sub(1, Ordering::SeqCst);
}));
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(max_concurrent.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn different_device_locks_are_independent() {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
let session_locks: crate::cache::Cache<String, Arc<async_lock::Mutex<()>>> =
crate::cache::Cache::builder().max_capacity(100).build();
let max_concurrent = Arc::new(AtomicU32::new(0));
let counter = Arc::new(AtomicU32::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(2));
let keys = ["100000012345678@lid.0", "100000012345678:5@lid.0"];
let mut handles = Vec::new();
for key in keys {
let locks = session_locks.clone();
let key = key.to_string();
let c = counter.clone();
let m = max_concurrent.clone();
let b = barrier.clone();
handles.push(tokio::spawn(async move {
let mutex: Arc<async_lock::Mutex<()>> = locks
.get_with_by_ref(&key, async { Arc::new(async_lock::Mutex::new(())) })
.await;
let _guard = mutex.lock_arc().await;
let active = c.fetch_add(1, Ordering::SeqCst) + 1;
m.fetch_max(active, Ordering::SeqCst);
b.wait().await;
c.fetch_sub(1, Ordering::SeqCst);
}));
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(max_concurrent.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn dm_recipient_uses_bare_address() {
let client = crate::test_utils::create_test_client().await;
let recipient_device33 = Jid::from_str("100000012345678:33@lid").unwrap();
let own_device_5 = Jid::from_str("999999999999:5@s.whatsapp.net").unwrap();
let recipient_bare = client
.resolve_encryption_jid(&recipient_device33)
.await
.to_non_ad();
let all_dm_jids = vec![recipient_bare.clone(), own_device_5.clone()];
let lock_jids = client.build_session_lock_keys(&all_dm_jids).await;
assert_eq!(
recipient_bare.to_protocol_address_string(),
"100000012345678@lid.0"
);
assert!(lock_jids.contains(&recipient_bare));
assert!(lock_jids.contains(&own_device_5));
assert!(
!lock_jids.contains(&recipient_device33),
"recipient must NOT use device-specific address"
);
}
#[test]
fn bare_normalization_deduplicates_recipient_devices() {
let devices: Vec<Jid> = [
"100000012345678@lid",
"100000012345678:5@lid",
"100000012345678:33@lid",
]
.iter()
.map(|s| Jid::from_str(s).unwrap())
.collect();
let bare: Vec<Jid> = devices.iter().map(|j| j.to_non_ad()).collect();
assert!(bare.windows(2).all(|w| w[0] == w[1]));
assert_eq!(
bare[0].to_protocol_address_string(),
"100000012345678@lid.0"
);
}
}
}