use std::sync::Arc;
use bsv::auth::types::RequestedCertificateSet;
use bsv::wallet::interfaces::Certificate;
use dashmap::DashMap;
use tokio::sync::mpsc;
use tokio::sync::Notify;
use crate::config::OnCertificatesReceived;
#[derive(Clone)]
pub struct CertificateGate {
pending: Arc<DashMap<String, Arc<Notify>>>,
}
impl CertificateGate {
pub fn new() -> Self {
Self {
pending: Arc::new(DashMap::new()),
}
}
pub fn register(&self, identity_key: &str) -> Arc<Notify> {
self.pending
.entry(identity_key.to_string())
.or_insert_with(|| Arc::new(Notify::new()))
.clone()
}
pub fn release(&self, identity_key: &str) {
if let Some((_, notify)) = self.pending.remove(identity_key) {
notify.notify_waiters();
}
}
}
impl Default for CertificateGate {
fn default() -> Self {
Self::new()
}
}
pub async fn certificate_listener_task(
mut cert_rx: mpsc::Receiver<(String, Vec<Certificate>)>,
mut cert_req_rx: mpsc::Receiver<(String, RequestedCertificateSet)>,
gate: CertificateGate,
callback: Option<Arc<OnCertificatesReceived>>,
) {
loop {
tokio::select! {
msg = cert_rx.recv() => {
match msg {
Some((sender_key, certs)) => {
tracing::info!(
sender = %sender_key,
count = certs.len(),
"certificates received from peer"
);
if let Some(ref cb) = callback {
let cb = Arc::clone(cb);
let key = sender_key.clone();
tokio::spawn(async move {
let fut = cb(key, certs);
fut.await;
});
}
gate.release(&sender_key);
}
None => {
tracing::debug!("certificate receiver closed");
break;
}
}
}
msg = cert_req_rx.recv() => {
match msg {
Some((sender_key, _requested)) => {
tracing::debug!(
sender = %sender_key,
"certificate request received from peer (handled by Peer internally)"
);
}
None => {
tracing::debug!("certificate request receiver closed");
break;
}
}
}
}
}
tracing::debug!("certificate listener task exiting");
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
#[test]
fn test_gate_register_returns_notify() {
let gate = CertificateGate::new();
let notify = gate.register("identity_key_1");
assert!(Arc::strong_count(¬ify) >= 1);
}
#[test]
fn test_gate_register_same_key_returns_same_notify() {
let gate = CertificateGate::new();
let notify1 = gate.register("identity_key_1");
let notify2 = gate.register("identity_key_1");
assert!(Arc::ptr_eq(¬ify1, ¬ify2));
}
#[tokio::test]
async fn test_gate_release_wakes_waiter() {
let gate = CertificateGate::new();
let notify = gate.register("identity_key_1");
let gate_clone = gate.clone();
let handle = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
gate_clone.release("identity_key_1");
});
let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
assert!(result.is_ok(), "notified() should have completed");
handle.await.unwrap();
}
#[test]
fn test_gate_release_unknown_key_does_not_panic() {
let gate = CertificateGate::new();
gate.release("unknown_key");
}
#[tokio::test]
async fn test_listener_invokes_callback_on_certificate() {
let called = Arc::new(AtomicBool::new(false));
let called_clone = called.clone();
let callback: OnCertificatesReceived = Box::new(move |_key, _certs| {
let called = called_clone.clone();
Box::pin(async move {
called.store(true, Ordering::SeqCst);
})
});
let gate = CertificateGate::new();
let (cert_tx, cert_rx) = mpsc::channel(8);
let (_cert_req_tx, cert_req_rx) = mpsc::channel(8);
let task = tokio::spawn(certificate_listener_task(
cert_rx,
cert_req_rx,
gate.clone(),
Some(Arc::new(callback)),
));
cert_tx
.send(("sender_1".to_string(), vec![]))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(
called.load(Ordering::SeqCst),
"callback should have been invoked"
);
drop(cert_tx);
drop(_cert_req_tx);
let _ = tokio::time::timeout(Duration::from_secs(2), task).await;
}
#[tokio::test]
async fn test_listener_releases_gate_on_certificate() {
let gate = CertificateGate::new();
let notify = gate.register("sender_1");
let (cert_tx, cert_rx) = mpsc::channel(8);
let (_cert_req_tx, cert_req_rx) = mpsc::channel(8);
let task = tokio::spawn(certificate_listener_task(
cert_rx,
cert_req_rx,
gate.clone(),
None,
));
cert_tx
.send(("sender_1".to_string(), vec![]))
.await
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), notify.notified()).await;
assert!(result.is_ok(), "gate should have been released");
drop(cert_tx);
drop(_cert_req_tx);
let _ = tokio::time::timeout(Duration::from_secs(2), task).await;
}
#[tokio::test]
async fn test_listener_exits_when_channels_close() {
let gate = CertificateGate::new();
let (cert_tx, cert_rx) = mpsc::channel::<(String, Vec<Certificate>)>(8);
let (cert_req_tx, cert_req_rx) = mpsc::channel::<(String, RequestedCertificateSet)>(8);
let task = tokio::spawn(certificate_listener_task(cert_rx, cert_req_rx, gate, None));
drop(cert_tx);
drop(cert_req_tx);
let result = tokio::time::timeout(Duration::from_secs(2), task).await;
assert!(result.is_ok(), "task should have completed");
assert!(result.unwrap().is_ok(), "task should not have panicked");
}
#[tokio::test]
async fn test_listener_handles_callback_panic_gracefully() {
let callback: OnCertificatesReceived = Box::new(|_key, _certs| {
Box::pin(async {
panic!("callback panicked intentionally");
})
});
let gate = CertificateGate::new();
let (cert_tx, cert_rx) = mpsc::channel(8);
let (_cert_req_tx, cert_req_rx) = mpsc::channel(8);
let task = tokio::spawn(certificate_listener_task(
cert_rx,
cert_req_rx,
gate.clone(),
Some(Arc::new(callback)),
));
cert_tx
.send(("sender_1".to_string(), vec![]))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
drop(cert_tx);
drop(_cert_req_tx);
let result = tokio::time::timeout(Duration::from_secs(2), task).await;
assert!(result.is_ok(), "listener task should have completed");
assert!(
result.unwrap().is_ok(),
"listener task should not have panicked"
);
}
}