use anyhow::Result;
use log::{debug, info, warn};
use wacore_binary::jid::Jid;
use super::Client;
#[derive(Debug, Clone)]
enum UserLookupKeys {
LidWithPn { lid: String, pn: String },
PnWithLid { lid: String, pn: String },
Unknown { user: String },
}
impl UserLookupKeys {
fn all_keys(&self) -> Vec<&str> {
match self {
Self::LidWithPn { lid, pn } | Self::PnWithLid { lid, pn } => vec![lid, pn],
Self::Unknown { user } => vec![user],
}
}
fn canonical_key(&self) -> &str {
match self {
Self::LidWithPn { lid, .. } | Self::PnWithLid { lid, .. } => lid,
Self::Unknown { user } => user,
}
}
}
impl Client {
#[cfg(test)]
pub(crate) async fn resolve_to_canonical_key(&self, user: &str) -> String {
self.resolve_lookup_keys(user)
.await
.canonical_key()
.to_string()
}
async fn resolve_lookup_keys(&self, user: &str) -> UserLookupKeys {
if let Some(pn) = self.lid_pn_cache.get_phone_number(user).await {
return UserLookupKeys::LidWithPn {
lid: user.to_string(),
pn,
};
}
if let Some(lid) = self.lid_pn_cache.get_current_lid(user).await {
return UserLookupKeys::PnWithLid {
lid,
pn: user.to_string(),
};
}
UserLookupKeys::Unknown {
user: user.to_string(),
}
}
pub(crate) async fn get_lookup_keys(&self, user: &str) -> Vec<String> {
self.resolve_lookup_keys(user)
.await
.all_keys()
.into_iter()
.map(String::from)
.collect()
}
pub(crate) async fn has_device(&self, user: &str, device_id: u32) -> bool {
if device_id == 0 {
return true;
}
let lookup_keys = self.get_lookup_keys(user).await;
for key in &lookup_keys {
if let Some(record) = self.device_registry_cache.get(key).await {
return record.devices.iter().any(|d| d.device_id == device_id);
}
}
let backend = self.persistence_manager.backend();
for key in &lookup_keys {
match backend.get_devices(key).await {
Ok(Some(record)) => {
let has_device = record.devices.iter().any(|d| d.device_id == device_id);
self.device_registry_cache
.insert(record.user.clone(), record)
.await;
return has_device;
}
Ok(None) => continue,
Err(e) => {
warn!("Failed to check device registry for {}: {e}", key);
}
}
}
false
}
pub(crate) async fn update_device_list(
&self,
mut record: wacore::store::traits::DeviceListRecord,
) -> Result<()> {
use anyhow::Context;
let original_user = record.user.clone();
let lookup = self.resolve_lookup_keys(&original_user).await;
let canonical_key = lookup.canonical_key().to_string();
record.user.clone_from(&canonical_key);
let record_for_cache = record.clone();
self.device_registry_cache
.insert(canonical_key.clone(), record_for_cache)
.await;
let backend = self.persistence_manager.backend();
backend
.update_device_list(record)
.await
.context("Failed to update device list in backend")?;
if canonical_key != original_user {
self.device_registry_cache.invalidate(&original_user).await;
debug!(
"Device registry: stored under LID {} (resolved from {})",
canonical_key, original_user
);
}
Ok(())
}
pub(crate) async fn invalidate_device_cache(&self, user: &str) {
let lookup = self.resolve_lookup_keys(user).await;
for key in lookup.all_keys() {
self.device_registry_cache.invalidate(key).await;
}
let device_cache = self.get_device_cache().await;
match &lookup {
UserLookupKeys::LidWithPn { lid, pn } | UserLookupKeys::PnWithLid { lid, pn } => {
device_cache.invalidate(&Jid::lid(lid)).await;
device_cache.invalidate(&Jid::pn(pn)).await;
}
UserLookupKeys::Unknown { user } => {
device_cache.invalidate(&Jid::lid(user)).await;
device_cache.invalidate(&Jid::pn(user)).await;
}
}
debug!("Invalidated device cache for user: {} ({:?})", user, lookup);
}
pub(crate) async fn patch_device_add(
&self,
user: &str,
from_jid: &Jid,
device: &wacore::stanza::devices::DeviceElement,
) {
let device_id = device.device_id();
let device_hw = device_id as u16;
let device_cache = self.get_device_cache().await;
let lookup = self.resolve_lookup_keys(user).await;
for jid in self.jids_for_lookup(&lookup, from_jid) {
if let Some(mut devices) = device_cache.get(&jid).await
&& !devices.iter().any(|d| d.device == device_hw)
{
devices.push(Jid {
user: jid.user.clone(),
server: jid.server.clone(),
device: device_hw,
..Default::default()
});
device_cache.insert(jid, devices).await;
}
}
for key in lookup.all_keys() {
if let Some(mut record) = self.device_registry_cache.get(key).await {
if !record.devices.iter().any(|d| d.device_id == device_id) {
record.devices.push(wacore::store::traits::DeviceInfo {
device_id,
key_index: device.key_index,
});
if let Err(e) = self.update_device_list(record).await {
warn!("patch_device_add: failed to persist: {e}");
}
}
return;
}
}
}
pub(crate) async fn patch_device_remove(&self, user: &str, from_jid: &Jid, device_id: u32) {
let device_hw = device_id as u16;
let device_cache = self.get_device_cache().await;
let lookup = self.resolve_lookup_keys(user).await;
for jid in self.jids_for_lookup(&lookup, from_jid) {
if let Some(mut devices) = device_cache.get(&jid).await {
let before = devices.len();
devices.retain(|d| d.device != device_hw);
if devices.len() != before {
device_cache.insert(jid, devices).await;
}
}
}
for key in lookup.all_keys() {
if let Some(mut record) = self.device_registry_cache.get(key).await {
let before = record.devices.len();
record.devices.retain(|d| d.device_id != device_id);
if record.devices.len() != before
&& let Err(e) = self.update_device_list(record).await
{
warn!("patch_device_remove: failed to persist: {e}");
}
return;
}
}
}
pub(crate) async fn patch_device_update(
&self,
user: &str,
device: &wacore::stanza::devices::DeviceElement,
) {
let device_id = device.device_id();
let lookup = self.resolve_lookup_keys(user).await;
for key in lookup.all_keys() {
if let Some(mut record) = self.device_registry_cache.get(key).await {
if let Some(d) = record.devices.iter_mut().find(|d| d.device_id == device_id) {
d.key_index = device.key_index;
if let Err(e) = self.update_device_list(record).await {
warn!("patch_device_update: failed to persist: {e}");
}
}
return;
}
}
}
fn jids_for_lookup(&self, lookup: &UserLookupKeys, from_jid: &Jid) -> Vec<Jid> {
match lookup {
UserLookupKeys::LidWithPn { lid, pn } | UserLookupKeys::PnWithLid { lid, pn } => {
vec![Jid::lid(lid), Jid::pn(pn)]
}
UserLookupKeys::Unknown { .. } => {
vec![from_jid.to_non_ad()]
}
}
}
pub(super) async fn device_registry_cleanup_loop(&self) {
self.shutdown_notifier.listen().await;
debug!(
target: "Client/DeviceRegistry",
"Shutdown signaled, exiting cleanup loop"
);
}
pub(crate) async fn migrate_device_registry_on_lid_discovery(&self, pn: &str, lid: &str) {
let backend = self.persistence_manager.backend();
match backend.get_devices(pn).await {
Ok(Some(mut record)) => {
info!(
"Migrating device registry entry from PN {} to LID {} ({} devices)",
pn,
lid,
record.devices.len()
);
record.user = lid.to_string();
if let Err(e) = backend.update_device_list(record.clone()).await {
warn!("Failed to migrate device registry to LID: {}", e);
return;
}
self.device_registry_cache
.insert(lid.to_string(), record)
.await;
self.invalidate_device_cache(pn).await;
}
Ok(None) => {}
Err(e) => {
warn!("Failed to check for PN device registry entry: {}", e);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lid_pn_cache::LearningSource;
use crate::test_utils::create_test_client_with_failing_http;
use std::sync::Arc;
async fn create_test_client() -> Arc<Client> {
create_test_client_with_failing_http("device_registry").await
}
#[tokio::test]
async fn test_resolve_to_canonical_key_unknown_user() {
let client = create_test_client().await;
let result = client.resolve_to_canonical_key("15551234567").await;
assert_eq!(result, "15551234567");
}
#[tokio::test]
async fn test_resolve_to_canonical_key_with_lid_mapping() {
use crate::lid_pn_cache::LidPnEntry;
let client = create_test_client().await;
let lid = "100000000000001";
let pn = "15551234567";
let entry = LidPnEntry::new(lid.to_string(), pn.to_string(), LearningSource::Usync);
client.lid_pn_cache.add(entry).await;
let result = client.resolve_to_canonical_key(pn).await;
assert_eq!(result, lid);
let result = client.resolve_to_canonical_key(lid).await;
assert_eq!(result, lid);
}
#[tokio::test]
async fn test_get_lookup_keys_unknown_user() {
let client = create_test_client().await;
let keys = client.get_lookup_keys("15551234567").await;
assert_eq!(keys, vec!["15551234567"]);
}
#[tokio::test]
async fn test_get_lookup_keys_with_lid_mapping() {
use crate::lid_pn_cache::LidPnEntry;
let client = create_test_client().await;
let lid = "100000000000001";
let pn = "15551234567";
let entry = LidPnEntry::new(lid.to_string(), pn.to_string(), LearningSource::Usync);
client.lid_pn_cache.add(entry).await;
let keys = client.get_lookup_keys(pn).await;
assert_eq!(keys, vec![lid.to_string(), pn.to_string()]);
let keys = client.get_lookup_keys(lid).await;
assert_eq!(keys, vec![lid.to_string(), pn.to_string()]);
}
#[tokio::test]
async fn test_15_digit_lid_handling() {
use crate::lid_pn_cache::LidPnEntry;
let client = create_test_client().await;
let lid = "100000000000001";
let pn = "15551234567";
assert_eq!(lid.len(), 15, "LID should be 15 digits");
let entry = LidPnEntry::new(lid.to_string(), pn.to_string(), LearningSource::Usync);
client.lid_pn_cache.add(entry).await;
let canonical = client.resolve_to_canonical_key(lid).await;
assert_eq!(canonical, lid);
let keys = client.get_lookup_keys(lid).await;
assert_eq!(keys.len(), 2);
assert_eq!(keys[0], lid);
assert_eq!(keys[1], pn);
}
#[tokio::test]
async fn test_has_device_primary_always_exists() {
let client = create_test_client().await;
assert!(client.has_device("anyuser", 0).await);
}
#[tokio::test]
async fn test_has_device_unknown_device() {
let client = create_test_client().await;
assert!(!client.has_device("15551234567", 5).await);
}
#[tokio::test]
async fn test_has_device_with_cached_record() {
use crate::lid_pn_cache::LidPnEntry;
let client = create_test_client().await;
let lid = "100000000000001";
let pn = "15551234567";
let entry = LidPnEntry::new(lid.to_string(), pn.to_string(), LearningSource::Usync);
client.lid_pn_cache.add(entry).await;
let record = wacore::store::traits::DeviceListRecord {
user: lid.to_string(),
devices: vec![wacore::store::traits::DeviceInfo {
device_id: 1,
key_index: None,
}],
timestamp: 12345,
phash: None,
};
client
.device_registry_cache
.insert(lid.to_string(), record)
.await;
assert!(client.has_device(pn, 1).await);
assert!(client.has_device(lid, 1).await);
assert!(!client.has_device(lid, 99).await);
}
#[tokio::test]
async fn test_invalidate_device_cache_uses_correct_jid_types() {
use crate::lid_pn_cache::LidPnEntry;
use wacore_binary::jid::Jid;
let client = create_test_client().await;
let lid = "100000000000001";
let pn = "15551234567";
let entry = LidPnEntry::new(lid.to_string(), pn.to_string(), LearningSource::Usync);
client.lid_pn_cache.add(entry).await;
let record = wacore::store::traits::DeviceListRecord {
user: lid.to_string(),
devices: vec![wacore::store::traits::DeviceInfo {
device_id: 1,
key_index: None,
}],
timestamp: 12345,
phash: None,
};
client
.device_registry_cache
.insert(lid.to_string(), record)
.await;
let lid_jid = Jid::lid(lid);
let pn_jid = Jid::pn(pn);
let device_cache = client.get_device_cache().await;
device_cache
.insert(lid_jid.clone(), vec![lid_jid.clone()])
.await;
device_cache
.insert(pn_jid.clone(), vec![pn_jid.clone()])
.await;
assert!(
client.device_registry_cache.get(lid).await.is_some(),
"Device registry cache should have LID entry before invalidation"
);
assert!(
device_cache.get(&lid_jid).await.is_some(),
"Device cache should have LID JID entry before invalidation"
);
assert!(
device_cache.get(&pn_jid).await.is_some(),
"Device cache should have PN JID entry before invalidation"
);
client.invalidate_device_cache(pn).await;
assert!(
client.device_registry_cache.get(lid).await.is_none(),
"Device registry cache should be invalidated for LID"
);
assert!(
device_cache.get(&lid_jid).await.is_none(),
"Device cache should be invalidated for LID JID"
);
assert!(
device_cache.get(&pn_jid).await.is_none(),
"Device cache should be invalidated for PN JID"
);
let record2 = wacore::store::traits::DeviceListRecord {
user: lid.to_string(),
devices: vec![wacore::store::traits::DeviceInfo {
device_id: 2,
key_index: None,
}],
timestamp: 12346,
phash: None,
};
client
.device_registry_cache
.insert(lid.to_string(), record2)
.await;
device_cache
.insert(lid_jid.clone(), vec![lid_jid.clone()])
.await;
device_cache
.insert(pn_jid.clone(), vec![pn_jid.clone()])
.await;
client.invalidate_device_cache(lid).await;
assert!(
client.device_registry_cache.get(lid).await.is_none(),
"Device registry cache should be invalidated for LID (called with LID)"
);
assert!(
device_cache.get(&lid_jid).await.is_none(),
"Device cache should be invalidated for LID JID (called with LID)"
);
assert!(
device_cache.get(&pn_jid).await.is_none(),
"Device cache should be invalidated for PN JID (called with LID)"
);
}
#[tokio::test]
async fn test_invalidate_device_cache_unknown_user_invalidates_both_types() {
use wacore_binary::jid::Jid;
let client = create_test_client().await;
let unknown_user = "100000000000999";
let lid_jid = Jid::lid(unknown_user);
let pn_jid = Jid::pn(unknown_user);
let device_cache = client.get_device_cache().await;
device_cache
.insert(lid_jid.clone(), vec![lid_jid.clone()])
.await;
assert!(
device_cache.get(&lid_jid).await.is_some(),
"Device cache should have LID JID entry before invalidation"
);
client.invalidate_device_cache(unknown_user).await;
assert!(
device_cache.get(&lid_jid).await.is_none(),
"Device cache should be invalidated for LID JID (unknown user)"
);
assert!(
device_cache.get(&pn_jid).await.is_none(),
"Device cache should be invalidated for PN JID (unknown user)"
);
let unknown_user2 = "15559998888";
let lid_jid2 = Jid::lid(unknown_user2);
let pn_jid2 = Jid::pn(unknown_user2);
device_cache
.insert(pn_jid2.clone(), vec![pn_jid2.clone()])
.await;
assert!(
device_cache.get(&pn_jid2).await.is_some(),
"Device cache should have PN JID entry before invalidation"
);
client.invalidate_device_cache(unknown_user2).await;
assert!(
device_cache.get(&lid_jid2).await.is_none(),
"Device cache should be invalidated for LID JID (unknown PN user)"
);
assert!(
device_cache.get(&pn_jid2).await.is_none(),
"Device cache should be invalidated for PN JID (unknown PN user)"
);
}
fn make_device_element(
device_id: u16,
key_index: Option<u32>,
) -> wacore::stanza::devices::DeviceElement {
wacore::stanza::devices::DeviceElement {
jid: Jid {
user: "15551234567".into(),
server: "s.whatsapp.net".into(),
device: device_id,
..Default::default()
},
key_index,
lid: None,
}
}
#[tokio::test]
async fn test_patch_device_add_to_existing_cache() {
let client = create_test_client().await;
let from_jid = Jid::pn("15551234567");
let non_ad = from_jid.to_non_ad();
let device_cache = client.get_device_cache().await;
let dev0 = Jid {
user: "15551234567".into(),
server: "s.whatsapp.net".into(),
device: 0,
..Default::default()
};
device_cache.insert(non_ad.clone(), vec![dev0]).await;
let elem = make_device_element(3, Some(5));
client
.patch_device_add("15551234567", &from_jid, &elem)
.await;
let devices = device_cache.get(&non_ad).await.unwrap();
assert_eq!(devices.len(), 2);
assert!(devices.iter().any(|d| d.device == 3));
}
#[tokio::test]
async fn test_patch_device_add_deduplicates() {
let client = create_test_client().await;
let from_jid = Jid::pn("15551234567");
let non_ad = from_jid.to_non_ad();
let dev3 = Jid {
user: "15551234567".into(),
server: "s.whatsapp.net".into(),
device: 3,
..Default::default()
};
let device_cache = client.get_device_cache().await;
device_cache.insert(non_ad.clone(), vec![dev3]).await;
let elem = make_device_element(3, None);
client
.patch_device_add("15551234567", &from_jid, &elem)
.await;
let devices = device_cache.get(&non_ad).await.unwrap();
assert_eq!(devices.len(), 1);
}
#[tokio::test]
async fn test_patch_device_add_noop_on_miss() {
let client = create_test_client().await;
let from_jid = Jid::pn("15551234567");
let elem = make_device_element(3, None);
client
.patch_device_add("15551234567", &from_jid, &elem)
.await;
let device_cache = client.get_device_cache().await;
assert!(device_cache.get(&from_jid.to_non_ad()).await.is_none());
}
#[tokio::test]
async fn test_patch_device_remove() {
let client = create_test_client().await;
let from_jid = Jid::pn("15551234567");
let non_ad = from_jid.to_non_ad();
let dev0 = Jid {
user: "15551234567".into(),
server: "s.whatsapp.net".into(),
device: 0,
..Default::default()
};
let dev3 = Jid {
user: "15551234567".into(),
server: "s.whatsapp.net".into(),
device: 3,
..Default::default()
};
let device_cache = client.get_device_cache().await;
device_cache.insert(non_ad.clone(), vec![dev0, dev3]).await;
client
.patch_device_remove("15551234567", &from_jid, 3)
.await;
let devices = device_cache.get(&non_ad).await.unwrap();
assert_eq!(devices.len(), 1);
assert_eq!(devices[0].device, 0);
}
#[tokio::test]
async fn test_patch_device_update_key_index() {
let client = create_test_client().await;
let record = wacore::store::traits::DeviceListRecord {
user: "15551234567".to_string(),
devices: vec![
wacore::store::traits::DeviceInfo {
device_id: 0,
key_index: None,
},
wacore::store::traits::DeviceInfo {
device_id: 3,
key_index: Some(1),
},
],
timestamp: 1000,
phash: None,
};
client
.device_registry_cache
.insert("15551234567".to_string(), record)
.await;
let elem = make_device_element(3, Some(5));
client.patch_device_update("15551234567", &elem).await;
let updated = client
.device_registry_cache
.get("15551234567")
.await
.unwrap();
let dev3 = updated.devices.iter().find(|d| d.device_id == 3).unwrap();
assert_eq!(dev3.key_index, Some(5));
}
#[tokio::test]
async fn test_patch_device_add_updates_registry() {
let client = create_test_client().await;
let from_jid = Jid::pn("15551234567");
let record = wacore::store::traits::DeviceListRecord {
user: "15551234567".to_string(),
devices: vec![wacore::store::traits::DeviceInfo {
device_id: 0,
key_index: None,
}],
timestamp: 1000,
phash: None,
};
client
.device_registry_cache
.insert("15551234567".to_string(), record)
.await;
let elem = make_device_element(3, Some(2));
client
.patch_device_add("15551234567", &from_jid, &elem)
.await;
let updated = client
.device_registry_cache
.get("15551234567")
.await
.unwrap();
assert_eq!(updated.devices.len(), 2);
let dev3 = updated.devices.iter().find(|d| d.device_id == 3).unwrap();
assert_eq!(dev3.key_index, Some(2));
}
}