bsv_messagebox_client/client.rs
1use std::collections::{HashMap, HashSet, VecDeque};
2use std::sync::Arc;
3
4use bsv::auth::clients::auth_fetch::{AuthFetch, AuthFetchResponse};
5use bsv::remittance::types::PeerMessage;
6use bsv::services::overlay_tools::Network;
7use bsv::wallet::interfaces::{GetPublicKeyArgs, WalletInterface};
8use tokio::sync::{Mutex, OnceCell};
9
10use crate::delivery::DeliveryMode;
11use crate::error::MessageBoxError;
12use crate::types::{ListMessagesParams, ListMessagesResponse};
13
14/// Callback type for live message subscriptions.
15type SubscriptionCallback = Arc<dyn Fn(PeerMessage) + Send + Sync>;
16
17/// Authenticated HTTP client for the MessageBox protocol.
18///
19/// `MessageBoxClient<W>` wraps `AuthFetch<W>` behind a `tokio::sync::Mutex`
20/// because `AuthFetch::fetch()` takes `&mut self` and must be called from
21/// async context. A `std::sync::Mutex` held across `.await` panics under
22/// Tokio — this is load-bearing and must not be changed.
23pub struct MessageBoxClient<W: WalletInterface + Clone + 'static> {
24 /// Base URL of the MessageBox server (trailing whitespace trimmed on construction).
25 host: String,
26 /// BRC-31 authenticated HTTP client (needs `&mut self`, hence Arc<Mutex> for sharing).
27 /// Arc allows cloning into background polling tasks without duplicating the wallet auth state.
28 auth_fetch: Arc<Mutex<AuthFetch<W>>>,
29 /// Wallet retained for direct encrypt / decrypt calls in `http_ops`.
30 wallet: W,
31 /// Optional originator string forwarded to wallet operations.
32 originator: Option<String>,
33 /// Cached identity public key (hex) — populated on first call.
34 identity_key: OnceCell<String>,
35 /// Ensures assert_initialized runs the full init path at most once.
36 pub(crate) init_once: OnceCell<()>,
37 /// Network preset for overlay tools (LookupResolver, TopicBroadcaster).
38 /// Defaults to Mainnet; pass Network::Local for localhost integration tests.
39 pub(crate) network: Network,
40 /// WebSocket connection state (None until first live message call).
41 ws_state: Mutex<Option<crate::websocket::MessageBoxWebSocket>>,
42 /// Tracks which message box rooms have been joined via join_room.
43 /// Updated on join_room (insert) and leave_room (remove).
44 /// Arc allows sharing with background polling tasks.
45 joined_rooms: Arc<Mutex<std::collections::HashSet<String>>>,
46 /// Registry of active subscriptions: room_id → callback. Survives reconnects.
47 /// Entries are added by listen_for_live_messages and removed by leave_room.
48 /// On WS reconnect, ensure_ws_connected replays joinRoom + re-subscribes
49 /// each entry on the fresh socket.
50 subscriptions: Arc<Mutex<HashMap<String, SubscriptionCallback>>>,
51}
52
53impl<W: WalletInterface + Clone + 'static + Send + Sync> MessageBoxClient<W> {
54 /// Construct a new `MessageBoxClient`.
55 ///
56 /// * `host` — Base URL of the MessageBox server. Trailing whitespace is
57 /// trimmed so callers do not need to sanitize.
58 /// * `wallet` — Any `WalletInterface` implementation.
59 /// * `originator` — Optional originator string forwarded to wallet ops.
60 /// * `network` — Network preset for overlay tools (use `Network::Local` for localhost).
61 pub fn new(host: String, wallet: W, originator: Option<String>, network: Network) -> Self {
62 MessageBoxClient {
63 host: host.trim().to_string(),
64 auth_fetch: Arc::new(Mutex::new(AuthFetch::new(wallet.clone()))),
65 wallet,
66 originator,
67 identity_key: OnceCell::new(),
68 init_once: OnceCell::new(),
69 network,
70 ws_state: Mutex::new(None),
71 joined_rooms: Arc::new(Mutex::new(std::collections::HashSet::new())),
72 subscriptions: Arc::new(Mutex::new(HashMap::new())),
73 }
74 }
75
76 /// Convenience constructor defaulting to `Network::Mainnet`.
77 pub fn new_mainnet(host: String, wallet: W, originator: Option<String>) -> Self {
78 Self::new(host, wallet, originator, Network::Mainnet)
79 }
80
81 // -----------------------------------------------------------------------
82 // Public getters (needed by http_ops)
83 // -----------------------------------------------------------------------
84
85 /// Return the trimmed host URL.
86 pub fn host(&self) -> &str {
87 &self.host
88 }
89
90 /// Return a reference to the underlying wallet.
91 pub fn wallet(&self) -> &W {
92 &self.wallet
93 }
94
95 /// Return the originator string, if any.
96 pub fn originator(&self) -> Option<&str> {
97 self.originator.as_deref()
98 }
99
100 /// Return the network preset used for overlay operations.
101 pub fn network(&self) -> &Network {
102 &self.network
103 }
104
105 // -----------------------------------------------------------------------
106 // Identity key
107 // -----------------------------------------------------------------------
108
109 /// Return the wallet's identity public key as a DER hex string.
110 ///
111 /// The result is cached in a `OnceCell` — subsequent calls return the
112 /// cached value without calling the wallet again.
113 pub async fn get_identity_key(&self) -> Result<String, MessageBoxError> {
114 if let Some(k) = self.identity_key.get() {
115 return Ok(k.clone());
116 }
117
118 let result = self
119 .wallet
120 .get_public_key(
121 GetPublicKeyArgs {
122 identity_key: true,
123 protocol_id: None,
124 key_id: None,
125 counterparty: None,
126 privileged: false,
127 privileged_reason: None,
128 for_self: None,
129 seek_permission: None,
130 },
131 self.originator.as_deref(),
132 )
133 .await
134 .map_err(|e| MessageBoxError::Wallet(e.to_string()))?;
135
136 let key = result.public_key.to_der_hex();
137 // Ignore the error — if another caller set the cell first, we just use
138 // the stored value.
139 let _ = self.identity_key.set(key.clone());
140 Ok(key)
141 }
142
143 // -----------------------------------------------------------------------
144 // Initialization guard
145 // -----------------------------------------------------------------------
146
147 /// Ensure the client is initialized before performing any HTTP operation.
148 ///
149 /// Uses `init_once.get_or_try_init` so the full init path runs at most once
150 /// even under concurrent callers — matching the TS `initializeConnection`
151 /// pattern which defers work until first use.
152 ///
153 /// Init sequence:
154 /// 1. Cache identity key.
155 /// 2. Query overlay advertisements for this identity + host.
156 /// 3. If no matching ad exists, call `anoint_host`.
157 /// 4. CRITICAL TS PARITY: catch anoint errors and continue — TS logs
158 /// "Failed to anoint host, continuing with default functionality".
159 pub(crate) async fn assert_initialized(&self) -> Result<(), MessageBoxError> {
160 self.init_once
161 .get_or_try_init(|| async {
162 let identity_key = self.get_identity_key().await?;
163 // Query existing advertisements for this identity+host pair.
164 // unwrap_or_default() because query_advertisements never fails (TS parity).
165 let ads = self
166 .query_advertisements(Some(&identity_key), Some(&self.host))
167 .await
168 .unwrap_or_default();
169 if ads.iter().all(|ad| ad.host.trim() != self.host.trim()) {
170 // No matching advertisement — anoint this host.
171 // CRITICAL TS PARITY: catch anoint errors and continue.
172 // TS: "Failed to anoint host, continuing with default functionality"
173 if let Err(e) = self.anoint_host(&self.host).await {
174 eprintln!("Warning: failed to anoint host: {e}");
175 }
176 }
177 Ok::<(), MessageBoxError>(())
178 })
179 .await?;
180 Ok(())
181 }
182
183 /// Initialize the client — ensures overlay advertisement exists.
184 ///
185 /// User-facing wrapper for `assert_initialized`. Safe to call multiple times —
186 /// the init path runs exactly once due to `init_once` OnceCell semantics.
187 ///
188 /// `target_host`: when Some, uses that host for the anoint_host call instead of
189 /// `self.host()`. Matches the TS `init(targetHost?)` signature.
190 pub async fn init(&self, target_host: Option<&str>) -> Result<(), MessageBoxError> {
191 match target_host {
192 Some(host) => {
193 // TS parity: if targetHost provided, anoint THAT host directly
194 // instead of going through assert_initialized's self.host logic.
195 self.init_once
196 .get_or_try_init(|| async {
197 let _identity_key = self.get_identity_key().await?;
198 // CRITICAL TS PARITY: catch anoint errors and continue.
199 if let Err(e) = self.anoint_host(host).await {
200 eprintln!("Warning: failed to anoint host: {e}");
201 }
202 Ok::<(), MessageBoxError>(())
203 })
204 .await?;
205 Ok(())
206 }
207 None => self.assert_initialized().await,
208 }
209 }
210
211 /// Ensure the WebSocket connection is established.
212 ///
213 /// User-facing wrapper for `ensure_ws_connected`. Mirrors the TS
214 /// `initializeConnection` method. When `override_host` is Some, the WS
215 /// connection uses that host instead of `self.host()`.
216 pub async fn initialize_connection(
217 &self,
218 override_host: Option<&str>,
219 ) -> Result<(), MessageBoxError> {
220 self.ensure_ws_connected(override_host).await
221 }
222
223 /// Returns a clone of the set of currently joined message box room names.
224 ///
225 /// Each entry is a raw room ID of the form `{identityKey}-{messageBox}`.
226 /// Mirrors the TS `joinedRooms` Map accessor.
227 pub fn get_joined_rooms(&self) -> std::collections::HashSet<String> {
228 // Use blocking lock — this is only called from synchronous test contexts.
229 // For async callers, they should be fine as the lock is never held across await.
230 self.joined_rooms.blocking_lock().clone()
231 }
232
233 /// Returns `Some(true)` if a WebSocket connection is active, `None` otherwise.
234 ///
235 /// Sync test utility mirroring the TS `testSocket` accessor. Only callable from
236 /// sync test contexts (e.g., unit tests) where `blocking_lock` is safe.
237 #[cfg(test)]
238 pub fn test_socket(&self) -> Option<bool> {
239 let guard = self.ws_state.blocking_lock();
240 guard.as_ref().map(|ws| ws.is_connected())
241 }
242
243 /// Returns the current WebSocket connection state.
244 ///
245 /// `None` = no WebSocket connected yet. `Some(true)` = connected and authenticated.
246 /// `Some(false)` = connected but BRC-103 handshake not yet complete.
247 ///
248 /// Async version of `test_socket` — safe to call from integration tests and any
249 /// async context. Mirrors the TS `testSocket` accessor.
250 pub async fn is_ws_connected(&self) -> Option<bool> {
251 let guard = self.ws_state.lock().await;
252 guard.as_ref().map(|ws| ws.is_connected())
253 }
254
255 /// Join a Socket.IO room for a message box and track it in `joined_rooms`.
256 ///
257 /// Constructs the room ID as `{identityKey}-{messageBox}` (matching TS joinRoom).
258 /// Ensures the WebSocket is connected before joining.
259 /// No-op if the room is already joined (idempotent like TS joinRoom).
260 pub async fn join_room(&self, message_box: &str) -> Result<(), MessageBoxError> {
261 let identity_key = self.get_identity_key().await?;
262 let room_id = format!("{identity_key}-{message_box}");
263
264 self.ensure_ws_connected(None).await?;
265
266 {
267 let guard = self.ws_state.lock().await;
268 if let Some(ref ws) = *guard {
269 ws.join_room(&room_id).await?;
270 }
271 }
272
273 self.joined_rooms.lock().await.insert(room_id);
274 Ok(())
275 }
276
277 /// POST JSON bytes to `url` using BRC-31 authenticated transport.
278 ///
279 /// The entire `fetch()` call executes while the lock is held. This is
280 /// correct because `fetch()` is the outermost operation; no re-entrant
281 /// locking occurs on the Phase 1 code path.
282 pub(crate) async fn post_json(
283 &self,
284 url: &str,
285 body_bytes: Vec<u8>,
286 ) -> Result<AuthFetchResponse, MessageBoxError> {
287 let mut headers = HashMap::new();
288 headers.insert("content-type".to_string(), "application/json".to_string());
289
290 let response = self
291 .auth_fetch
292 .lock()
293 .await
294 .fetch(url, "POST", Some(body_bytes), Some(headers))
295 .await
296 .map_err(|e| MessageBoxError::Auth(e.to_string()))?;
297
298 if response.status < 200 || response.status >= 300 {
299 return Err(MessageBoxError::Http(response.status, url.to_string()));
300 }
301
302 Ok(response)
303 }
304
305 // -----------------------------------------------------------------------
306 // WebSocket live messaging
307 // -----------------------------------------------------------------------
308
309 /// Ensure a WebSocket connection is established and authenticated.
310 ///
311 /// If no connection exists or the existing connection is no longer connected,
312 /// creates a new `MessageBoxWebSocket` with the current identity key.
313 /// `rust_socketio` handles the Socket.IO handshake and HTTP-to-WS upgrade
314 /// internally — we pass the same base URL used for HTTP requests.
315 async fn ensure_ws_connected(
316 &self,
317 override_host: Option<&str>,
318 ) -> Result<(), MessageBoxError> {
319 let mut guard = self.ws_state.lock().await;
320 if guard.as_ref().map(|ws| ws.is_connected()).unwrap_or(false) {
321 return Ok(());
322 }
323 let identity_key = self.get_identity_key().await?;
324 let ws_url = override_host.unwrap_or_else(|| self.host()).to_string();
325 let ws = crate::websocket::MessageBoxWebSocket::connect(
326 &ws_url,
327 &identity_key,
328 self.wallet.clone(),
329 self.originator.clone(),
330 )
331 .await?;
332
333 // Replay subscriptions on the fresh socket so general_msg_dispatcher
334 // has callbacks registered for every active room. Without this, events
335 // delivered to the reconnected socket are silently dropped.
336 {
337 let subs = self.subscriptions.lock().await;
338 for (room_id, callback) in subs.iter() {
339 let event_key = format!("sendMessage-{room_id}");
340 // Re-join the room — ignore errors (server may already know us)
341 if let Err(e) = ws.join_room(room_id).await {
342 tracing::warn!(room_id, error = %e, "joinRoom replay failed on reconnect");
343 } else {
344 ws.subscribe(event_key, callback.clone()).await;
345 tracing::info!(room_id, "replayed subscription on reconnected socket");
346 }
347 }
348 }
349
350 *guard = Some(ws);
351 Ok(())
352 }
353
354 /// Listen for live messages on a message box via WebSocket.
355 ///
356 /// Joins the Socket.IO room `{identity_key}-{message_box}` and registers
357 /// the provided callback. Messages can arrive via three paths, all funnelled
358 /// through one shared dedup wrapper so the callback fires at most once per id:
359 ///
360 /// 1. **WebSocket push (primary):** the BRC-103 `general_msg_dispatcher` fires
361 /// the callback when the server broadcasts a signed `sendMessage-{roomId}`.
362 /// 2. **WebSocket `on_any` (fallback):** the same callback, for servers that
363 /// emit raw (unsigned) room events.
364 /// 3. **HTTP poll (backstop):** a background task that *stands down* for any
365 /// interval in which WS push already delivered (so it does not poll every
366 /// 2 s when live push is healthy), and otherwise polls `/listMessages` to
367 /// catch anything the WS paths missed — forcing a catch-up at least every
368 /// `MAX_POLL_SKIPS` intervals. Stops when `leave_room` removes the room
369 /// from `joined_rooms`.
370 ///
371 /// All paths deliver `PeerMessage` with decrypted body to the same callback.
372 ///
373 /// Establishes a WebSocket connection if one is not already active.
374 /// `override_host` is reserved for future multi-host WS routing.
375 pub async fn listen_for_live_messages(
376 &self,
377 message_box: &str,
378 on_message: Arc<dyn Fn(PeerMessage) + Send + Sync>,
379 override_host: Option<&str>,
380 ) -> Result<(), MessageBoxError> {
381 let identity_key = self.get_identity_key().await?;
382 let room_id = format!("{identity_key}-{message_box}");
383 let event_key = format!("sendMessage-{room_id}");
384
385 self.ensure_ws_connected(override_host).await?;
386
387 // One user callback, three possible delivery paths (WS primary
388 // dispatcher, WS `on_any` fallback, HTTP poll). Wrap once so a
389 // `message_id` reaches the user at most once regardless of which path
390 // wins the race.
391 let deduped = exactly_once(on_message);
392
393 // WS-only delivery counter. Bumped on every WebSocket delivery so the
394 // HTTP poll backstop below can detect healthy live-push and stand down,
395 // instead of hitting /listMessages every 2s for every active mailbox.
396 let ws_activity = Arc::new(std::sync::atomic::AtomicU64::new(0));
397 let ws_callback = record_ws_activity(deduped.clone(), ws_activity.clone());
398
399 {
400 let guard = self.ws_state.lock().await;
401 if let Some(ref ws) = *guard {
402 ws.join_room(&room_id).await?;
403 ws.subscribe(event_key.clone(), ws_callback.clone()).await;
404 }
405 }
406
407 // Register in the subscription registry so ensure_ws_connected can replay
408 // this subscription (with its dedup + activity wrappers) on any reconnect.
409 self.subscriptions.lock().await.insert(room_id.clone(), ws_callback.clone());
410
411 self.joined_rooms.lock().await.insert(room_id.clone());
412
413 // Spawn the HTTP poll BACKSTOP. With the server signing room broadcasts
414 // onto the client's authenticated primary path, live WS push is the fast
415 // path. The poll now exists only to (a) support servers that don't push
416 // and (b) recover if the WS connection silently stalls — so it stands
417 // down for any interval in which a WS delivery already occurred. This
418 // keeps redundant /listMessages load off the server under high-frequency,
419 // many-connection traffic while preserving a correctness backstop.
420 let poll_auth_fetch = self.auth_fetch.clone();
421 let poll_joined_rooms = self.joined_rooms.clone();
422 let poll_host = self.host.clone();
423 let poll_message_box = message_box.to_string();
424 let poll_identity_key = identity_key.clone();
425 let poll_wallet = self.wallet.clone();
426 let poll_originator = self.originator.clone();
427 let poll_room_id = room_id.clone();
428 let poll_callback = deduped;
429 let poll_ws_activity = ws_activity;
430
431 tokio::spawn(async move {
432 use std::sync::atomic::Ordering;
433 let mut last_activity = poll_ws_activity.load(Ordering::Relaxed);
434 let mut skipped: u32 = 0;
435
436 loop {
437 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
438
439 // Stop when the room is no longer active (leave_room was called)
440 if !poll_joined_rooms.lock().await.contains(&poll_room_id) {
441 break;
442 }
443
444 // Stand down while WS push is healthy, but force a catch-up at
445 // least every MAX_POLL_SKIPS intervals so partial push loss on an
446 // otherwise-active connection can't permanently suppress the poll.
447 let activity = poll_ws_activity.load(Ordering::Relaxed);
448 if !poll_should_run(activity, &mut last_activity, &mut skipped) {
449 continue;
450 }
451
452 // Backstop poll. Cheap no-op when the mailbox is simply idle; a
453 // genuine catch-up when live push stalled. The shared `deduped`
454 // callback suppresses anything WS already delivered.
455 match poll_list_messages(
456 &poll_auth_fetch,
457 &poll_host,
458 &poll_message_box,
459 &poll_identity_key,
460 &poll_wallet,
461 poll_originator.as_deref(),
462 )
463 .await
464 {
465 Ok(messages) => {
466 for msg in messages {
467 poll_callback(msg);
468 }
469 }
470 // The poll IS the correctness backstop — never swallow its
471 // error silently. A persistent failure here means both live
472 // push (already stalled) and the backstop are down.
473 Err(e) => {
474 tracing::warn!(
475 room_id = %poll_room_id,
476 error = %e,
477 "poll backstop failed — message catch-up unavailable this interval"
478 );
479 }
480 }
481
482 // Refresh in case a WS delivery landed while we were polling.
483 last_activity = poll_ws_activity.load(Ordering::Relaxed);
484 }
485 });
486
487 Ok(())
488 }
489
490 /// Send a message via WebSocket with 10-second ack timeout and HTTP fallback.
491 ///
492 /// Mirrors TS `sendLiveMessage`: auto-connects if needed, joins the sender's
493 /// own room (required for ack routing), then emits. Falls back to HTTP if the
494 /// connection cannot be established or the ack times out / fails.
495 ///
496 /// TS parity: HTTP fallback resolves recipient's host via overlay before sending.
497 /// The WS path connects to `self.host()` — overlay resolution affects the fallback path.
498 /// `override_host`: when Some, the HTTP fallback path sends to that host directly.
499 /// Send a message via WebSocket with 10-second ack timeout and HTTP fallback.
500 ///
501 /// Mirrors TS `sendLiveMessage(params, overrideHost?)` which accepts the full
502 /// `SendMessageParams` including `skipEncryption`, `checkPermissions`, and `messageId`.
503 ///
504 /// - `skip_encryption`: when true, sends body as-is without BRC-78 encryption.
505 /// - `check_permissions`: when true, the HTTP fallback path fetches quotes and pays fees.
506 /// - `message_id`: when Some, uses caller-supplied ID instead of HMAC-derived ID.
507 /// - `override_host`: when Some, the HTTP fallback sends to that host directly.
508 #[allow(clippy::too_many_arguments)]
509 pub async fn send_live_message(
510 &self,
511 recipient: &str,
512 message_box: &str,
513 body: &str,
514 skip_encryption: bool,
515 check_permissions: bool,
516 message_id: Option<&str>,
517 override_host: Option<&str>,
518 ) -> Result<DeliveryMode, MessageBoxError> {
519 // If no host override, resolve the recipient's host via overlay.
520 // If the recipient is on a DIFFERENT host than ours, we must use HTTP
521 // (send_message with overlay resolution) since our WebSocket is only
522 // connected to self.host.
523 if override_host.is_none() {
524 let resolved = self
525 .resolve_host_for_recipient(recipient)
526 .await
527 .unwrap_or_else(|_| self.host().to_string());
528 if resolved.trim() != self.host().trim() {
529 // Recipient is on a different MessageBox server — use HTTP with overlay.
530 // This is a persisted delivery (no live WS ack possible cross-host).
531 let msg_id = self
532 .send_message(
533 recipient,
534 message_box,
535 body,
536 skip_encryption,
537 check_permissions,
538 message_id,
539 None,
540 )
541 .await?;
542 return Ok(DeliveryMode::Persisted { message_id: msg_id });
543 }
544 }
545
546 // Auto-connect or reconnect — if the WebSocket is disconnected, drop the
547 // stale connection and establish a fresh one with a new BRC-103 handshake.
548 // This handles proxy timeouts and server restarts without falling through
549 // to the slower HTTP fallback path.
550 {
551 let guard = self.ws_state.lock().await;
552 if guard.as_ref().map(|ws| !ws.is_connected()).unwrap_or(false) {
553 drop(guard);
554 // Stale connection — tear down and reconnect
555 if let Err(e) = self.disconnect_web_socket().await {
556 eprintln!("Warning: stale WebSocket disconnect failed (proceeding to reconnect): {e}");
557 }
558 }
559 }
560 if let Err(e) = self.ensure_ws_connected(override_host).await {
561 eprintln!("Warning: WebSocket connection failed, falling back to HTTP: {e}");
562 // HTTP fallback: use override_host if provided, otherwise overlay resolution.
563 // Returns Persisted because no WS ack was received.
564 let msg_id = match override_host {
565 Some(host) => {
566 self.send_message_to_host(
567 host,
568 recipient,
569 message_box,
570 body,
571 skip_encryption,
572 check_permissions,
573 message_id,
574 None,
575 )
576 .await?
577 }
578 None => {
579 self.send_message(
580 recipient,
581 message_box,
582 body,
583 skip_encryption,
584 check_permissions,
585 message_id,
586 None,
587 )
588 .await?
589 }
590 };
591 return Ok(DeliveryMode::Persisted { message_id: msg_id });
592 }
593
594 // Join sender's own room before send — TS calls joinRoom(messageBox) which
595 // joins `${myIdentityKey}-${messageBox}`. Required so the server can route
596 // the sendMessageAck back to this socket.
597 let identity_key = self.get_identity_key().await?;
598 let my_room = format!("{identity_key}-{message_box}");
599 {
600 let guard = self.ws_state.lock().await;
601 if let Some(ref ws) = *guard {
602 if ws.join_room(&my_room).await.is_err() {
603 drop(guard);
604 let msg_id = match override_host {
605 Some(host) => {
606 self.send_message_to_host(
607 host,
608 recipient,
609 message_box,
610 body,
611 skip_encryption,
612 check_permissions,
613 message_id,
614 None,
615 )
616 .await?
617 }
618 None => {
619 self.send_message(
620 recipient,
621 message_box,
622 body,
623 skip_encryption,
624 check_permissions,
625 message_id,
626 None,
627 )
628 .await?
629 }
630 };
631 return Ok(DeliveryMode::Persisted { message_id: msg_id });
632 }
633 } else {
634 drop(guard);
635 let msg_id = match override_host {
636 Some(host) => {
637 self.send_message_to_host(
638 host,
639 recipient,
640 message_box,
641 body,
642 skip_encryption,
643 check_permissions,
644 message_id,
645 None,
646 )
647 .await?
648 }
649 None => {
650 self.send_message(
651 recipient,
652 message_box,
653 body,
654 skip_encryption,
655 check_permissions,
656 message_id,
657 None,
658 )
659 .await?
660 }
661 };
662 return Ok(DeliveryMode::Persisted { message_id: msg_id });
663 }
664 }
665
666 // Encrypt (unless skip_encryption) and resolve message ID for the WebSocket path
667 let encrypted = if skip_encryption {
668 body.to_string()
669 } else {
670 crate::encryption::encrypt_body(self.wallet(), body, recipient, self.originator())
671 .await?
672 };
673 let message_id = if let Some(id) = message_id {
674 id.to_string()
675 } else {
676 crate::encryption::generate_message_id(
677 self.wallet(),
678 body,
679 recipient,
680 self.originator(),
681 )
682 .await?
683 };
684
685 let room_id = format!("{recipient}-{message_box}");
686 let ack_key = format!("sendMessageAck-{room_id}");
687
688 let (ack_tx, ack_rx) = tokio::sync::oneshot::channel::<bool>();
689
690 let payload = serde_json::json!({
691 "roomId": room_id,
692 "message": {
693 "messageId": message_id,
694 "recipient": recipient,
695 "body": encrypted
696 }
697 });
698
699 // Emit — acquire lock briefly then release before awaiting ack
700 {
701 let guard = self.ws_state.lock().await;
702 if let Some(ref ws) = *guard {
703 ws.emit_send_message(payload, ack_key.clone(), ack_tx)
704 .await?;
705 }
706 }
707
708 // Await ack with 10-second timeout.
709 // Live: server acked via WS within the window → DeliveryMode::Live.
710 // Anything else (timeout, channel error, ack=false): fall back to HTTP
711 // and return DeliveryMode::Persisted.
712 match tokio::time::timeout(std::time::Duration::from_secs(10), ack_rx).await {
713 Ok(Ok(true)) => Ok(DeliveryMode::Live { message_id }),
714 _ => {
715 // Clean up the pending ack to prevent channel leaks (Pitfall 7)
716 let guard = self.ws_state.lock().await;
717 if let Some(ref ws) = *guard {
718 ws.remove_pending_ack(&ack_key).await;
719 }
720 drop(guard);
721 tracing::debug!(
722 "send_live_message: WS ack timed out or failed; falling back to HTTP"
723 );
724 // Fall back to HTTP — pass through all feature params.
725 // The HTTP path generates a fresh message ID; use that for the Persisted ID.
726 let http_id = match override_host {
727 Some(host) => {
728 self.send_message_to_host(
729 host,
730 recipient,
731 message_box,
732 body,
733 skip_encryption,
734 check_permissions,
735 None,
736 None,
737 )
738 .await?
739 }
740 None => {
741 self.send_message(
742 recipient,
743 message_box,
744 body,
745 skip_encryption,
746 check_permissions,
747 None,
748 None,
749 )
750 .await?
751 }
752 };
753 Ok(DeliveryMode::Persisted { message_id: http_id })
754 }
755 }
756 }
757
758 /// Leave a Socket.IO room and remove its subscription.
759 ///
760 /// Mirrors TS `leaveRoom(messageBox)`. Constructs the room ID as
761 /// `{identityKey}-{messageBox}` and emits `leaveRoom` to the server.
762 /// No-op if the WebSocket is not connected.
763 /// `override_host` is reserved for future multi-host WS routing.
764 pub async fn leave_room(
765 &self,
766 message_box: &str,
767 override_host: Option<&str>,
768 ) -> Result<(), MessageBoxError> {
769 let _ = override_host;
770 let identity_key = self.get_identity_key().await?;
771 let room_id = format!("{identity_key}-{message_box}");
772 {
773 let guard = self.ws_state.lock().await;
774 if let Some(ref ws) = *guard {
775 ws.leave_room(&room_id).await?;
776 }
777 }
778 self.joined_rooms.lock().await.remove(&room_id);
779 // Remove from subscription registry so future reconnects don't replay it.
780 self.subscriptions.lock().await.remove(&room_id);
781 Ok(())
782 }
783
784 /// Disconnect the WebSocket connection and clear its state.
785 ///
786 /// Safe to call when no connection is active (no-op).
787 pub async fn disconnect_web_socket(&self) -> Result<(), MessageBoxError> {
788 let mut guard = self.ws_state.lock().await;
789 if let Some(ws) = guard.take() {
790 ws.disconnect().await?;
791 }
792 Ok(())
793 }
794
795 // -----------------------------------------------------------------------
796 // Internal HTTP helpers
797 // -----------------------------------------------------------------------
798
799 /// GET `url` using BRC-31 authenticated transport.
800 ///
801 /// Mirrors `post_json` but sends no body and no content-type header.
802 /// The caller is responsible for building the full URL including query string.
803 pub(crate) async fn get_json(&self, url: &str) -> Result<AuthFetchResponse, MessageBoxError> {
804 let response = self
805 .auth_fetch
806 .lock()
807 .await
808 .fetch(url, "GET", None, None)
809 .await
810 .map_err(|e| MessageBoxError::Auth(e.to_string()))?;
811
812 if response.status < 200 || response.status >= 300 {
813 return Err(MessageBoxError::Http(response.status, url.to_string()));
814 }
815
816 Ok(response)
817 }
818}
819
820// ---------------------------------------------------------------------------
821// Delivery callback wrappers
822// ---------------------------------------------------------------------------
823
824/// Maximum consecutive poll-backstop skips before a poll is forced regardless of
825/// WS activity. At the 2 s interval this caps the catch-up window at ~16 s, so a
826/// chatty room that keeps WS *activity* alive can never permanently suppress the
827/// backstop even if *individual* pushes are being lost.
828const MAX_POLL_SKIPS: u32 = 7;
829
830/// Wrap a subscriber callback so each `message_id` is delivered **at most once**,
831/// no matter which path produced it: the WS primary dispatcher, the WS `on_any`
832/// fallback, or the HTTP poll backstop. Without this, a server that both pushes
833/// and is polled (or an `on_any`/dispatcher race) fires the callback twice.
834///
835/// Dedup is bounded to the most recent 10,000 distinct ids (FIFO eviction), so a
836/// duplicate separated from its original by more than 10,000 intervening ids can
837/// still slip through — acceptable for the single-use-mailbox traffic this
838/// serves. This suppresses duplicates only; delivery liveness comes from the
839/// WS + poll paths, not from this wrapper.
840fn exactly_once(
841 inner: Arc<dyn Fn(PeerMessage) + Send + Sync>,
842) -> Arc<dyn Fn(PeerMessage) + Send + Sync> {
843 let seen = Arc::new(std::sync::Mutex::new(BoundedIdSet::new(10_000)));
844 Arc::new(move |msg: PeerMessage| {
845 // Recover rather than panic on poison: a poisoned dedup set is harmless
846 // (worst case one duplicate delivery), whereas panicking here would kill
847 // the WS dispatcher or the spawned poll task and stop delivery entirely.
848 let fresh = match seen.lock() {
849 Ok(mut g) => g.insert(msg.message_id.clone()),
850 Err(poisoned) => poisoned.into_inner().insert(msg.message_id.clone()),
851 };
852 if fresh {
853 inner(msg);
854 }
855 })
856}
857
858/// Wrap a callback so every invocation first bumps `activity`. Installed only on
859/// the WebSocket delivery path, so the HTTP poll backstop can tell when live
860/// push is healthy and stand down (see the poll loop in
861/// [`MessageBoxClient::listen_for_live_messages`]).
862fn record_ws_activity(
863 inner: Arc<dyn Fn(PeerMessage) + Send + Sync>,
864 activity: Arc<std::sync::atomic::AtomicU64>,
865) -> Arc<dyn Fn(PeerMessage) + Send + Sync> {
866 Arc::new(move |msg: PeerMessage| {
867 activity.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
868 inner(msg);
869 })
870}
871
872/// Decide whether the poll backstop should run this interval, given the current
873/// WS-activity counter vs the value at the last check.
874///
875/// Skips (returns `false`) while WS push is delivering — *unless* `MAX_POLL_SKIPS`
876/// consecutive skips have accrued, in which case it forces a poll so partial push
877/// loss on an otherwise-active connection can't permanently suppress catch-up.
878/// Updates `last_activity` and `skipped` in place.
879fn poll_should_run(current_activity: u64, last_activity: &mut u64, skipped: &mut u32) -> bool {
880 if current_activity != *last_activity && *skipped < MAX_POLL_SKIPS {
881 *last_activity = current_activity;
882 *skipped += 1;
883 return false; // WS push is healthy this interval — stand down
884 }
885 *last_activity = current_activity;
886 *skipped = 0;
887 true
888}
889
890// ---------------------------------------------------------------------------
891// Bounded dedup set
892// ---------------------------------------------------------------------------
893
894/// A bounded set with FIFO eviction for deduplicating message IDs.
895///
896/// Prevents unbounded memory growth in the HTTP polling fallback task.
897/// When the set reaches `capacity`, the oldest entry is evicted before
898/// inserting the new one.
899struct BoundedIdSet {
900 set: HashSet<String>,
901 order: VecDeque<String>,
902 capacity: usize,
903}
904
905impl BoundedIdSet {
906 fn new(capacity: usize) -> Self {
907 assert!(capacity > 0, "BoundedIdSet capacity must be at least 1");
908 Self {
909 set: HashSet::with_capacity(capacity),
910 order: VecDeque::with_capacity(capacity),
911 capacity,
912 }
913 }
914
915 /// Insert an ID. Returns `true` if the ID was new (not previously seen).
916 fn insert(&mut self, id: String) -> bool {
917 if self.set.contains(&id) {
918 return false;
919 }
920 if self.order.len() >= self.capacity {
921 if let Some(old) = self.order.pop_front() {
922 self.set.remove(&old);
923 }
924 }
925 self.set.insert(id.clone());
926 self.order.push_back(id);
927 debug_assert_eq!(self.set.len(), self.order.len(),
928 "BoundedIdSet internal invariant violated: set/deque size mismatch");
929 true
930 }
931
932 /// Returns the number of IDs currently tracked.
933 #[cfg(test)]
934 fn len(&self) -> usize {
935 self.order.len()
936 }
937
938 /// Returns true if the given ID is currently in the set.
939 #[cfg(test)]
940 fn contains(&self, id: &str) -> bool {
941 self.set.contains(id)
942 }
943}
944
945// ---------------------------------------------------------------------------
946// Standalone helpers
947// ---------------------------------------------------------------------------
948
949/// Poll `/listMessages` for a given message box using a shared authenticated HTTP client.
950///
951/// This is used by the background polling task spawned in `listen_for_live_messages`
952/// to deliver messages as a fallback when the server does not broadcast via WS push.
953///
954/// Returns a `Vec<PeerMessage>` with decrypted bodies, or an empty vec on error.
955/// `accept_payments` is always false — the polling path does not handle delivery fees.
956async fn poll_list_messages<W>(
957 auth_fetch: &Arc<Mutex<AuthFetch<W>>>,
958 host: &str,
959 message_box: &str,
960 identity_key: &str,
961 wallet: &W,
962 originator: Option<&str>,
963) -> Result<Vec<PeerMessage>, MessageBoxError>
964where
965 W: WalletInterface + Clone + Send + Sync + 'static,
966{
967 let params = ListMessagesParams {
968 message_box: message_box.to_string(),
969 };
970 let body_bytes = serde_json::to_vec(¶ms)?;
971 let url = format!("{host}/listMessages");
972 let mut headers = HashMap::new();
973 headers.insert("content-type".to_string(), "application/json".to_string());
974
975 let response = auth_fetch
976 .lock()
977 .await
978 .fetch(&url, "POST", Some(body_bytes), Some(headers))
979 .await
980 .map_err(|e| MessageBoxError::Auth(e.to_string()))?;
981
982 if response.status < 200 || response.status >= 300 {
983 return Err(MessageBoxError::Http(response.status, url));
984 }
985
986 check_status_error(&response.body)?;
987
988 let list_response: ListMessagesResponse = serde_json::from_slice(&response.body)?;
989
990 let mut result = Vec::with_capacity(list_response.messages.len());
991 for msg in list_response.messages {
992 // Simple body extraction: if the body is a wrapped envelope, extract the message field.
993 let plain_body = extract_plain_body(&msg.body);
994 // Decrypt if encrypted
995 let decrypted =
996 crate::encryption::try_decrypt_message(wallet, &plain_body, &msg.sender, originator)
997 .await;
998
999 result.push(PeerMessage {
1000 message_id: msg.message_id,
1001 sender: msg.sender,
1002 recipient: identity_key.to_string(),
1003 message_box: message_box.to_string(),
1004 body: decrypted,
1005 });
1006 }
1007
1008 Ok(result)
1009}
1010
1011/// Extract the plain message body from a potentially server-wrapped envelope.
1012///
1013/// The server sometimes wraps messages as `{"message": "...", "payment": {...}}`.
1014/// This helper unwraps it. If the body isn't a wrapped envelope, returns it as-is.
1015fn extract_plain_body(body: &str) -> String {
1016 if let Ok(v) = serde_json::from_str::<serde_json::Value>(body) {
1017 if let Some(message) = v.get("message") {
1018 return match message {
1019 serde_json::Value::String(s) => s.clone(),
1020 other => other.to_string(),
1021 };
1022 }
1023 }
1024 body.to_string()
1025}
1026
1027/// Check if a successful (2xx) HTTP response body contains a server-level
1028/// error indicator (`{"status": "error", "description": "..."}`).
1029///
1030/// The MessageBox server can return HTTP 200 with a logical error payload —
1031/// this helper normalises that into `MessageBoxError::Auth`.
1032pub(crate) fn check_status_error(body: &[u8]) -> Result<(), MessageBoxError> {
1033 // Attempt a lightweight parse — ignore failures (malformed JSON is not
1034 // a server error in this sense).
1035 if let Ok(v) = serde_json::from_slice::<serde_json::Value>(body) {
1036 if v.get("status").and_then(|s| s.as_str()) == Some("error") {
1037 let description = v
1038 .get("description")
1039 .and_then(|d| d.as_str())
1040 .unwrap_or("unknown error")
1041 .to_string();
1042 return Err(MessageBoxError::Auth(description));
1043 }
1044 }
1045 Ok(())
1046}
1047
1048// ---------------------------------------------------------------------------
1049// Tests
1050// ---------------------------------------------------------------------------
1051
1052#[cfg(test)]
1053mod tests {
1054 use super::*;
1055 use bsv::primitives::private_key::PrivateKey;
1056 use bsv::services::overlay_tools::Network;
1057 use bsv::wallet::error::WalletError;
1058 use bsv::wallet::interfaces::*;
1059 use bsv::wallet::proto_wallet::ProtoWallet;
1060 use std::sync::Arc;
1061
1062 /// Thin Arc wrapper that makes ProtoWallet clone-able for test purposes.
1063 ///
1064 /// ProtoWallet does not implement Clone because it holds a non-Clone
1065 /// KeyDeriver. Wrapping in Arc satisfies the Clone bound while sharing
1066 /// the same underlying wallet across the clone and the AuthFetch instance.
1067 #[derive(Clone)]
1068 struct ArcWallet(Arc<ProtoWallet>);
1069
1070 impl ArcWallet {
1071 fn new() -> Self {
1072 let key = PrivateKey::from_random().expect("random key");
1073 ArcWallet(Arc::new(ProtoWallet::new(key)))
1074 }
1075 }
1076
1077 // Delegate every WalletInterface method to the inner ProtoWallet.
1078 #[async_trait::async_trait]
1079 impl WalletInterface for ArcWallet {
1080 async fn create_action(
1081 &self,
1082 args: CreateActionArgs,
1083 orig: Option<&str>,
1084 ) -> Result<CreateActionResult, WalletError> {
1085 self.0.create_action(args, orig).await
1086 }
1087 async fn sign_action(
1088 &self,
1089 args: SignActionArgs,
1090 orig: Option<&str>,
1091 ) -> Result<SignActionResult, WalletError> {
1092 self.0.sign_action(args, orig).await
1093 }
1094 async fn abort_action(
1095 &self,
1096 args: AbortActionArgs,
1097 orig: Option<&str>,
1098 ) -> Result<AbortActionResult, WalletError> {
1099 self.0.abort_action(args, orig).await
1100 }
1101 async fn list_actions(
1102 &self,
1103 args: ListActionsArgs,
1104 orig: Option<&str>,
1105 ) -> Result<ListActionsResult, WalletError> {
1106 self.0.list_actions(args, orig).await
1107 }
1108 async fn internalize_action(
1109 &self,
1110 args: InternalizeActionArgs,
1111 orig: Option<&str>,
1112 ) -> Result<InternalizeActionResult, WalletError> {
1113 self.0.internalize_action(args, orig).await
1114 }
1115 async fn list_outputs(
1116 &self,
1117 args: ListOutputsArgs,
1118 orig: Option<&str>,
1119 ) -> Result<ListOutputsResult, WalletError> {
1120 self.0.list_outputs(args, orig).await
1121 }
1122 async fn relinquish_output(
1123 &self,
1124 args: RelinquishOutputArgs,
1125 orig: Option<&str>,
1126 ) -> Result<RelinquishOutputResult, WalletError> {
1127 self.0.relinquish_output(args, orig).await
1128 }
1129 async fn get_public_key(
1130 &self,
1131 args: GetPublicKeyArgs,
1132 orig: Option<&str>,
1133 ) -> Result<GetPublicKeyResult, WalletError> {
1134 self.0.get_public_key(args, orig).await
1135 }
1136 async fn reveal_counterparty_key_linkage(
1137 &self,
1138 args: RevealCounterpartyKeyLinkageArgs,
1139 orig: Option<&str>,
1140 ) -> Result<RevealCounterpartyKeyLinkageResult, WalletError> {
1141 self.0.reveal_counterparty_key_linkage(args, orig).await
1142 }
1143 async fn reveal_specific_key_linkage(
1144 &self,
1145 args: RevealSpecificKeyLinkageArgs,
1146 orig: Option<&str>,
1147 ) -> Result<RevealSpecificKeyLinkageResult, WalletError> {
1148 self.0.reveal_specific_key_linkage(args, orig).await
1149 }
1150 async fn encrypt(
1151 &self,
1152 args: EncryptArgs,
1153 orig: Option<&str>,
1154 ) -> Result<EncryptResult, WalletError> {
1155 self.0.encrypt(args, orig).await
1156 }
1157 async fn decrypt(
1158 &self,
1159 args: DecryptArgs,
1160 orig: Option<&str>,
1161 ) -> Result<DecryptResult, WalletError> {
1162 self.0.decrypt(args, orig).await
1163 }
1164 async fn create_hmac(
1165 &self,
1166 args: CreateHmacArgs,
1167 orig: Option<&str>,
1168 ) -> Result<CreateHmacResult, WalletError> {
1169 self.0.create_hmac(args, orig).await
1170 }
1171 async fn verify_hmac(
1172 &self,
1173 args: VerifyHmacArgs,
1174 orig: Option<&str>,
1175 ) -> Result<VerifyHmacResult, WalletError> {
1176 self.0.verify_hmac(args, orig).await
1177 }
1178 async fn create_signature(
1179 &self,
1180 args: CreateSignatureArgs,
1181 orig: Option<&str>,
1182 ) -> Result<CreateSignatureResult, WalletError> {
1183 self.0.create_signature(args, orig).await
1184 }
1185 async fn verify_signature(
1186 &self,
1187 args: VerifySignatureArgs,
1188 orig: Option<&str>,
1189 ) -> Result<VerifySignatureResult, WalletError> {
1190 self.0.verify_signature(args, orig).await
1191 }
1192 async fn acquire_certificate(
1193 &self,
1194 args: AcquireCertificateArgs,
1195 orig: Option<&str>,
1196 ) -> Result<Certificate, WalletError> {
1197 self.0.acquire_certificate(args, orig).await
1198 }
1199 async fn list_certificates(
1200 &self,
1201 args: ListCertificatesArgs,
1202 orig: Option<&str>,
1203 ) -> Result<ListCertificatesResult, WalletError> {
1204 self.0.list_certificates(args, orig).await
1205 }
1206 async fn prove_certificate(
1207 &self,
1208 args: ProveCertificateArgs,
1209 orig: Option<&str>,
1210 ) -> Result<ProveCertificateResult, WalletError> {
1211 self.0.prove_certificate(args, orig).await
1212 }
1213 async fn relinquish_certificate(
1214 &self,
1215 args: RelinquishCertificateArgs,
1216 orig: Option<&str>,
1217 ) -> Result<RelinquishCertificateResult, WalletError> {
1218 self.0.relinquish_certificate(args, orig).await
1219 }
1220 async fn discover_by_identity_key(
1221 &self,
1222 args: DiscoverByIdentityKeyArgs,
1223 orig: Option<&str>,
1224 ) -> Result<DiscoverCertificatesResult, WalletError> {
1225 self.0.discover_by_identity_key(args, orig).await
1226 }
1227 async fn discover_by_attributes(
1228 &self,
1229 args: DiscoverByAttributesArgs,
1230 orig: Option<&str>,
1231 ) -> Result<DiscoverCertificatesResult, WalletError> {
1232 self.0.discover_by_attributes(args, orig).await
1233 }
1234 async fn is_authenticated(
1235 &self,
1236 orig: Option<&str>,
1237 ) -> Result<AuthenticatedResult, WalletError> {
1238 self.0.is_authenticated(orig).await
1239 }
1240 async fn wait_for_authentication(
1241 &self,
1242 orig: Option<&str>,
1243 ) -> Result<AuthenticatedResult, WalletError> {
1244 self.0.wait_for_authentication(orig).await
1245 }
1246 async fn get_height(&self, orig: Option<&str>) -> Result<GetHeightResult, WalletError> {
1247 self.0.get_height(orig).await
1248 }
1249 async fn get_header_for_height(
1250 &self,
1251 args: GetHeaderArgs,
1252 orig: Option<&str>,
1253 ) -> Result<GetHeaderResult, WalletError> {
1254 self.0.get_header_for_height(args, orig).await
1255 }
1256 async fn get_network(&self, orig: Option<&str>) -> Result<GetNetworkResult, WalletError> {
1257 self.0.get_network(orig).await
1258 }
1259 async fn get_version(&self, orig: Option<&str>) -> Result<GetVersionResult, WalletError> {
1260 self.0.get_version(orig).await
1261 }
1262 }
1263
1264 /// `new()` must trim leading/trailing whitespace from the host URL.
1265 #[tokio::test]
1266 async fn new_trims_host_url() {
1267 let wallet = ArcWallet::new();
1268 let client = MessageBoxClient::new(
1269 "https://example.com ".to_string(),
1270 wallet,
1271 None,
1272 Network::Mainnet,
1273 );
1274 assert_eq!(client.host(), "https://example.com");
1275 }
1276
1277 /// `get_identity_key` returns a non-empty hex string.
1278 #[tokio::test]
1279 async fn get_identity_key_returns_non_empty_hex() {
1280 let wallet = ArcWallet::new();
1281 let client = MessageBoxClient::new(
1282 "https://example.com".to_string(),
1283 wallet,
1284 None,
1285 Network::Mainnet,
1286 );
1287 let key = client.get_identity_key().await.expect("get_identity_key");
1288 assert!(!key.is_empty(), "identity key must be non-empty");
1289 assert!(
1290 key.chars().all(|c| c.is_ascii_hexdigit()),
1291 "identity key must be hex"
1292 );
1293 }
1294
1295 /// `get_json` exists — compile check via type coercion to async fn pointer.
1296 ///
1297 /// We verify the method resolves without calling it (no live network needed).
1298 #[allow(dead_code)]
1299 fn get_json_compiles(client: &MessageBoxClient<ArcWallet>) {
1300 // If get_json does not exist or has wrong signature, this fn fails to compile.
1301 let _fut = client.get_json("https://example.com/test");
1302 }
1303
1304 // -----------------------------------------------------------------------
1305 // Fix 2: Subscription registry tests
1306 // -----------------------------------------------------------------------
1307
1308 /// Subscription registry starts empty on construction.
1309 #[tokio::test]
1310 async fn subscription_registry_starts_empty() {
1311 let wallet = ArcWallet::new();
1312 let client = MessageBoxClient::new(
1313 "https://example.com".to_string(),
1314 wallet,
1315 None,
1316 Network::Mainnet,
1317 );
1318 let subs = client.subscriptions.lock().await;
1319 assert!(subs.is_empty(), "subscriptions must be empty on new client");
1320 }
1321
1322 /// Subscription registry can be populated and queried directly.
1323 ///
1324 /// This exercises the same path as listen_for_live_messages inserting into
1325 /// the registry. Full reconnect replay requires a live Socket.IO server.
1326 #[tokio::test]
1327 async fn subscription_registry_insert_and_lookup() {
1328 use std::sync::atomic::{AtomicBool, Ordering};
1329 let wallet = ArcWallet::new();
1330 let client = MessageBoxClient::new(
1331 "https://example.com".to_string(),
1332 wallet.clone(),
1333 None,
1334 Network::Mainnet,
1335 );
1336 let identity_key = client.get_identity_key().await.expect("identity key");
1337
1338 // Simulate what listen_for_live_messages does: build room_id and insert callback.
1339 let room_id = format!("{identity_key}-test_inbox");
1340 let fired = Arc::new(AtomicBool::new(false));
1341 let fired_clone = fired.clone();
1342 let callback: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
1343 Arc::new(move |_msg| {
1344 fired_clone.store(true, Ordering::SeqCst);
1345 });
1346
1347 client
1348 .subscriptions
1349 .lock()
1350 .await
1351 .insert(room_id.clone(), callback.clone());
1352
1353 // Verify it was stored
1354 let subs = client.subscriptions.lock().await;
1355 assert!(subs.contains_key(&room_id), "room_id must be in registry");
1356 assert_eq!(subs.len(), 1, "registry must have exactly one entry");
1357
1358 // Verify the stored callback is callable
1359 let cb = subs.get(&room_id).cloned().expect("callback must exist");
1360 drop(subs);
1361 cb(bsv::remittance::types::PeerMessage {
1362 message_id: "test".to_string(),
1363 sender: "03sender".to_string(),
1364 recipient: identity_key.clone(),
1365 message_box: "test_inbox".to_string(),
1366 body: "hello".to_string(),
1367 });
1368 assert!(fired.load(Ordering::SeqCst), "callback must have been invoked");
1369 }
1370
1371 /// Subscription registry entry is removed when leave_room is called — compile check.
1372 ///
1373 /// Full removal requires ws_state to be Some (live connection). Here we verify
1374 /// the direct remove path on the registry in isolation.
1375 #[tokio::test]
1376 async fn subscription_registry_remove_on_leave() {
1377 let wallet = ArcWallet::new();
1378 let client = MessageBoxClient::new(
1379 "https://example.com".to_string(),
1380 wallet,
1381 None,
1382 Network::Mainnet,
1383 );
1384 let identity_key = client.get_identity_key().await.expect("identity key");
1385 let room_id = format!("{identity_key}-inbox");
1386
1387 // Insert a dummy callback
1388 let cb: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
1389 Arc::new(|_| {});
1390 client.subscriptions.lock().await.insert(room_id.clone(), cb);
1391 assert_eq!(client.subscriptions.lock().await.len(), 1, "inserted");
1392
1393 // Remove it directly (simulates what leave_room does)
1394 client.subscriptions.lock().await.remove(&room_id);
1395 assert!(client.subscriptions.lock().await.is_empty(), "removed");
1396 }
1397
1398 /// `check_status_error` returns Ok for success body.
1399 #[test]
1400 fn check_status_error_passes_success_body() {
1401 use super::check_status_error;
1402 let body = br#"{"status":"success","data":{}}"#;
1403 assert!(check_status_error(body).is_ok());
1404 }
1405
1406 /// `check_status_error` returns Err for server error body.
1407 #[test]
1408 fn check_status_error_returns_err_for_error_body() {
1409 use super::check_status_error;
1410 let body = br#"{"status":"error","description":"permission denied"}"#;
1411 let err = check_status_error(body).unwrap_err();
1412 assert!(matches!(err, crate::error::MessageBoxError::Auth(_)));
1413 assert_eq!(err.to_string(), "auth error: permission denied");
1414 }
1415
1416 /// `get_identity_key` returns the same value on a second call (OnceCell cache).
1417 #[tokio::test]
1418 async fn get_identity_key_caches_result() {
1419 let wallet = ArcWallet::new();
1420 let client = MessageBoxClient::new(
1421 "https://example.com".to_string(),
1422 wallet,
1423 None,
1424 Network::Mainnet,
1425 );
1426 let key1 = client.get_identity_key().await.expect("first call");
1427 let key2 = client.get_identity_key().await.expect("second call");
1428 assert_eq!(key1, key2, "OnceCell must return the same value on re-call");
1429 }
1430
1431 /// `init_once` field is of type `OnceCell<()>` — compile check.
1432 ///
1433 /// The init_once field must be retained so assert_initialized can be wired
1434 /// through it in Phase 5. This test verifies the field type and existence.
1435 #[test]
1436 fn test_init_compiles() {
1437 let wallet = ArcWallet::new();
1438 let client = MessageBoxClient::new(
1439 "https://example.com".to_string(),
1440 wallet,
1441 None,
1442 Network::Mainnet,
1443 );
1444 // Verify the init_once field can be referenced and the public init() method exists.
1445 // If init_once were removed or its type changed, this compile-check fails.
1446 let _cell: &OnceCell<()> = &client.init_once;
1447 // Public init() must exist — verified by type resolution.
1448 // Verify init() exists and returns a Future — drop without awaiting.
1449 drop(client.init(None));
1450 }
1451
1452 // -----------------------------------------------------------------------
1453 // BoundedIdSet tests
1454 // -----------------------------------------------------------------------
1455
1456 /// BoundedIdSet rejects duplicate inserts.
1457 #[test]
1458 fn bounded_id_set_rejects_duplicates() {
1459 let mut set = super::BoundedIdSet::new(100);
1460 assert!(set.insert("a".to_string()), "first insert returns true");
1461 assert!(!set.insert("a".to_string()), "duplicate returns false");
1462 assert_eq!(set.len(), 1);
1463 }
1464
1465 /// BoundedIdSet evicts oldest entries when at capacity.
1466 #[test]
1467 fn bounded_id_set_evicts_oldest() {
1468 let mut set = super::BoundedIdSet::new(3);
1469 set.insert("a".to_string());
1470 set.insert("b".to_string());
1471 set.insert("c".to_string());
1472 // At capacity — next insert should evict "a"
1473 assert!(set.insert("d".to_string()));
1474 assert!(!set.contains("a"), "oldest entry must be evicted");
1475 assert!(set.contains("b"));
1476 assert!(set.contains("c"));
1477 assert!(set.contains("d"));
1478 // "a" is now unknown — should be insertable again
1479 assert!(
1480 set.insert("a".to_string()),
1481 "evicted entry can be re-inserted"
1482 );
1483 }
1484
1485 /// BoundedIdSet panics on capacity=0.
1486 #[test]
1487 #[should_panic(expected = "capacity must be at least 1")]
1488 fn bounded_id_set_rejects_zero_capacity() {
1489 super::BoundedIdSet::new(0);
1490 }
1491
1492 fn peer_msg(id: &str) -> bsv::remittance::types::PeerMessage {
1493 bsv::remittance::types::PeerMessage {
1494 message_id: id.to_string(),
1495 sender: "03sender".to_string(),
1496 recipient: "02recipient".to_string(),
1497 message_box: "inbox".to_string(),
1498 body: "body".to_string(),
1499 }
1500 }
1501
1502 #[test]
1503 fn exactly_once_delivers_each_message_id_once() {
1504 use std::sync::atomic::{AtomicUsize, Ordering};
1505 let count = Arc::new(AtomicUsize::new(0));
1506 let c = count.clone();
1507 let inner: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
1508 Arc::new(move |_m| {
1509 c.fetch_add(1, Ordering::SeqCst);
1510 });
1511 let deduped = super::exactly_once(inner);
1512
1513 // Same id arriving on three "paths" (WS dispatcher, WS on_any, poll).
1514 deduped(peer_msg("m1"));
1515 deduped(peer_msg("m1"));
1516 deduped(peer_msg("m1"));
1517 // A distinct id still gets through.
1518 deduped(peer_msg("m2"));
1519
1520 assert_eq!(count.load(Ordering::SeqCst), 2, "m1 once + m2 once");
1521 }
1522
1523 #[test]
1524 fn record_ws_activity_bumps_counter_and_forwards() {
1525 use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
1526 let activity = Arc::new(AtomicU64::new(0));
1527 let count = Arc::new(AtomicUsize::new(0));
1528 let c = count.clone();
1529 let inner: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
1530 Arc::new(move |_m| {
1531 c.fetch_add(1, Ordering::SeqCst);
1532 });
1533 let wrapped = super::record_ws_activity(inner, activity.clone());
1534
1535 wrapped(peer_msg("m1"));
1536 wrapped(peer_msg("m2"));
1537
1538 assert_eq!(activity.load(Ordering::Relaxed), 2, "counter bumped per delivery");
1539 assert_eq!(count.load(Ordering::SeqCst), 2, "inner callback forwarded each time");
1540 }
1541
1542 #[test]
1543 fn poll_should_run_runs_when_ws_quiet() {
1544 // current == last → WS delivered nothing this interval → poll runs.
1545 let mut last = 5u64;
1546 let mut skipped = 0u32;
1547 assert!(super::poll_should_run(5, &mut last, &mut skipped));
1548 assert_eq!(skipped, 0);
1549 }
1550
1551 #[test]
1552 fn poll_should_run_stands_down_when_ws_active() {
1553 // current != last → WS delivered → stand down (within the skip budget).
1554 let mut last = 5u64;
1555 let mut skipped = 0u32;
1556 assert!(!super::poll_should_run(6, &mut last, &mut skipped));
1557 assert_eq!(last, 6, "last_activity advances to current");
1558 assert_eq!(skipped, 1);
1559 }
1560
1561 #[test]
1562 fn poll_should_run_forces_catch_up_after_max_skips() {
1563 // A perpetually-active connection must still be polled at least every
1564 // MAX_POLL_SKIPS intervals so partial push loss can't suppress catch-up.
1565 let mut last = 0u64;
1566 let mut skipped = 0u32;
1567 let mut runs = 0u32;
1568 for tick in 1..=(super::MAX_POLL_SKIPS as u64 * 3) {
1569 // WS "delivers" every interval → counter always changes.
1570 if super::poll_should_run(tick, &mut last, &mut skipped) {
1571 runs += 1;
1572 }
1573 }
1574 // Over 3*MAX_POLL_SKIPS always-active intervals, the forced poll fires
1575 // roughly every (MAX_POLL_SKIPS+1) intervals — at least twice.
1576 assert!(runs >= 2, "forced catch-up must fire periodically, got {runs}");
1577 }
1578
1579 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1580 async fn exactly_once_is_safe_under_concurrent_delivery() {
1581 use std::sync::atomic::{AtomicUsize, Ordering};
1582 // The deduped callback is shared across the WS dispatcher, on_any, and
1583 // the poll task concurrently. Hammer the same ids from many tasks and
1584 // assert each id is delivered exactly once.
1585 let count = Arc::new(AtomicUsize::new(0));
1586 let c = count.clone();
1587 let inner: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
1588 Arc::new(move |_m| {
1589 c.fetch_add(1, Ordering::SeqCst);
1590 });
1591 let deduped = super::exactly_once(inner);
1592
1593 let mut handles = Vec::new();
1594 for _ in 0..8 {
1595 let cb = deduped.clone();
1596 handles.push(tokio::spawn(async move {
1597 for i in 0..100 {
1598 cb(peer_msg(&format!("m{i}")));
1599 }
1600 }));
1601 }
1602 for h in handles {
1603 h.await.unwrap();
1604 }
1605 // 100 distinct ids, each delivered exactly once despite 8 racing tasks.
1606 assert_eq!(count.load(Ordering::SeqCst), 100);
1607 }
1608
1609 #[test]
1610 fn exactly_once_composes_with_record_ws_activity() {
1611 // Mirrors the wiring in listen_for_live_messages: the WS path stamps
1612 // activity then delivers through the shared dedup; the poll path shares
1613 // the same dedup. A message delivered by WS must not be re-delivered by
1614 // a later poll of the same id.
1615 use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
1616 let activity = Arc::new(AtomicU64::new(0));
1617 let count = Arc::new(AtomicUsize::new(0));
1618 let c = count.clone();
1619 let inner: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
1620 Arc::new(move |_m| {
1621 c.fetch_add(1, Ordering::SeqCst);
1622 });
1623 let deduped = super::exactly_once(inner);
1624 let ws_path = super::record_ws_activity(deduped.clone(), activity.clone());
1625 let poll_path = deduped;
1626
1627 ws_path(peer_msg("m1")); // delivered via WS, stamps activity
1628 poll_path(peer_msg("m1")); // poll re-sees the same id → suppressed
1629
1630 assert_eq!(count.load(Ordering::SeqCst), 1, "delivered exactly once across paths");
1631 assert_eq!(activity.load(Ordering::Relaxed), 1, "only the WS path stamps activity");
1632 }
1633}