use crate::contacts::{ContactStore, TrustLevel};
use crate::error::{NetworkError, NetworkResult};
use crate::identity::AgentId;
use crate::network::NetworkNode;
use bytes::Bytes;
use saorsa_gossip_pubsub::{PlumtreePubSub, PubSub};
use saorsa_gossip_types::{PeerId, TopicId};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
const MSG_V2_PREFIX: &[u8] = b"x0x-msg-v2";
const VERSION_V2: u8 = 0x02;
pub struct SigningContext {
pub agent_id: AgentId,
pub public_key_bytes: Vec<u8>,
secret_key_bytes: Vec<u8>,
}
impl std::fmt::Debug for SigningContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SigningContext")
.field("agent_id", &self.agent_id)
.field("public_key_bytes_len", &self.public_key_bytes.len())
.field("secret_key", &"<REDACTED>")
.finish()
}
}
impl SigningContext {
pub fn from_keypair(kp: &crate::identity::AgentKeypair) -> Self {
let (pub_bytes, sec_bytes) = kp.to_bytes();
Self {
agent_id: kp.agent_id(),
public_key_bytes: pub_bytes,
secret_key_bytes: sec_bytes,
}
}
pub fn sign(&self, message: &[u8]) -> NetworkResult<Vec<u8>> {
let secret_key =
ant_quic::MlDsaSecretKey::from_bytes(&self.secret_key_bytes).map_err(|e| {
NetworkError::SerializationError(format!("invalid secret key: {:?}", e))
})?;
let signature =
ant_quic::crypto::raw_public_keys::pqc::sign_with_ml_dsa(&secret_key, message)
.map_err(|e| {
NetworkError::SerializationError(format!("signing failed: {:?}", e))
})?;
Ok(signature.as_bytes().to_vec())
}
}
#[derive(Debug, Clone)]
pub struct PubSubMessage {
pub topic: String,
pub payload: Bytes,
pub sender: Option<AgentId>,
pub sender_public_key: Option<Vec<u8>>,
pub verified: bool,
pub trust_level: Option<TrustLevel>,
}
pub struct Subscription {
topic: String,
receiver: mpsc::Receiver<PubSubMessage>,
topic_ref_counts: Arc<RwLock<HashMap<String, usize>>>,
}
impl Subscription {
#[must_use]
pub fn topic(&self) -> &str {
&self.topic
}
pub async fn recv(&mut self) -> Option<PubSubMessage> {
self.receiver.recv().await
}
}
impl Drop for Subscription {
fn drop(&mut self) {
let topic = self.topic.clone();
let topic_ref_counts = self.topic_ref_counts.clone();
tokio::spawn(async move {
let mut counts = topic_ref_counts.write().await;
if let Some(count) = counts.get_mut(&topic) {
if *count > 1 {
*count -= 1;
} else {
counts.remove(&topic);
}
}
});
}
}
pub struct PubSubManager {
network: Arc<NetworkNode>,
plumtree: Arc<PlumtreePubSub<NetworkNode>>,
topic_ref_counts: Arc<RwLock<HashMap<String, usize>>>,
signing: Option<Arc<SigningContext>>,
contacts: std::sync::OnceLock<Arc<tokio::sync::RwLock<ContactStore>>>,
}
impl std::fmt::Debug for PubSubManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PubSubManager")
.field("network", &self.network)
.field("topic_count", &"<dynamic>")
.field("signing_enabled", &self.signing.is_some())
.finish_non_exhaustive()
}
}
impl PubSubManager {
pub fn new(
network: Arc<NetworkNode>,
signing: Option<Arc<SigningContext>>,
) -> NetworkResult<Self> {
let peer_id = saorsa_gossip_transport::GossipTransport::local_peer_id(network.as_ref());
let plumtree_signing_key =
saorsa_gossip_identity::MlDsaKeyPair::generate().map_err(|e| {
NetworkError::NodeCreation(format!("failed to create PlumTree signing key: {e}"))
})?;
let plumtree = Arc::new(PlumtreePubSub::new(
peer_id,
Arc::clone(&network),
plumtree_signing_key,
));
Ok(Self {
network,
plumtree,
topic_ref_counts: Arc::new(RwLock::new(HashMap::new())),
signing,
contacts: std::sync::OnceLock::new(),
})
}
pub fn set_contacts(&self, store: Arc<tokio::sync::RwLock<ContactStore>>) {
let _ = self.contacts.set(store);
}
pub async fn subscribe(&self, topic: String) -> Subscription {
let topic_id = TopicId::from_entity(topic.as_bytes());
self.initialize_topic_peers(topic_id).await;
let mut plumtree_rx = self.plumtree.subscribe(topic_id);
tokio::task::yield_now().await;
let (tx, rx) = mpsc::channel(10_000);
let contacts = self.contacts.get().cloned();
{
let mut counts = self.topic_ref_counts.write().await;
*counts.entry(topic.clone()).or_insert(0) += 1;
}
let sub_topic = topic.clone();
tokio::spawn(async move {
while let Some((_peer, encoded_payload)) = plumtree_rx.recv().await {
tracing::info!(
topic = %sub_topic,
payload_len = encoded_payload.len(),
"[4/6 pubsub] received from PlumTree, decoding"
);
let Some(message) = decode_for_delivery(encoded_payload, contacts.as_ref()).await
else {
tracing::warn!(
topic = %sub_topic,
"[4/6 pubsub] decode_for_delivery returned None, skipping"
);
continue;
};
tracing::info!(
topic = %sub_topic,
msg_topic = %message.topic,
"[4/6 pubsub] decoded, forwarding to subscriber channel"
);
if tx.send(message).await.is_err() {
tracing::info!(topic = %sub_topic, "[4/6 pubsub] subscriber channel closed");
break;
}
}
});
Subscription {
topic,
receiver: rx,
topic_ref_counts: self.topic_ref_counts.clone(),
}
}
pub async fn publish(&self, topic: String, payload: Bytes) -> NetworkResult<()> {
let encoded = if let Some(ref ctx) = self.signing {
let signing_payload =
build_signing_payload(ctx.agent_id.as_bytes(), topic.as_bytes(), &payload);
let signature = ctx.sign(&signing_payload)?;
encode_v2(
&ctx.agent_id,
&ctx.public_key_bytes,
&signature,
&topic,
&payload,
)?
} else {
encode_v1(&topic, &payload)?
};
let topic_id = TopicId::from_entity(topic.as_bytes());
self.initialize_topic_peers(topic_id).await;
self.plumtree
.publish(topic_id, encoded)
.await
.map_err(|e| NetworkError::ConnectionFailed(format!("PlumTree publish failed: {e}")))
}
pub async fn handle_incoming(&self, peer: PeerId, data: Bytes) {
if let Err(e) = self.plumtree.handle_message(peer, data).await {
tracing::warn!("Failed to handle PlumTree pubsub message from {peer}: {e}");
}
}
pub async fn subscription_count(&self) -> usize {
self.topic_ref_counts.read().await.len()
}
pub async fn unsubscribe(&self, topic: &str) {
self.topic_ref_counts.write().await.remove(topic);
let topic_id = TopicId::from_entity(topic.as_bytes());
if let Err(e) = self.plumtree.unsubscribe(topic_id).await {
tracing::debug!("PlumTree unsubscribe failed for topic '{topic}': {e}");
}
}
pub async fn refresh_topic_peers(&self) {
let peers: Vec<PeerId> = self
.network
.connected_peers()
.await
.into_iter()
.map(|peer| PeerId::new(peer.0))
.collect();
let subscribed: Vec<String> = self.topic_ref_counts.read().await.keys().cloned().collect();
if !peers.is_empty() && !subscribed.is_empty() {
tracing::debug!(
"[4/6 pubsub] refresh_topic_peers: {} connected peers, {} subscribed topics",
peers.len(),
subscribed.len()
);
}
for topic in &subscribed {
let topic_id = TopicId::from_entity(topic.as_bytes());
self.plumtree.set_topic_peers(topic_id, peers.clone()).await;
}
let all_plumtree_topics = self.plumtree.all_topic_ids().await;
let subscribed_ids: std::collections::HashSet<TopicId> = subscribed
.iter()
.map(|t| TopicId::from_entity(t.as_bytes()))
.collect();
for topic_id in all_plumtree_topics {
if !subscribed_ids.contains(&topic_id) {
self.plumtree.set_topic_peers(topic_id, peers.clone()).await;
}
}
}
async fn initialize_topic_peers(&self, topic: TopicId) {
let peers: Vec<PeerId> = self
.network
.connected_peers()
.await
.into_iter()
.map(|peer| PeerId::new(peer.0))
.collect();
self.plumtree.initialize_topic_peers(topic, peers).await;
}
}
async fn decode_for_delivery(
encoded_payload: Bytes,
contacts: Option<&Arc<tokio::sync::RwLock<ContactStore>>>,
) -> Option<PubSubMessage> {
let mut message = match decode_auto(encoded_payload) {
Ok(msg) => msg,
Err(e) => {
tracing::warn!("Failed to decode x0x payload from PlumTree message: {}", e);
return None;
}
};
if message.sender.is_some() && !message.verified {
tracing::warn!(
"Dropping pubsub payload with invalid signature from sender {:?}",
message.sender
);
return None;
}
if let (Some(store), Some(sender)) = (contacts, message.sender) {
let guard = store.read().await;
if guard.is_revoked(&sender) {
tracing::debug!("Dropping delivered payload from revoked sender {}", sender);
return None;
}
let trust = guard.trust_level(&sender);
drop(guard);
if trust == TrustLevel::Blocked {
tracing::debug!("Dropping delivered payload from blocked sender {}", sender);
return None;
}
message.trust_level = Some(trust);
}
Some(message)
}
fn encode_v1(topic: &str, payload: &Bytes) -> NetworkResult<Bytes> {
let topic_bytes = topic.as_bytes();
let topic_len = u16::try_from(topic_bytes.len())
.map_err(|_| NetworkError::SerializationError("Topic too long".to_string()))?;
let mut buf = Vec::with_capacity(2 + topic_bytes.len() + payload.len());
buf.extend_from_slice(&topic_len.to_be_bytes());
buf.extend_from_slice(topic_bytes);
buf.extend_from_slice(payload);
Ok(Bytes::from(buf))
}
fn decode_v1(data: &[u8]) -> NetworkResult<PubSubMessage> {
if data.len() < 2 {
return Err(NetworkError::SerializationError(
"Message too short".to_string(),
));
}
let topic_len = u16::from_be_bytes([data[0], data[1]]) as usize;
if data.len() < 2 + topic_len {
return Err(NetworkError::SerializationError(
"Invalid topic length".to_string(),
));
}
let topic = String::from_utf8(data[2..2 + topic_len].to_vec())
.map_err(|e| NetworkError::SerializationError(format!("Invalid UTF-8: {}", e)))?;
let payload = Bytes::copy_from_slice(&data[2 + topic_len..]);
Ok(PubSubMessage {
topic,
payload,
sender: None,
sender_public_key: None,
verified: false,
trust_level: None,
})
}
fn encode_v2(
agent_id: &AgentId,
public_key: &[u8],
signature: &[u8],
topic: &str,
payload: &Bytes,
) -> NetworkResult<Bytes> {
let topic_bytes = topic.as_bytes();
let topic_len = u16::try_from(topic_bytes.len())
.map_err(|_| NetworkError::SerializationError("Topic too long".to_string()))?;
let pk_len = u16::try_from(public_key.len())
.map_err(|_| NetworkError::SerializationError("Public key too long".to_string()))?;
let sig_len = u16::try_from(signature.len())
.map_err(|_| NetworkError::SerializationError("Signature too long".to_string()))?;
let total =
1 + 32 + 2 + public_key.len() + 2 + signature.len() + 2 + topic_bytes.len() + payload.len();
let mut buf = Vec::with_capacity(total);
buf.push(VERSION_V2);
buf.extend_from_slice(agent_id.as_bytes());
buf.extend_from_slice(&pk_len.to_be_bytes());
buf.extend_from_slice(public_key);
buf.extend_from_slice(&sig_len.to_be_bytes());
buf.extend_from_slice(signature);
buf.extend_from_slice(&topic_len.to_be_bytes());
buf.extend_from_slice(topic_bytes);
buf.extend_from_slice(payload);
Ok(Bytes::from(buf))
}
fn decode_v2(data: &[u8]) -> NetworkResult<PubSubMessage> {
if data.len() < 39 {
return Err(NetworkError::SerializationError(
"V2 message too short".to_string(),
));
}
let mut pos = 1;
let mut agent_id_bytes = [0u8; 32];
agent_id_bytes.copy_from_slice(&data[pos..pos + 32]);
let agent_id = AgentId(agent_id_bytes);
pos += 32;
if data.len() < pos + 2 {
return Err(NetworkError::SerializationError(
"Truncated pubkey length".to_string(),
));
}
let pk_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2;
if data.len() < pos + pk_len {
return Err(NetworkError::SerializationError(
"Truncated public key".to_string(),
));
}
let public_key_bytes = data[pos..pos + pk_len].to_vec();
pos += pk_len;
if data.len() < pos + 2 {
return Err(NetworkError::SerializationError(
"Truncated signature length".to_string(),
));
}
let sig_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2;
if data.len() < pos + sig_len {
return Err(NetworkError::SerializationError(
"Truncated signature".to_string(),
));
}
let signature_bytes = &data[pos..pos + sig_len];
pos += sig_len;
if data.len() < pos + 2 {
return Err(NetworkError::SerializationError(
"Truncated topic length".to_string(),
));
}
let topic_len = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2;
if data.len() < pos + topic_len {
return Err(NetworkError::SerializationError(
"Truncated topic".to_string(),
));
}
let topic = String::from_utf8(data[pos..pos + topic_len].to_vec())
.map_err(|e| NetworkError::SerializationError(format!("Invalid UTF-8: {}", e)))?;
pos += topic_len;
let payload = Bytes::copy_from_slice(&data[pos..]);
let verified = verify_signature(
&public_key_bytes,
&agent_id_bytes,
topic.as_bytes(),
&payload,
signature_bytes,
);
if !verified {
tracing::warn!(
"ML-DSA-65 signature verification failed for sender {}",
agent_id
);
}
Ok(PubSubMessage {
topic,
payload,
sender: Some(agent_id),
sender_public_key: Some(public_key_bytes),
verified,
trust_level: None,
})
}
fn decode_auto(data: Bytes) -> NetworkResult<PubSubMessage> {
if data.is_empty() {
return Err(NetworkError::SerializationError(
"Empty message".to_string(),
));
}
if data[0] == VERSION_V2 {
decode_v2(&data)
} else {
decode_v1(&data)
}
}
fn build_signing_payload(agent_id: &[u8; 32], topic: &[u8], payload: &[u8]) -> Vec<u8> {
let mut buf = Vec::with_capacity(MSG_V2_PREFIX.len() + 32 + topic.len() + payload.len());
buf.extend_from_slice(MSG_V2_PREFIX);
buf.extend_from_slice(agent_id);
buf.extend_from_slice(topic);
buf.extend_from_slice(payload);
buf
}
fn verify_signature(
public_key_bytes: &[u8],
agent_id: &[u8; 32],
topic: &[u8],
payload: &[u8],
signature_bytes: &[u8],
) -> bool {
let public_key = match ant_quic::MlDsaPublicKey::from_bytes(public_key_bytes) {
Ok(pk) => pk,
Err(_) => return false,
};
let derived_id = crate::identity::AgentId::from_public_key(&public_key);
if derived_id.0 != *agent_id {
tracing::warn!("Agent ID mismatch: embedded ID does not match public key");
return false;
}
let signature =
match ant_quic::crypto::raw_public_keys::pqc::MlDsaSignature::from_bytes(signature_bytes) {
Ok(sig) => sig,
Err(_) => return false,
};
let signing_payload = build_signing_payload(agent_id, topic, payload);
ant_quic::crypto::raw_public_keys::pqc::verify_with_ml_dsa(
&public_key,
&signing_payload,
&signature,
)
.is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::identity::AgentKeypair;
use crate::network::NetworkConfig;
async fn test_node() -> Arc<NetworkNode> {
Arc::new(
NetworkNode::new(NetworkConfig::default(), None, None)
.await
.expect("Failed to create test node"),
)
}
#[test]
fn test_v1_encode_decode_roundtrip() {
let topic = "test-topic";
let payload = Bytes::from(&b"hello world"[..]);
let encoded = encode_v1(topic, &payload).expect("Encoding failed");
let msg = decode_v1(&encoded).expect("Decoding failed");
assert_eq!(msg.topic, topic);
assert_eq!(msg.payload, payload);
assert!(msg.sender.is_none());
assert!(!msg.verified);
}
#[test]
fn test_v1_empty_topic() {
let encoded = encode_v1("", &Bytes::from("data")).expect("Encoding failed");
let msg = decode_v1(&encoded).expect("Decoding failed");
assert_eq!(msg.topic, "");
assert_eq!(msg.payload, Bytes::from("data"));
}
#[test]
fn test_v1_empty_payload() {
let encoded = encode_v1("topic", &Bytes::new()).expect("Encoding failed");
let msg = decode_v1(&encoded).expect("Decoding failed");
assert_eq!(msg.topic, "topic");
assert!(msg.payload.is_empty());
}
#[test]
fn test_v1_unicode_topic() {
let topic = "тема/главная/система";
let payload = Bytes::from(&b"data"[..]);
let encoded = encode_v1(topic, &payload).expect("Encoding failed");
let msg = decode_v1(&encoded).expect("Decoding failed");
assert_eq!(msg.topic, topic);
}
#[test]
fn test_v1_too_long_topic() {
let topic = "a".repeat(70000);
assert!(encode_v1(&topic, &Bytes::from("data")).is_err());
}
#[test]
fn test_v1_too_short() {
assert!(decode_v1(&[0x12]).is_err());
}
#[test]
fn test_v1_invalid_utf8() {
let data = vec![0, 3, 0xFF, 0xFF, 0xFF];
assert!(decode_v1(&data).is_err());
}
#[test]
fn test_v2_encode_decode_roundtrip() {
let kp = AgentKeypair::generate().expect("keygen");
let ctx = SigningContext::from_keypair(&kp);
let topic = "chat";
let payload = Bytes::from("hello signed world");
let signing_payload =
build_signing_payload(ctx.agent_id.as_bytes(), topic.as_bytes(), &payload);
let signature = ctx.sign(&signing_payload).expect("sign");
let encoded = encode_v2(
&ctx.agent_id,
&ctx.public_key_bytes,
&signature,
topic,
&payload,
)
.expect("encode");
let msg = decode_v2(&encoded).expect("decode");
assert_eq!(msg.topic, topic);
assert_eq!(msg.payload, payload);
assert_eq!(msg.sender, Some(ctx.agent_id));
assert!(msg.verified);
}
#[test]
fn test_v2_tampered_payload_fails_verification() {
let kp = AgentKeypair::generate().expect("keygen");
let ctx = SigningContext::from_keypair(&kp);
let topic = "chat";
let payload = Bytes::from("original");
let signing_payload =
build_signing_payload(ctx.agent_id.as_bytes(), topic.as_bytes(), &payload);
let signature = ctx.sign(&signing_payload).expect("sign");
let tampered_payload = Bytes::from("TAMPERED");
let encoded = encode_v2(
&ctx.agent_id,
&ctx.public_key_bytes,
&signature,
topic,
&tampered_payload,
)
.expect("encode");
let msg = decode_v2(&encoded).expect("decode");
assert!(!msg.verified); }
#[test]
fn test_v2_wrong_sender_fails() {
let kp1 = AgentKeypair::generate().expect("keygen1");
let kp2 = AgentKeypair::generate().expect("keygen2");
let ctx1 = SigningContext::from_keypair(&kp1);
let topic = "chat";
let payload = Bytes::from("hello");
let signing_payload =
build_signing_payload(ctx1.agent_id.as_bytes(), topic.as_bytes(), &payload);
let signature = ctx1.sign(&signing_payload).expect("sign");
let ctx2 = SigningContext::from_keypair(&kp2);
let encoded = encode_v2(
&ctx2.agent_id,
&ctx2.public_key_bytes,
&signature,
topic,
&payload,
)
.expect("encode");
let msg = decode_v2(&encoded).expect("decode");
assert!(!msg.verified); }
#[test]
fn test_v2_empty_payload() {
let kp = AgentKeypair::generate().expect("keygen");
let ctx = SigningContext::from_keypair(&kp);
let topic = "ping";
let payload = Bytes::new();
let signing_payload =
build_signing_payload(ctx.agent_id.as_bytes(), topic.as_bytes(), &payload);
let signature = ctx.sign(&signing_payload).expect("sign");
let encoded = encode_v2(
&ctx.agent_id,
&ctx.public_key_bytes,
&signature,
topic,
&payload,
)
.expect("encode");
let msg = decode_v2(&encoded).expect("decode");
assert!(msg.verified);
assert!(msg.payload.is_empty());
}
#[test]
fn test_v2_truncated_data() {
assert!(decode_v2(&[VERSION_V2, 0, 0, 0]).is_err());
}
#[test]
fn test_auto_detect_v1() {
let encoded = encode_v1("topic", &Bytes::from("data")).expect("encode");
let msg = decode_auto(encoded).expect("decode");
assert_eq!(msg.topic, "topic");
assert!(msg.sender.is_none());
assert!(!msg.verified);
}
#[test]
fn test_auto_detect_v2() {
let kp = AgentKeypair::generate().expect("keygen");
let ctx = SigningContext::from_keypair(&kp);
let topic = "test";
let payload = Bytes::from("signed");
let signing_payload =
build_signing_payload(ctx.agent_id.as_bytes(), topic.as_bytes(), &payload);
let signature = ctx.sign(&signing_payload).expect("sign");
let encoded = encode_v2(
&ctx.agent_id,
&ctx.public_key_bytes,
&signature,
topic,
&payload,
)
.expect("encode");
let msg = decode_auto(encoded).expect("decode");
assert_eq!(msg.topic, topic);
assert!(msg.sender.is_some());
assert!(msg.verified);
}
#[test]
fn test_auto_detect_empty() {
assert!(decode_auto(Bytes::new()).is_err());
}
#[test]
fn test_build_signing_payload_deterministic() {
let agent_id = [42u8; 32];
let p1 = build_signing_payload(&agent_id, b"topic", b"payload");
let p2 = build_signing_payload(&agent_id, b"topic", b"payload");
assert_eq!(p1, p2);
let p3 = build_signing_payload(&agent_id, b"other", b"payload");
assert_ne!(p1, p3);
}
#[tokio::test]
async fn test_pubsub_creation() {
let node = test_node().await;
let _manager = PubSubManager::new(node, None).expect("manager");
}
#[tokio::test]
async fn test_subscribe_to_topic() {
let node = test_node().await;
let manager = PubSubManager::new(node, None).expect("manager");
let sub = manager.subscribe("test-topic".to_string()).await;
assert_eq!(sub.topic(), "test-topic");
}
#[tokio::test]
async fn test_publish_local_delivery_unsigned() {
let node = test_node().await;
let manager = PubSubManager::new(node, None).expect("manager");
let mut sub = manager.subscribe("chat".to_string()).await;
manager
.publish("chat".to_string(), Bytes::from("hello"))
.await
.expect("Publish failed");
let msg = sub.recv().await.expect("Failed to receive message");
assert_eq!(msg.topic, "chat");
assert_eq!(msg.payload, Bytes::from("hello"));
assert!(msg.sender.is_none());
assert!(!msg.verified);
}
#[tokio::test]
async fn test_publish_local_delivery_signed() {
let node = test_node().await;
let kp = AgentKeypair::generate().expect("keygen");
let ctx = Arc::new(SigningContext::from_keypair(&kp));
let manager = PubSubManager::new(node, Some(ctx.clone())).expect("manager");
let mut sub = manager.subscribe("chat".to_string()).await;
manager
.publish("chat".to_string(), Bytes::from("signed hello"))
.await
.expect("Publish failed");
let msg = sub.recv().await.expect("Failed to receive");
assert_eq!(msg.topic, "chat");
assert_eq!(msg.payload, Bytes::from("signed hello"));
assert_eq!(msg.sender, Some(kp.agent_id()));
assert!(msg.verified);
}
#[tokio::test]
async fn test_multiple_subscribers() {
let node = test_node().await;
let manager = PubSubManager::new(node, None).expect("manager");
let mut sub1 = manager.subscribe("news".to_string()).await;
let mut sub2 = manager.subscribe("news".to_string()).await;
manager
.publish("news".to_string(), Bytes::from("breaking"))
.await
.expect("Publish failed");
let msg1 = sub1.recv().await.expect("sub1 failed");
let msg2 = sub2.recv().await.expect("sub2 failed");
assert_eq!(msg1.payload, Bytes::from("breaking"));
assert_eq!(msg2.payload, Bytes::from("breaking"));
}
#[tokio::test]
async fn test_publish_no_subscribers() {
let node = test_node().await;
let manager = PubSubManager::new(node, None).expect("manager");
assert!(manager
.publish("empty".to_string(), Bytes::from("nothing"))
.await
.is_ok());
}
#[tokio::test]
async fn test_unsubscribe() {
let node = test_node().await;
let manager = PubSubManager::new(node, None).expect("manager");
let mut sub = manager.subscribe("temp".to_string()).await;
manager
.publish("temp".to_string(), Bytes::from("msg1"))
.await
.expect("Publish");
assert!(sub.recv().await.is_some());
manager.unsubscribe("temp").await;
manager
.publish("temp".to_string(), Bytes::from("msg2"))
.await
.expect("Publish");
assert!(sub.recv().await.is_none());
}
#[tokio::test]
async fn test_subscription_count() {
let node = test_node().await;
let manager = PubSubManager::new(node, None).expect("manager");
assert_eq!(manager.subscription_count().await, 0);
let _sub_t1 = manager.subscribe("t1".to_string()).await;
assert_eq!(manager.subscription_count().await, 1);
let _sub_t2 = manager.subscribe("t2".to_string()).await;
assert_eq!(manager.subscription_count().await, 2);
let _sub_t1_b = manager.subscribe("t1".to_string()).await; assert_eq!(manager.subscription_count().await, 2);
manager.unsubscribe("t1").await;
assert_eq!(manager.subscription_count().await, 1);
}
#[tokio::test]
async fn test_handle_incoming_invalid() {
let node = test_node().await;
let manager = PubSubManager::new(node, None).expect("manager");
let _sub = manager.subscribe("test".to_string()).await;
let peer = PeerId::new([1; 32]);
manager
.handle_incoming(peer, Bytes::from(&[0x12][..]))
.await;
}
#[tokio::test]
async fn test_multiple_subscribers_not_starved_by_replay_cache() {
let node = test_node().await;
let manager = PubSubManager::new(node, None).expect("manager");
let mut sub1 = manager.subscribe("multi".to_string()).await;
let mut sub2 = manager.subscribe("multi".to_string()).await;
let mut sub3 = manager.subscribe("multi".to_string()).await;
manager
.publish("multi".to_string(), Bytes::from("msg-a"))
.await
.expect("publish a");
let m1 = sub1.recv().await.expect("sub1");
let m2 = sub2.recv().await.expect("sub2");
let m3 = sub3.recv().await.expect("sub3");
assert_eq!(m1.payload, Bytes::from("msg-a"));
assert_eq!(m2.payload, Bytes::from("msg-a"));
assert_eq!(m3.payload, Bytes::from("msg-a"));
}
#[tokio::test]
async fn test_local_duplicate_publishes_are_delivered() {
let node = test_node().await;
let manager = PubSubManager::new(node, None).expect("manager");
let mut sub = manager.subscribe("dedup".to_string()).await;
manager
.publish("dedup".to_string(), Bytes::from("hello"))
.await
.expect("publish 1");
manager
.publish("dedup".to_string(), Bytes::from("hello"))
.await
.expect("publish 2 (same content, intentional)");
let msg1 = sub.recv().await.expect("should receive first message");
assert_eq!(msg1.payload, Bytes::from("hello"));
let msg2 = sub.recv().await.expect("should receive second message");
assert_eq!(
msg2.payload,
Bytes::from("hello"),
"Local duplicate publishes should both be delivered"
);
}
}