use std::sync::Arc;
use serde::de::DeserializeOwned;
use serde_json::Value;
use tokio::sync::broadcast;
pub fn build_broadcast_listener<T>(
method_name: &'static str,
tx: broadcast::Sender<T>,
) -> Arc<dyn Fn(Value) + Send + Sync>
where
T: DeserializeOwned + Clone + Send + 'static,
{
Arc::new(move |params: Value| {
let event: T = match serde_json::from_value(params) {
Ok(e) => e,
Err(e) => {
tracing::warn!(
method = method_name,
error = %e,
"notification listener: decode failed"
);
return;
}
};
let _ = tx.send(event);
})
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
struct Pinger {
n: u32,
}
#[tokio::test]
async fn listener_decodes_and_broadcasts() {
let (tx, mut rx) = broadcast::channel::<Pinger>(8);
let listener = build_broadcast_listener::<Pinger>("test/ping", tx);
listener(serde_json::json!({"n": 42}));
let got = tokio::time::timeout(Duration::from_millis(50), rx.recv())
.await
.expect("recv timeout")
.expect("recv ok");
assert_eq!(got, Pinger { n: 42 });
}
#[tokio::test]
async fn listener_drops_malformed_payload_without_panic() {
let (tx, mut rx) = broadcast::channel::<Pinger>(8);
let listener = build_broadcast_listener::<Pinger>("test/ping", tx);
listener(serde_json::json!({"not": "a real ping"}));
let res = tokio::time::timeout(Duration::from_millis(20), rx.recv()).await;
assert!(
res.is_err(),
"no event should be broadcast on malformed input"
);
}
#[tokio::test]
async fn listener_continues_when_no_subscribers() {
let (tx, _rx) = broadcast::channel::<Pinger>(8);
drop(_rx);
let listener = build_broadcast_listener::<Pinger>("test/ping", tx);
listener(serde_json::json!({"n": 7}));
}
}