1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
//! Pre-key management for Signal Protocol.
//!
//! Pre-key IDs use a persistent monotonic counter (Device::next_pre_key_id)
//! matching WhatsApp Web's NEXT_PK_ID pattern. IDs only increase to prevent
//! collisions when prekeys are consumed non-sequentially from the store.
use crate::client::Client;
use anyhow;
use log;
use std::sync::atomic::Ordering;
use wacore::iq::prekeys::{
DigestKeyBundleSpec, PreKeyCountSpec, PreKeyFetchReason, PreKeyFetchSpec, PreKeyUploadSpec,
};
use wacore::libsignal::protocol::{KeyPair, PreKeyBundle, PublicKey};
use wacore::libsignal::store::record_helpers::new_pre_key_record;
use wacore::store::commands::DeviceCommand;
use wacore_binary::Jid;
pub use wacore::prekeys::PreKeyUtils;
/// Matches WA Web's UPLOAD_KEYS_COUNT from WAWebSignalStoreApi.
const WANTED_PRE_KEY_COUNT: usize = 812;
const MIN_PRE_KEY_COUNT: usize = 5;
impl Client {
pub(crate) async fn fetch_pre_keys(
&self,
jids: &[Jid],
reason: Option<PreKeyFetchReason>,
) -> Result<std::collections::HashMap<Jid, PreKeyBundle>, anyhow::Error> {
let spec = match reason {
Some(r) => PreKeyFetchSpec::with_reason(jids.to_vec(), r),
None => PreKeyFetchSpec::new(jids.to_vec()),
};
let bundles = self.execute(spec).await?;
for jid in bundles.keys() {
log::debug!("Successfully parsed pre-key bundle for {jid}");
}
Ok(bundles)
}
/// Query the WhatsApp server for how many pre-keys it currently has for this device.
pub(crate) async fn get_server_pre_key_count(&self) -> Result<usize, crate::request::IqError> {
let response = self.execute(PreKeyCountSpec::new()).await?;
Ok(response.count)
}
/// Upload prekeys at login if the persisted flag indicates they're needed.
/// Matches WA Web's PassiveTasks.js:30 which checks `getServerHasPreKeys()`.
pub(crate) async fn upload_pre_keys_at_login(&self) -> Result<(), anyhow::Error> {
let has_prekeys = self
.persistence_manager
.get_device_snapshot()
.await
.server_has_prekeys;
if has_prekeys {
log::debug!("Server has prekeys (persisted flag), skipping login upload.");
return Ok(());
}
// Serialize with prekey-low/digest paths to avoid duplicate uploads
let _guard = self.prekey_upload_lock.lock().await;
// Re-check after acquiring lock (another task may have uploaded)
if self
.persistence_manager
.get_device_snapshot()
.await
.server_has_prekeys
{
return Ok(());
}
log::info!("Server missing prekeys (persisted flag), uploading.");
self.upload_pre_keys_inner().await
}
/// Ensure the server has enough pre-keys, uploading if below threshold.
/// When `force` is true, skips the count guard (used by digest key repair).
pub(crate) async fn upload_pre_keys(&self, force: bool) -> Result<(), anyhow::Error> {
let server_count = self
.get_server_pre_key_count()
.await
.map_err(|e| anyhow::anyhow!(e))?;
if !force && server_count >= MIN_PRE_KEY_COUNT {
log::debug!("Server has {server_count} pre-keys, no upload needed.");
return Ok(());
}
log::debug!("Server has {server_count} pre-keys, uploading.");
self.upload_pre_keys_inner().await
}
/// Generate and upload WANTED_PRE_KEY_COUNT pre-keys. Shared by
/// `upload_pre_keys` and `upload_pre_keys_at_login` to avoid
/// redundant server count queries.
async fn upload_pre_keys_inner(&self) -> Result<(), anyhow::Error> {
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
let device_store = self.persistence_manager.get_device_arc().await;
let backend = {
let device_guard = device_store.read().await;
device_guard.backend.clone()
};
// Use the persistent counter, falling back to max(store_id)+1 for migration.
// The counter is the source of truth after the first upload.
let max_id = backend.get_max_prekey_id().await?;
let raw_start = if device_snapshot.next_pre_key_id > 0 {
std::cmp::max(device_snapshot.next_pre_key_id, max_id + 1)
} else {
log::info!(
"Migrating pre-key counter: MAX(key_id) in store = {}, starting from {}",
max_id,
max_id + 1
);
max_id + 1
};
// WA Web uses 24-bit PreKey IDs (max 2^24 - 1 = 16777215).
// Wrap into valid range so lingering high-ID rows don't pin start_id
// above the boundary and cause repeated overwrites of low IDs.
const MAX_PREKEY_ID: u32 = 16777215;
let start_id = ((raw_start as u64 - 1) % MAX_PREKEY_ID as u64) as u32 + 1;
let mut keys_to_upload = Vec::with_capacity(WANTED_PRE_KEY_COUNT);
let mut key_pairs_to_upload = Vec::with_capacity(WANTED_PRE_KEY_COUNT);
for i in 0..WANTED_PRE_KEY_COUNT {
let pre_key_id =
(((start_id as u64 - 1) + i as u64) % (MAX_PREKEY_ID as u64)) as u32 + 1;
let key_pair = KeyPair::generate(&mut rand::make_rng::<rand::rngs::StdRng>());
let pre_key_record = new_pre_key_record(pre_key_id, &key_pair);
keys_to_upload.push((pre_key_id, pre_key_record));
key_pairs_to_upload.push((pre_key_id, key_pair));
}
// Encode all prekey records into a single contiguous buffer, then slice
// into Bytes sub-views. This replaces 812 individual encode_to_vec() allocs
// with one large allocation + zero-copy slicing.
let encoded_batch: Vec<(u32, bytes::Bytes)> = {
use prost::Message;
let total_len: usize = keys_to_upload.iter().map(|(_, r)| r.encoded_len()).sum();
let mut buf = Vec::with_capacity(total_len);
let mut offsets = Vec::with_capacity(keys_to_upload.len());
for (id, record) in &keys_to_upload {
let start = buf.len();
record
.encode(&mut buf)
.expect("prost encode into pre-sized Vec");
offsets.push((*id, start..buf.len()));
}
let shared = bytes::Bytes::from(buf);
offsets
.into_iter()
.map(|(id, range)| (id, shared.slice(range)))
.collect()
};
// Persist the freshly generated prekeys before uploading them so they are
// already available for local decryption if the server starts sending
// pkmsg traffic immediately after accepting the upload.
// Propagate errors — uploading a key we can't store locally would cause
// decryption failures when the server hands it out.
backend.store_prekeys_batch(&encoded_batch, false).await?;
let pre_key_pairs: Vec<(u32, PublicKey)> = key_pairs_to_upload
.iter()
.map(|(id, key_pair)| (*id, key_pair.public_key))
.collect();
let spec = PreKeyUploadSpec::new(
device_snapshot.registration_id,
device_snapshot.identity_key.public_key,
device_snapshot.signed_pre_key_id,
device_snapshot.signed_pre_key.public_key,
device_snapshot.signed_pre_key_signature.to_vec(),
pre_key_pairs,
);
self.execute(spec).await?;
// Mark the uploaded prekeys as server-synced (reuse encoded batch)
if let Err(e) = backend.store_prekeys_batch(&encoded_batch, true).await {
log::warn!("Failed to mark prekeys as uploaded: {:?}", e);
}
// IDs wrap modulo MAX_PREKEY_ID. If the counter wraps while unconsumed
// high-ID prekeys still exist, the upsert (.on_conflict.do_update)
// silently overwrites them. Acceptable: the server consumes keys well
// before a full 16M cycle completes.
let next_id = (((start_id as u64 - 1) + key_pairs_to_upload.len() as u64)
% (MAX_PREKEY_ID as u64)) as u32
+ 1;
self.persistence_manager
.process_command(DeviceCommand::SetNextPreKeyId(next_id))
.await;
// Persist flag matching WA Web's setServerHasPreKeys(true) (PreKeysJob.js:79)
self.persistence_manager
.modify_device(|d| d.server_has_prekeys = true)
.await;
log::debug!(
"Successfully uploaded {} new pre-keys with sequential IDs starting from {}.",
key_pairs_to_upload.len(),
start_id
);
Ok(())
}
/// Upload pre-keys with Fibonacci retry backoff matching WA Web's `PromiseRetryLoop`.
///
/// Retry schedule: 1s, 2s, 3s, 5s, 8s, 13s, ... capped at 610s.
/// Verified against WA Web JS: `{ algo: { type: "fibonacci", first: 1e3, second: 2e3 }, max: 61e4 }`
///
/// When `force` is true, bypasses the count guard (used by digest repair path).
pub(crate) async fn upload_pre_keys_with_retry(
&self,
force: bool,
) -> Result<(), anyhow::Error> {
let mut delay_a: u64 = 1;
let mut delay_b: u64 = 2;
const MAX_DELAY_SECS: u64 = 610;
loop {
match self.upload_pre_keys(force).await {
Ok(()) => {
log::info!("Pre-key upload succeeded");
return Ok(());
}
Err(e) => {
let delay = delay_a.min(MAX_DELAY_SECS);
log::warn!("Pre-key upload failed, retrying in {}s: {:?}", delay, e);
self.runtime
.sleep(std::time::Duration::from_secs(delay))
.await;
// Bail if disconnected during retry wait
if !self.is_logged_in.load(Ordering::Relaxed) {
return Err(anyhow::anyhow!(
"Connection lost during pre-key upload retry"
));
}
let next = delay_a + delay_b;
delay_a = delay_b;
delay_b = next;
}
}
}
}
/// Force-refresh the server's one-time pre-key pool with a fresh batch.
///
/// Intended for callers that just restored a device from an external source
/// (e.g., migrating a Baileys session into an `InMemoryBackend`). The server
/// may still hold pre-key IDs whose private key material the caller cannot
/// reconstruct; any `pkmsg` referencing those IDs will fail forever with
/// `InvalidPreKeyId`. Uploading a fresh batch gives the server new IDs the
/// caller *does* have locally, and old unmatched IDs drain as peers consume
/// them.
///
/// Acquires `prekey_upload_lock` for the duration so this force-upload
/// cannot race on `start_id` with the count-based and digest-repair paths.
pub async fn refresh_pre_keys(&self) -> Result<(), anyhow::Error> {
let _guard = self.prekey_upload_lock.lock().await;
self.upload_pre_keys_with_retry(true).await
}
/// Validate server key bundle digest, re-uploading only when the server has no record.
///
/// Matches WA Web's `WAWebDigestKeyJob.digestKey()`:
/// 1. Queries server for key bundle digest (identity + signed prekey + prekey IDs + SHA-1 hash)
/// 2. If server returns 404 (no record): triggers `upload_pre_keys_with_retry()`
/// 3. If server returns 406/503/other error: logs and does nothing
/// 4. On success: loads local keys and computes SHA-1 over the same material
/// 5. If validation fails (regId mismatch, missing prekey, hash mismatch): logs warning,
/// does NOT re-upload — WA Web catches all `validateLocalKeyBundle` exceptions without
/// re-uploading; the normal `RotateKeyJob` will eventually refresh keys
pub(crate) async fn validate_digest_key(&self) -> Result<(), anyhow::Error> {
let response = match self.execute(DigestKeyBundleSpec::new()).await {
Ok(resp) => resp,
Err(crate::request::IqError::ServerError { code: 404, .. }) => {
log::warn!("digestKey: no record found for current user, re-uploading");
return self.upload_pre_keys_with_retry(true).await;
}
Err(crate::request::IqError::ServerError { code: 406, .. }) => {
log::warn!("digestKey: malformed request");
return Ok(());
}
Err(crate::request::IqError::ServerError { code: 503, .. }) => {
log::warn!("digestKey: service unavailable");
return Ok(());
}
Err(crate::request::IqError::ParseError(e)) => {
// WA Web catches parse failures without re-uploading
log::debug!("digestKey: unparseable digest response ({e}), skipping");
return Ok(());
}
Err(e) => {
if !self.is_shutting_down() {
log::warn!("digestKey: server error: {:?}", e);
}
return Ok(());
}
};
// WA Web's validateLocalKeyBundle validates but catches ALL exceptions without
// re-uploading. The catch block in digestKey() sets a=false for any throw from y(),
// meaning only 404 triggers re-upload. We match that: log warnings, return Ok(()).
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
if response.reg_id != device_snapshot.registration_id {
log::warn!(
"digestKey: registration ID mismatch (server={}, local={}), skipping",
response.reg_id,
device_snapshot.registration_id
);
return Ok(());
}
// Compute local SHA-1 digest over the same material as WA Web's validateLocalKeyBundle:
// identity_pub_key + signed_prekey_pub + signed_prekey_signature + (for each prekey ID: load 32-byte pubkey)
let identity_bytes = device_snapshot.identity_key.public_key.public_key_bytes();
let skey_pub_bytes = device_snapshot.signed_pre_key.public_key.public_key_bytes();
let skey_sig_bytes = &device_snapshot.signed_pre_key_signature;
let device_store = self.persistence_manager.get_device_arc().await;
let backend = {
let guard = device_store.read().await;
guard.backend.clone()
};
// Batch-load all prekeys referenced by the server digest
let loaded = match backend.load_prekeys_batch(&response.prekey_ids).await {
Ok(v) => v,
Err(e) => {
log::warn!("digestKey: failed to batch-load prekeys: {:?}, skipping", e);
return Ok(());
}
};
// Build a lookup so we preserve the server-requested order.
// Dedupe the expected count since the server may send duplicate IDs.
let loaded_map: std::collections::HashMap<u32, bytes::Bytes> = loaded.into_iter().collect();
let unique_requested: std::collections::HashSet<&u32> =
response.prekey_ids.iter().collect();
if loaded_map.len() < unique_requested.len() {
log::warn!(
"digestKey: missing {} local prekeys, skipping",
unique_requested.len() - loaded_map.len()
);
return Ok(());
}
// Extract public keys directly from stored protobuf bytes without full decode
let mut prekey_pubkeys = Vec::with_capacity(response.prekey_ids.len());
for prekey_id in &response.prekey_ids {
let Some(record_bytes) = loaded_map.get(prekey_id) else {
log::warn!("digestKey: missing local prekey {}, skipping", prekey_id);
return Ok(());
};
match wacore::prekeys::extract_prekey_public_key(record_bytes) {
Some(pk) => prekey_pubkeys.push(pk),
None => {
log::warn!(
"digestKey: prekey {} has no public key, skipping",
prekey_id
);
return Ok(());
}
}
}
let local_hash = wacore::prekeys::compute_key_bundle_digest(
identity_bytes,
skey_pub_bytes,
skey_sig_bytes,
&prekey_pubkeys,
);
if local_hash.as_slice() != response.hash.as_slice() {
log::warn!(
"digestKey: hash mismatch (server={}, local={}), skipping",
hex::encode(&response.hash),
hex::encode(local_hash)
);
return Ok(());
}
log::debug!("digestKey: key bundle validation successful");
Ok(())
}
}