use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use bsv::auth::clients::auth_fetch::{AuthFetch, AuthFetchResponse};
use bsv::remittance::types::PeerMessage;
use bsv::services::overlay_tools::Network;
use bsv::wallet::interfaces::{GetPublicKeyArgs, WalletInterface};
use tokio::sync::{Mutex, OnceCell};
use crate::delivery::DeliveryMode;
use crate::error::MessageBoxError;
use crate::types::{ListMessagesParams, ListMessagesResponse};
type SubscriptionCallback = Arc<dyn Fn(PeerMessage) + Send + Sync>;
pub struct MessageBoxClient<W: WalletInterface + Clone + 'static> {
host: String,
auth_fetch: Arc<Mutex<AuthFetch<W>>>,
wallet: W,
originator: Option<String>,
identity_key: OnceCell<String>,
pub(crate) init_once: OnceCell<()>,
pub(crate) network: Network,
ws_state: Mutex<Option<crate::websocket::MessageBoxWebSocket>>,
joined_rooms: Arc<Mutex<std::collections::HashSet<String>>>,
subscriptions: Arc<Mutex<HashMap<String, SubscriptionCallback>>>,
}
impl<W: WalletInterface + Clone + 'static + Send + Sync> MessageBoxClient<W> {
pub fn new(host: String, wallet: W, originator: Option<String>, network: Network) -> Self {
MessageBoxClient {
host: host.trim().to_string(),
auth_fetch: Arc::new(Mutex::new(AuthFetch::new(wallet.clone()))),
wallet,
originator,
identity_key: OnceCell::new(),
init_once: OnceCell::new(),
network,
ws_state: Mutex::new(None),
joined_rooms: Arc::new(Mutex::new(std::collections::HashSet::new())),
subscriptions: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn new_mainnet(host: String, wallet: W, originator: Option<String>) -> Self {
Self::new(host, wallet, originator, Network::Mainnet)
}
pub fn host(&self) -> &str {
&self.host
}
pub fn wallet(&self) -> &W {
&self.wallet
}
pub fn originator(&self) -> Option<&str> {
self.originator.as_deref()
}
pub fn network(&self) -> &Network {
&self.network
}
pub async fn get_identity_key(&self) -> Result<String, MessageBoxError> {
if let Some(k) = self.identity_key.get() {
return Ok(k.clone());
}
let result = self
.wallet
.get_public_key(
GetPublicKeyArgs {
identity_key: true,
protocol_id: None,
key_id: None,
counterparty: None,
privileged: false,
privileged_reason: None,
for_self: None,
seek_permission: None,
},
self.originator.as_deref(),
)
.await
.map_err(|e| MessageBoxError::Wallet(e.to_string()))?;
let key = result.public_key.to_der_hex();
let _ = self.identity_key.set(key.clone());
Ok(key)
}
pub(crate) async fn assert_initialized(&self) -> Result<(), MessageBoxError> {
self.init_once
.get_or_try_init(|| async {
let identity_key = self.get_identity_key().await?;
let ads = self
.query_advertisements(Some(&identity_key), Some(&self.host))
.await
.unwrap_or_default();
if ads.iter().all(|ad| ad.host.trim() != self.host.trim()) {
if let Err(e) = self.anoint_host(&self.host).await {
eprintln!("Warning: failed to anoint host: {e}");
}
}
Ok::<(), MessageBoxError>(())
})
.await?;
Ok(())
}
pub async fn init(&self, target_host: Option<&str>) -> Result<(), MessageBoxError> {
match target_host {
Some(host) => {
self.init_once
.get_or_try_init(|| async {
let _identity_key = self.get_identity_key().await?;
if let Err(e) = self.anoint_host(host).await {
eprintln!("Warning: failed to anoint host: {e}");
}
Ok::<(), MessageBoxError>(())
})
.await?;
Ok(())
}
None => self.assert_initialized().await,
}
}
pub async fn initialize_connection(
&self,
override_host: Option<&str>,
) -> Result<(), MessageBoxError> {
self.ensure_ws_connected(override_host).await
}
pub fn get_joined_rooms(&self) -> std::collections::HashSet<String> {
self.joined_rooms.blocking_lock().clone()
}
#[cfg(test)]
pub fn test_socket(&self) -> Option<bool> {
let guard = self.ws_state.blocking_lock();
guard.as_ref().map(|ws| ws.is_connected())
}
pub async fn is_ws_connected(&self) -> Option<bool> {
let guard = self.ws_state.lock().await;
guard.as_ref().map(|ws| ws.is_connected())
}
pub async fn join_room(&self, message_box: &str) -> Result<(), MessageBoxError> {
let identity_key = self.get_identity_key().await?;
let room_id = format!("{identity_key}-{message_box}");
self.ensure_ws_connected(None).await?;
{
let guard = self.ws_state.lock().await;
if let Some(ref ws) = *guard {
ws.join_room(&room_id).await?;
}
}
self.joined_rooms.lock().await.insert(room_id);
Ok(())
}
pub(crate) async fn post_json(
&self,
url: &str,
body_bytes: Vec<u8>,
) -> Result<AuthFetchResponse, MessageBoxError> {
let mut headers = HashMap::new();
headers.insert("content-type".to_string(), "application/json".to_string());
let response = self
.auth_fetch
.lock()
.await
.fetch(url, "POST", Some(body_bytes), Some(headers))
.await
.map_err(|e| MessageBoxError::Auth(e.to_string()))?;
if response.status < 200 || response.status >= 300 {
return Err(MessageBoxError::Http(response.status, url.to_string()));
}
Ok(response)
}
async fn ensure_ws_connected(
&self,
override_host: Option<&str>,
) -> Result<(), MessageBoxError> {
let mut guard = self.ws_state.lock().await;
if guard.as_ref().map(|ws| ws.is_connected()).unwrap_or(false) {
return Ok(());
}
let identity_key = self.get_identity_key().await?;
let ws_url = override_host.unwrap_or_else(|| self.host()).to_string();
let ws = crate::websocket::MessageBoxWebSocket::connect(
&ws_url,
&identity_key,
self.wallet.clone(),
self.originator.clone(),
)
.await?;
{
let subs = self.subscriptions.lock().await;
for (room_id, callback) in subs.iter() {
let event_key = format!("sendMessage-{room_id}");
if let Err(e) = ws.join_room(room_id).await {
tracing::warn!(room_id, error = %e, "joinRoom replay failed on reconnect");
} else {
ws.subscribe(event_key, callback.clone()).await;
tracing::info!(room_id, "replayed subscription on reconnected socket");
}
}
}
*guard = Some(ws);
Ok(())
}
pub async fn listen_for_live_messages(
&self,
message_box: &str,
on_message: Arc<dyn Fn(PeerMessage) + Send + Sync>,
override_host: Option<&str>,
) -> Result<(), MessageBoxError> {
let identity_key = self.get_identity_key().await?;
let room_id = format!("{identity_key}-{message_box}");
let event_key = format!("sendMessage-{room_id}");
self.ensure_ws_connected(override_host).await?;
let deduped = exactly_once(on_message);
let ws_activity = Arc::new(std::sync::atomic::AtomicU64::new(0));
let ws_callback = record_ws_activity(deduped.clone(), ws_activity.clone());
{
let guard = self.ws_state.lock().await;
if let Some(ref ws) = *guard {
ws.join_room(&room_id).await?;
ws.subscribe(event_key.clone(), ws_callback.clone()).await;
}
}
self.subscriptions.lock().await.insert(room_id.clone(), ws_callback.clone());
self.joined_rooms.lock().await.insert(room_id.clone());
let poll_auth_fetch = self.auth_fetch.clone();
let poll_joined_rooms = self.joined_rooms.clone();
let poll_host = self.host.clone();
let poll_message_box = message_box.to_string();
let poll_identity_key = identity_key.clone();
let poll_wallet = self.wallet.clone();
let poll_originator = self.originator.clone();
let poll_room_id = room_id.clone();
let poll_callback = deduped;
let poll_ws_activity = ws_activity;
tokio::spawn(async move {
use std::sync::atomic::Ordering;
let mut last_activity = poll_ws_activity.load(Ordering::Relaxed);
let mut skipped: u32 = 0;
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
if !poll_joined_rooms.lock().await.contains(&poll_room_id) {
break;
}
let activity = poll_ws_activity.load(Ordering::Relaxed);
if !poll_should_run(activity, &mut last_activity, &mut skipped) {
continue;
}
match poll_list_messages(
&poll_auth_fetch,
&poll_host,
&poll_message_box,
&poll_identity_key,
&poll_wallet,
poll_originator.as_deref(),
)
.await
{
Ok(messages) => {
for msg in messages {
poll_callback(msg);
}
}
Err(e) => {
tracing::warn!(
room_id = %poll_room_id,
error = %e,
"poll backstop failed — message catch-up unavailable this interval"
);
}
}
last_activity = poll_ws_activity.load(Ordering::Relaxed);
}
});
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub async fn send_live_message(
&self,
recipient: &str,
message_box: &str,
body: &str,
skip_encryption: bool,
check_permissions: bool,
message_id: Option<&str>,
override_host: Option<&str>,
) -> Result<DeliveryMode, MessageBoxError> {
if override_host.is_none() {
let resolved = self
.resolve_host_for_recipient(recipient)
.await
.unwrap_or_else(|_| self.host().to_string());
if resolved.trim() != self.host().trim() {
let msg_id = self
.send_message(
recipient,
message_box,
body,
skip_encryption,
check_permissions,
message_id,
None,
)
.await?;
return Ok(DeliveryMode::Persisted { message_id: msg_id });
}
}
{
let guard = self.ws_state.lock().await;
if guard.as_ref().map(|ws| !ws.is_connected()).unwrap_or(false) {
drop(guard);
if let Err(e) = self.disconnect_web_socket().await {
eprintln!("Warning: stale WebSocket disconnect failed (proceeding to reconnect): {e}");
}
}
}
if let Err(e) = self.ensure_ws_connected(override_host).await {
eprintln!("Warning: WebSocket connection failed, falling back to HTTP: {e}");
let msg_id = match override_host {
Some(host) => {
self.send_message_to_host(
host,
recipient,
message_box,
body,
skip_encryption,
check_permissions,
message_id,
None,
)
.await?
}
None => {
self.send_message(
recipient,
message_box,
body,
skip_encryption,
check_permissions,
message_id,
None,
)
.await?
}
};
return Ok(DeliveryMode::Persisted { message_id: msg_id });
}
let identity_key = self.get_identity_key().await?;
let my_room = format!("{identity_key}-{message_box}");
{
let guard = self.ws_state.lock().await;
if let Some(ref ws) = *guard {
if ws.join_room(&my_room).await.is_err() {
drop(guard);
let msg_id = match override_host {
Some(host) => {
self.send_message_to_host(
host,
recipient,
message_box,
body,
skip_encryption,
check_permissions,
message_id,
None,
)
.await?
}
None => {
self.send_message(
recipient,
message_box,
body,
skip_encryption,
check_permissions,
message_id,
None,
)
.await?
}
};
return Ok(DeliveryMode::Persisted { message_id: msg_id });
}
} else {
drop(guard);
let msg_id = match override_host {
Some(host) => {
self.send_message_to_host(
host,
recipient,
message_box,
body,
skip_encryption,
check_permissions,
message_id,
None,
)
.await?
}
None => {
self.send_message(
recipient,
message_box,
body,
skip_encryption,
check_permissions,
message_id,
None,
)
.await?
}
};
return Ok(DeliveryMode::Persisted { message_id: msg_id });
}
}
let encrypted = if skip_encryption {
body.to_string()
} else {
crate::encryption::encrypt_body(self.wallet(), body, recipient, self.originator())
.await?
};
let message_id = if let Some(id) = message_id {
id.to_string()
} else {
crate::encryption::generate_message_id(
self.wallet(),
body,
recipient,
self.originator(),
)
.await?
};
let room_id = format!("{recipient}-{message_box}");
let ack_key = format!("sendMessageAck-{room_id}");
let (ack_tx, ack_rx) = tokio::sync::oneshot::channel::<bool>();
let payload = serde_json::json!({
"roomId": room_id,
"message": {
"messageId": message_id,
"recipient": recipient,
"body": encrypted
}
});
{
let guard = self.ws_state.lock().await;
if let Some(ref ws) = *guard {
ws.emit_send_message(payload, ack_key.clone(), ack_tx)
.await?;
}
}
match tokio::time::timeout(std::time::Duration::from_secs(10), ack_rx).await {
Ok(Ok(true)) => Ok(DeliveryMode::Live { message_id }),
_ => {
let guard = self.ws_state.lock().await;
if let Some(ref ws) = *guard {
ws.remove_pending_ack(&ack_key).await;
}
drop(guard);
tracing::debug!(
"send_live_message: WS ack timed out or failed; falling back to HTTP"
);
let http_id = match override_host {
Some(host) => {
self.send_message_to_host(
host,
recipient,
message_box,
body,
skip_encryption,
check_permissions,
None,
None,
)
.await?
}
None => {
self.send_message(
recipient,
message_box,
body,
skip_encryption,
check_permissions,
None,
None,
)
.await?
}
};
Ok(DeliveryMode::Persisted { message_id: http_id })
}
}
}
pub async fn leave_room(
&self,
message_box: &str,
override_host: Option<&str>,
) -> Result<(), MessageBoxError> {
let _ = override_host;
let identity_key = self.get_identity_key().await?;
let room_id = format!("{identity_key}-{message_box}");
{
let guard = self.ws_state.lock().await;
if let Some(ref ws) = *guard {
ws.leave_room(&room_id).await?;
}
}
self.joined_rooms.lock().await.remove(&room_id);
self.subscriptions.lock().await.remove(&room_id);
Ok(())
}
pub async fn disconnect_web_socket(&self) -> Result<(), MessageBoxError> {
let mut guard = self.ws_state.lock().await;
if let Some(ws) = guard.take() {
ws.disconnect().await?;
}
Ok(())
}
pub(crate) async fn get_json(&self, url: &str) -> Result<AuthFetchResponse, MessageBoxError> {
let response = self
.auth_fetch
.lock()
.await
.fetch(url, "GET", None, None)
.await
.map_err(|e| MessageBoxError::Auth(e.to_string()))?;
if response.status < 200 || response.status >= 300 {
return Err(MessageBoxError::Http(response.status, url.to_string()));
}
Ok(response)
}
}
const MAX_POLL_SKIPS: u32 = 7;
fn exactly_once(
inner: Arc<dyn Fn(PeerMessage) + Send + Sync>,
) -> Arc<dyn Fn(PeerMessage) + Send + Sync> {
let seen = Arc::new(std::sync::Mutex::new(BoundedIdSet::new(10_000)));
Arc::new(move |msg: PeerMessage| {
let fresh = match seen.lock() {
Ok(mut g) => g.insert(msg.message_id.clone()),
Err(poisoned) => poisoned.into_inner().insert(msg.message_id.clone()),
};
if fresh {
inner(msg);
}
})
}
fn record_ws_activity(
inner: Arc<dyn Fn(PeerMessage) + Send + Sync>,
activity: Arc<std::sync::atomic::AtomicU64>,
) -> Arc<dyn Fn(PeerMessage) + Send + Sync> {
Arc::new(move |msg: PeerMessage| {
activity.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
inner(msg);
})
}
fn poll_should_run(current_activity: u64, last_activity: &mut u64, skipped: &mut u32) -> bool {
if current_activity != *last_activity && *skipped < MAX_POLL_SKIPS {
*last_activity = current_activity;
*skipped += 1;
return false; }
*last_activity = current_activity;
*skipped = 0;
true
}
struct BoundedIdSet {
set: HashSet<String>,
order: VecDeque<String>,
capacity: usize,
}
impl BoundedIdSet {
fn new(capacity: usize) -> Self {
assert!(capacity > 0, "BoundedIdSet capacity must be at least 1");
Self {
set: HashSet::with_capacity(capacity),
order: VecDeque::with_capacity(capacity),
capacity,
}
}
fn insert(&mut self, id: String) -> bool {
if self.set.contains(&id) {
return false;
}
if self.order.len() >= self.capacity {
if let Some(old) = self.order.pop_front() {
self.set.remove(&old);
}
}
self.set.insert(id.clone());
self.order.push_back(id);
debug_assert_eq!(self.set.len(), self.order.len(),
"BoundedIdSet internal invariant violated: set/deque size mismatch");
true
}
#[cfg(test)]
fn len(&self) -> usize {
self.order.len()
}
#[cfg(test)]
fn contains(&self, id: &str) -> bool {
self.set.contains(id)
}
}
async fn poll_list_messages<W>(
auth_fetch: &Arc<Mutex<AuthFetch<W>>>,
host: &str,
message_box: &str,
identity_key: &str,
wallet: &W,
originator: Option<&str>,
) -> Result<Vec<PeerMessage>, MessageBoxError>
where
W: WalletInterface + Clone + Send + Sync + 'static,
{
let params = ListMessagesParams {
message_box: message_box.to_string(),
};
let body_bytes = serde_json::to_vec(¶ms)?;
let url = format!("{host}/listMessages");
let mut headers = HashMap::new();
headers.insert("content-type".to_string(), "application/json".to_string());
let response = auth_fetch
.lock()
.await
.fetch(&url, "POST", Some(body_bytes), Some(headers))
.await
.map_err(|e| MessageBoxError::Auth(e.to_string()))?;
if response.status < 200 || response.status >= 300 {
return Err(MessageBoxError::Http(response.status, url));
}
check_status_error(&response.body)?;
let list_response: ListMessagesResponse = serde_json::from_slice(&response.body)?;
let mut result = Vec::with_capacity(list_response.messages.len());
for msg in list_response.messages {
let plain_body = extract_plain_body(&msg.body);
let decrypted =
crate::encryption::try_decrypt_message(wallet, &plain_body, &msg.sender, originator)
.await;
result.push(PeerMessage {
message_id: msg.message_id,
sender: msg.sender,
recipient: identity_key.to_string(),
message_box: message_box.to_string(),
body: decrypted,
});
}
Ok(result)
}
fn extract_plain_body(body: &str) -> String {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(body) {
if let Some(message) = v.get("message") {
return match message {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
}
}
body.to_string()
}
pub(crate) fn check_status_error(body: &[u8]) -> Result<(), MessageBoxError> {
if let Ok(v) = serde_json::from_slice::<serde_json::Value>(body) {
if v.get("status").and_then(|s| s.as_str()) == Some("error") {
let description = v
.get("description")
.and_then(|d| d.as_str())
.unwrap_or("unknown error")
.to_string();
return Err(MessageBoxError::Auth(description));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use bsv::primitives::private_key::PrivateKey;
use bsv::services::overlay_tools::Network;
use bsv::wallet::error::WalletError;
use bsv::wallet::interfaces::*;
use bsv::wallet::proto_wallet::ProtoWallet;
use std::sync::Arc;
#[derive(Clone)]
struct ArcWallet(Arc<ProtoWallet>);
impl ArcWallet {
fn new() -> Self {
let key = PrivateKey::from_random().expect("random key");
ArcWallet(Arc::new(ProtoWallet::new(key)))
}
}
#[async_trait::async_trait]
impl WalletInterface for ArcWallet {
async fn create_action(
&self,
args: CreateActionArgs,
orig: Option<&str>,
) -> Result<CreateActionResult, WalletError> {
self.0.create_action(args, orig).await
}
async fn sign_action(
&self,
args: SignActionArgs,
orig: Option<&str>,
) -> Result<SignActionResult, WalletError> {
self.0.sign_action(args, orig).await
}
async fn abort_action(
&self,
args: AbortActionArgs,
orig: Option<&str>,
) -> Result<AbortActionResult, WalletError> {
self.0.abort_action(args, orig).await
}
async fn list_actions(
&self,
args: ListActionsArgs,
orig: Option<&str>,
) -> Result<ListActionsResult, WalletError> {
self.0.list_actions(args, orig).await
}
async fn internalize_action(
&self,
args: InternalizeActionArgs,
orig: Option<&str>,
) -> Result<InternalizeActionResult, WalletError> {
self.0.internalize_action(args, orig).await
}
async fn list_outputs(
&self,
args: ListOutputsArgs,
orig: Option<&str>,
) -> Result<ListOutputsResult, WalletError> {
self.0.list_outputs(args, orig).await
}
async fn relinquish_output(
&self,
args: RelinquishOutputArgs,
orig: Option<&str>,
) -> Result<RelinquishOutputResult, WalletError> {
self.0.relinquish_output(args, orig).await
}
async fn get_public_key(
&self,
args: GetPublicKeyArgs,
orig: Option<&str>,
) -> Result<GetPublicKeyResult, WalletError> {
self.0.get_public_key(args, orig).await
}
async fn reveal_counterparty_key_linkage(
&self,
args: RevealCounterpartyKeyLinkageArgs,
orig: Option<&str>,
) -> Result<RevealCounterpartyKeyLinkageResult, WalletError> {
self.0.reveal_counterparty_key_linkage(args, orig).await
}
async fn reveal_specific_key_linkage(
&self,
args: RevealSpecificKeyLinkageArgs,
orig: Option<&str>,
) -> Result<RevealSpecificKeyLinkageResult, WalletError> {
self.0.reveal_specific_key_linkage(args, orig).await
}
async fn encrypt(
&self,
args: EncryptArgs,
orig: Option<&str>,
) -> Result<EncryptResult, WalletError> {
self.0.encrypt(args, orig).await
}
async fn decrypt(
&self,
args: DecryptArgs,
orig: Option<&str>,
) -> Result<DecryptResult, WalletError> {
self.0.decrypt(args, orig).await
}
async fn create_hmac(
&self,
args: CreateHmacArgs,
orig: Option<&str>,
) -> Result<CreateHmacResult, WalletError> {
self.0.create_hmac(args, orig).await
}
async fn verify_hmac(
&self,
args: VerifyHmacArgs,
orig: Option<&str>,
) -> Result<VerifyHmacResult, WalletError> {
self.0.verify_hmac(args, orig).await
}
async fn create_signature(
&self,
args: CreateSignatureArgs,
orig: Option<&str>,
) -> Result<CreateSignatureResult, WalletError> {
self.0.create_signature(args, orig).await
}
async fn verify_signature(
&self,
args: VerifySignatureArgs,
orig: Option<&str>,
) -> Result<VerifySignatureResult, WalletError> {
self.0.verify_signature(args, orig).await
}
async fn acquire_certificate(
&self,
args: AcquireCertificateArgs,
orig: Option<&str>,
) -> Result<Certificate, WalletError> {
self.0.acquire_certificate(args, orig).await
}
async fn list_certificates(
&self,
args: ListCertificatesArgs,
orig: Option<&str>,
) -> Result<ListCertificatesResult, WalletError> {
self.0.list_certificates(args, orig).await
}
async fn prove_certificate(
&self,
args: ProveCertificateArgs,
orig: Option<&str>,
) -> Result<ProveCertificateResult, WalletError> {
self.0.prove_certificate(args, orig).await
}
async fn relinquish_certificate(
&self,
args: RelinquishCertificateArgs,
orig: Option<&str>,
) -> Result<RelinquishCertificateResult, WalletError> {
self.0.relinquish_certificate(args, orig).await
}
async fn discover_by_identity_key(
&self,
args: DiscoverByIdentityKeyArgs,
orig: Option<&str>,
) -> Result<DiscoverCertificatesResult, WalletError> {
self.0.discover_by_identity_key(args, orig).await
}
async fn discover_by_attributes(
&self,
args: DiscoverByAttributesArgs,
orig: Option<&str>,
) -> Result<DiscoverCertificatesResult, WalletError> {
self.0.discover_by_attributes(args, orig).await
}
async fn is_authenticated(
&self,
orig: Option<&str>,
) -> Result<AuthenticatedResult, WalletError> {
self.0.is_authenticated(orig).await
}
async fn wait_for_authentication(
&self,
orig: Option<&str>,
) -> Result<AuthenticatedResult, WalletError> {
self.0.wait_for_authentication(orig).await
}
async fn get_height(&self, orig: Option<&str>) -> Result<GetHeightResult, WalletError> {
self.0.get_height(orig).await
}
async fn get_header_for_height(
&self,
args: GetHeaderArgs,
orig: Option<&str>,
) -> Result<GetHeaderResult, WalletError> {
self.0.get_header_for_height(args, orig).await
}
async fn get_network(&self, orig: Option<&str>) -> Result<GetNetworkResult, WalletError> {
self.0.get_network(orig).await
}
async fn get_version(&self, orig: Option<&str>) -> Result<GetVersionResult, WalletError> {
self.0.get_version(orig).await
}
}
#[tokio::test]
async fn new_trims_host_url() {
let wallet = ArcWallet::new();
let client = MessageBoxClient::new(
"https://example.com ".to_string(),
wallet,
None,
Network::Mainnet,
);
assert_eq!(client.host(), "https://example.com");
}
#[tokio::test]
async fn get_identity_key_returns_non_empty_hex() {
let wallet = ArcWallet::new();
let client = MessageBoxClient::new(
"https://example.com".to_string(),
wallet,
None,
Network::Mainnet,
);
let key = client.get_identity_key().await.expect("get_identity_key");
assert!(!key.is_empty(), "identity key must be non-empty");
assert!(
key.chars().all(|c| c.is_ascii_hexdigit()),
"identity key must be hex"
);
}
#[allow(dead_code)]
fn get_json_compiles(client: &MessageBoxClient<ArcWallet>) {
let _fut = client.get_json("https://example.com/test");
}
#[tokio::test]
async fn subscription_registry_starts_empty() {
let wallet = ArcWallet::new();
let client = MessageBoxClient::new(
"https://example.com".to_string(),
wallet,
None,
Network::Mainnet,
);
let subs = client.subscriptions.lock().await;
assert!(subs.is_empty(), "subscriptions must be empty on new client");
}
#[tokio::test]
async fn subscription_registry_insert_and_lookup() {
use std::sync::atomic::{AtomicBool, Ordering};
let wallet = ArcWallet::new();
let client = MessageBoxClient::new(
"https://example.com".to_string(),
wallet.clone(),
None,
Network::Mainnet,
);
let identity_key = client.get_identity_key().await.expect("identity key");
let room_id = format!("{identity_key}-test_inbox");
let fired = Arc::new(AtomicBool::new(false));
let fired_clone = fired.clone();
let callback: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
Arc::new(move |_msg| {
fired_clone.store(true, Ordering::SeqCst);
});
client
.subscriptions
.lock()
.await
.insert(room_id.clone(), callback.clone());
let subs = client.subscriptions.lock().await;
assert!(subs.contains_key(&room_id), "room_id must be in registry");
assert_eq!(subs.len(), 1, "registry must have exactly one entry");
let cb = subs.get(&room_id).cloned().expect("callback must exist");
drop(subs);
cb(bsv::remittance::types::PeerMessage {
message_id: "test".to_string(),
sender: "03sender".to_string(),
recipient: identity_key.clone(),
message_box: "test_inbox".to_string(),
body: "hello".to_string(),
});
assert!(fired.load(Ordering::SeqCst), "callback must have been invoked");
}
#[tokio::test]
async fn subscription_registry_remove_on_leave() {
let wallet = ArcWallet::new();
let client = MessageBoxClient::new(
"https://example.com".to_string(),
wallet,
None,
Network::Mainnet,
);
let identity_key = client.get_identity_key().await.expect("identity key");
let room_id = format!("{identity_key}-inbox");
let cb: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
Arc::new(|_| {});
client.subscriptions.lock().await.insert(room_id.clone(), cb);
assert_eq!(client.subscriptions.lock().await.len(), 1, "inserted");
client.subscriptions.lock().await.remove(&room_id);
assert!(client.subscriptions.lock().await.is_empty(), "removed");
}
#[test]
fn check_status_error_passes_success_body() {
use super::check_status_error;
let body = br#"{"status":"success","data":{}}"#;
assert!(check_status_error(body).is_ok());
}
#[test]
fn check_status_error_returns_err_for_error_body() {
use super::check_status_error;
let body = br#"{"status":"error","description":"permission denied"}"#;
let err = check_status_error(body).unwrap_err();
assert!(matches!(err, crate::error::MessageBoxError::Auth(_)));
assert_eq!(err.to_string(), "auth error: permission denied");
}
#[tokio::test]
async fn get_identity_key_caches_result() {
let wallet = ArcWallet::new();
let client = MessageBoxClient::new(
"https://example.com".to_string(),
wallet,
None,
Network::Mainnet,
);
let key1 = client.get_identity_key().await.expect("first call");
let key2 = client.get_identity_key().await.expect("second call");
assert_eq!(key1, key2, "OnceCell must return the same value on re-call");
}
#[test]
fn test_init_compiles() {
let wallet = ArcWallet::new();
let client = MessageBoxClient::new(
"https://example.com".to_string(),
wallet,
None,
Network::Mainnet,
);
let _cell: &OnceCell<()> = &client.init_once;
drop(client.init(None));
}
#[test]
fn bounded_id_set_rejects_duplicates() {
let mut set = super::BoundedIdSet::new(100);
assert!(set.insert("a".to_string()), "first insert returns true");
assert!(!set.insert("a".to_string()), "duplicate returns false");
assert_eq!(set.len(), 1);
}
#[test]
fn bounded_id_set_evicts_oldest() {
let mut set = super::BoundedIdSet::new(3);
set.insert("a".to_string());
set.insert("b".to_string());
set.insert("c".to_string());
assert!(set.insert("d".to_string()));
assert!(!set.contains("a"), "oldest entry must be evicted");
assert!(set.contains("b"));
assert!(set.contains("c"));
assert!(set.contains("d"));
assert!(
set.insert("a".to_string()),
"evicted entry can be re-inserted"
);
}
#[test]
#[should_panic(expected = "capacity must be at least 1")]
fn bounded_id_set_rejects_zero_capacity() {
super::BoundedIdSet::new(0);
}
fn peer_msg(id: &str) -> bsv::remittance::types::PeerMessage {
bsv::remittance::types::PeerMessage {
message_id: id.to_string(),
sender: "03sender".to_string(),
recipient: "02recipient".to_string(),
message_box: "inbox".to_string(),
body: "body".to_string(),
}
}
#[test]
fn exactly_once_delivers_each_message_id_once() {
use std::sync::atomic::{AtomicUsize, Ordering};
let count = Arc::new(AtomicUsize::new(0));
let c = count.clone();
let inner: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
Arc::new(move |_m| {
c.fetch_add(1, Ordering::SeqCst);
});
let deduped = super::exactly_once(inner);
deduped(peer_msg("m1"));
deduped(peer_msg("m1"));
deduped(peer_msg("m1"));
deduped(peer_msg("m2"));
assert_eq!(count.load(Ordering::SeqCst), 2, "m1 once + m2 once");
}
#[test]
fn record_ws_activity_bumps_counter_and_forwards() {
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
let activity = Arc::new(AtomicU64::new(0));
let count = Arc::new(AtomicUsize::new(0));
let c = count.clone();
let inner: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
Arc::new(move |_m| {
c.fetch_add(1, Ordering::SeqCst);
});
let wrapped = super::record_ws_activity(inner, activity.clone());
wrapped(peer_msg("m1"));
wrapped(peer_msg("m2"));
assert_eq!(activity.load(Ordering::Relaxed), 2, "counter bumped per delivery");
assert_eq!(count.load(Ordering::SeqCst), 2, "inner callback forwarded each time");
}
#[test]
fn poll_should_run_runs_when_ws_quiet() {
let mut last = 5u64;
let mut skipped = 0u32;
assert!(super::poll_should_run(5, &mut last, &mut skipped));
assert_eq!(skipped, 0);
}
#[test]
fn poll_should_run_stands_down_when_ws_active() {
let mut last = 5u64;
let mut skipped = 0u32;
assert!(!super::poll_should_run(6, &mut last, &mut skipped));
assert_eq!(last, 6, "last_activity advances to current");
assert_eq!(skipped, 1);
}
#[test]
fn poll_should_run_forces_catch_up_after_max_skips() {
let mut last = 0u64;
let mut skipped = 0u32;
let mut runs = 0u32;
for tick in 1..=(super::MAX_POLL_SKIPS as u64 * 3) {
if super::poll_should_run(tick, &mut last, &mut skipped) {
runs += 1;
}
}
assert!(runs >= 2, "forced catch-up must fire periodically, got {runs}");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn exactly_once_is_safe_under_concurrent_delivery() {
use std::sync::atomic::{AtomicUsize, Ordering};
let count = Arc::new(AtomicUsize::new(0));
let c = count.clone();
let inner: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
Arc::new(move |_m| {
c.fetch_add(1, Ordering::SeqCst);
});
let deduped = super::exactly_once(inner);
let mut handles = Vec::new();
for _ in 0..8 {
let cb = deduped.clone();
handles.push(tokio::spawn(async move {
for i in 0..100 {
cb(peer_msg(&format!("m{i}")));
}
}));
}
for h in handles {
h.await.unwrap();
}
assert_eq!(count.load(Ordering::SeqCst), 100);
}
#[test]
fn exactly_once_composes_with_record_ws_activity() {
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
let activity = Arc::new(AtomicU64::new(0));
let count = Arc::new(AtomicUsize::new(0));
let c = count.clone();
let inner: Arc<dyn Fn(bsv::remittance::types::PeerMessage) + Send + Sync> =
Arc::new(move |_m| {
c.fetch_add(1, Ordering::SeqCst);
});
let deduped = super::exactly_once(inner);
let ws_path = super::record_ws_activity(deduped.clone(), activity.clone());
let poll_path = deduped;
ws_path(peer_msg("m1")); poll_path(peer_msg("m1"));
assert_eq!(count.load(Ordering::SeqCst), 1, "delivered exactly once across paths");
assert_eq!(activity.load(Ordering::Relaxed), 1, "only the WS path stamps activity");
}
}