use anyhow::Result;
use async_stream::stream;
use async_trait::async_trait;
use bytes::Bytes;
use std::sync::{Arc, Mutex};
const ZMQ_SNDHWM: i32 = 100_000; const ZMQ_RCVHWM: i32 = 100_000; const ZMQ_RCVTIMEOUT_MS: i32 = 100;
use super::codec::MsgpackCodec;
use super::frame::Frame;
use super::transport::{EventTransportRx, EventTransportTx, WireStream};
use crate::discovery::EventTransportKind;
struct ZmqMessage {
#[allow(dead_code)]
topic: Vec<u8>,
publisher_id: u64,
sequence: u64,
data: Vec<u8>,
}
pub struct ZmqPubTransport {
socket: Arc<Mutex<zmq::Socket>>,
topic: String,
}
impl ZmqPubTransport {
pub async fn bind(endpoint: &str, topic: &str) -> Result<(Self, String)> {
let actual_endpoint = if endpoint.ends_with(":0") {
let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await?;
let actual_addr = listener.local_addr()?;
let port = actual_addr.port();
drop(listener);
format!("tcp://0.0.0.0:{}", port)
} else {
endpoint.to_string()
};
let endpoint_for_closure = actual_endpoint.clone();
let socket = tokio::task::spawn_blocking(move || -> Result<zmq::Socket> {
let ctx = zmq::Context::new();
let socket = ctx.socket(zmq::PUB)?;
socket.set_sndhwm(ZMQ_SNDHWM)?;
socket.set_sndtimeo(0)?;
socket.bind(&endpoint_for_closure)?;
Ok(socket)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
tracing::info!(
endpoint = %actual_endpoint,
topic = %topic,
sndhwm = ZMQ_SNDHWM,
"ZMQ PUB transport bound with configured HWM"
);
Ok((
Self {
socket: Arc::new(Mutex::new(socket)),
topic: topic.to_string(),
},
actual_endpoint,
))
}
pub fn topic(&self) -> &str {
&self.topic
}
pub async fn connect(xsub_endpoint: &str, topic: &str) -> Result<Self> {
let endpoint_owned = xsub_endpoint.to_string();
let topic_owned = topic.to_string();
let socket = tokio::task::spawn_blocking(move || -> Result<zmq::Socket> {
let ctx = zmq::Context::new();
let socket = ctx.socket(zmq::PUB)?;
socket.set_sndhwm(ZMQ_SNDHWM)?;
socket.set_sndtimeo(0)?;
socket.connect(&endpoint_owned)?;
Ok(socket)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
tracing::info!(
endpoint = %xsub_endpoint,
topic = %topic,
sndhwm = ZMQ_SNDHWM,
"ZMQ PUB transport connected to broker XSUB"
);
Ok(Self {
socket: Arc::new(Mutex::new(socket)),
topic: topic_owned,
})
}
pub async fn connect_multiple(xsub_endpoints: &[String], topic: &str) -> Result<Self> {
if xsub_endpoints.is_empty() {
anyhow::bail!("Cannot connect to zero endpoints");
}
let endpoints_owned = xsub_endpoints.to_vec();
let topic_owned = topic.to_string();
let socket = tokio::task::spawn_blocking(move || -> Result<zmq::Socket> {
let ctx = zmq::Context::new();
let socket = ctx.socket(zmq::PUB)?;
socket.set_sndhwm(ZMQ_SNDHWM)?;
socket.set_sndtimeo(0)?;
for endpoint in &endpoints_owned {
socket.connect(endpoint)?;
tracing::debug!(endpoint = %endpoint, "ZMQ PUB connected to broker XSUB");
}
Ok(socket)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
tracing::info!(
num_endpoints = xsub_endpoints.len(),
topic = %topic,
sndhwm = ZMQ_SNDHWM,
"ZMQ PUB transport connected to multiple broker XSUBs with configured HWM"
);
Ok(Self {
socket: Arc::new(Mutex::new(socket)),
topic: topic_owned,
})
}
}
#[async_trait]
impl EventTransportTx for ZmqPubTransport {
async fn publish(&self, _subject: &str, envelope_bytes: Bytes) -> Result<()> {
let codec = MsgpackCodec;
let envelope = codec.decode_envelope(&envelope_bytes)?;
let frame = Frame::new(envelope_bytes);
let frame_bytes = frame.encode();
let topic_bytes = self.topic.as_bytes().to_vec();
let publisher_id_bytes = envelope.publisher_id.to_be_bytes().to_vec();
let sequence_bytes = envelope.sequence.to_be_bytes().to_vec();
let frame_vec = frame_bytes.to_vec();
let socket = Arc::clone(&self.socket);
tokio::task::spawn_blocking(move || -> Result<()> {
let socket = socket.lock().unwrap();
socket.send(&topic_bytes, zmq::SNDMORE)?;
socket.send(&publisher_id_bytes, zmq::SNDMORE)?;
socket.send(&sequence_bytes, zmq::SNDMORE)?;
socket.send(&frame_vec, 0)?;
Ok(())
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
Ok(())
}
fn kind(&self) -> EventTransportKind {
EventTransportKind::Zmq
}
}
pub struct ZmqSubTransport {
socket: Arc<Mutex<zmq::Socket>>,
broadcast_tx: tokio::sync::broadcast::Sender<Bytes>,
_socket_pump_handle: tokio::task::JoinHandle<()>,
}
impl ZmqSubTransport {
pub async fn connect(endpoint: &str, topic: &str) -> Result<Self> {
let endpoint_owned = endpoint.to_string();
let topic_owned = topic.to_string();
let socket = tokio::task::spawn_blocking(move || -> Result<zmq::Socket> {
let ctx = zmq::Context::new();
let socket = ctx.socket(zmq::SUB)?;
socket.set_rcvhwm(ZMQ_RCVHWM)?;
socket.set_rcvtimeo(ZMQ_RCVTIMEOUT_MS)?;
socket.connect(&endpoint_owned)?;
socket.set_subscribe(topic_owned.as_bytes())?;
Ok(socket)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
tracing::info!(
endpoint = %endpoint,
topic = %topic,
rcvhwm = ZMQ_RCVHWM,
"ZMQ SUB transport connected with configured HWM"
);
let socket = Arc::new(Mutex::new(socket));
let (broadcast_tx, _) = tokio::sync::broadcast::channel(1024);
let pump_handle = Self::start_socket_pump(Arc::clone(&socket), broadcast_tx.clone());
Ok(Self {
socket,
broadcast_tx,
_socket_pump_handle: pump_handle,
})
}
pub async fn connect_broker(xpub_endpoint: &str, topic: &str) -> Result<Self> {
Self::connect(xpub_endpoint, topic).await
}
pub async fn connect_broker_multiple(xpub_endpoints: &[String], topic: &str) -> Result<Self> {
Self::connect_multiple(xpub_endpoints, topic).await
}
pub async fn connect_multiple(endpoints: &[String], topic: &str) -> Result<Self> {
if endpoints.is_empty() {
anyhow::bail!("Cannot connect to zero endpoints");
}
let endpoints_owned = endpoints.to_vec();
let topic_owned = topic.to_string();
let socket = tokio::task::spawn_blocking(move || -> Result<zmq::Socket> {
let ctx = zmq::Context::new();
let socket = ctx.socket(zmq::SUB)?;
socket.set_rcvhwm(ZMQ_RCVHWM)?;
socket.set_rcvtimeo(ZMQ_RCVTIMEOUT_MS)?;
for endpoint in &endpoints_owned {
socket.connect(endpoint)?;
tracing::debug!(endpoint = %endpoint, "ZMQ SUB connected to endpoint");
}
socket.set_subscribe(topic_owned.as_bytes())?;
Ok(socket)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
tracing::info!(
num_endpoints = endpoints.len(),
topic = %topic,
rcvhwm = ZMQ_RCVHWM,
"ZMQ SUB transport connected to multiple endpoints with configured HWM"
);
let socket = Arc::new(Mutex::new(socket));
let (broadcast_tx, _) = tokio::sync::broadcast::channel(1024);
let pump_handle = Self::start_socket_pump(Arc::clone(&socket), broadcast_tx.clone());
Ok(Self {
socket,
broadcast_tx,
_socket_pump_handle: pump_handle,
})
}
fn start_socket_pump(
socket: Arc<Mutex<zmq::Socket>>,
broadcast_tx: tokio::sync::broadcast::Sender<Bytes>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
loop {
let socket_clone = Arc::clone(&socket);
let result = tokio::task::spawn_blocking(move || -> Result<Option<ZmqMessage>> {
let socket = socket_clone.lock().unwrap();
let topic = match socket.recv_bytes(0) {
Ok(data) => data,
Err(zmq::Error::EAGAIN) => return Ok(None), Err(e) => return Err(e.into()),
};
let publisher_id_bytes = socket.recv_bytes(0)?;
if publisher_id_bytes.len() != 8 {
anyhow::bail!(
"Invalid publisher_id frame: expected 8 bytes, got {}",
publisher_id_bytes.len()
);
}
let publisher_id = u64::from_be_bytes(publisher_id_bytes.try_into().unwrap());
let sequence_bytes = socket.recv_bytes(0)?;
if sequence_bytes.len() != 8 {
anyhow::bail!(
"Invalid sequence frame: expected 8 bytes, got {}",
sequence_bytes.len()
);
}
let sequence = u64::from_be_bytes(sequence_bytes.try_into().unwrap());
let data = socket.recv_bytes(0)?;
Ok(Some(ZmqMessage {
topic,
publisher_id,
sequence,
data,
}))
})
.await;
match result {
Ok(Ok(Some(ZmqMessage {
publisher_id,
sequence,
data: frame_bytes,
..
}))) => {
tracing::trace!(
publisher_id = publisher_id,
sequence = sequence,
"Socket pump received ZMQ message"
);
let frame_bytes = Bytes::from(frame_bytes);
match Frame::decode(frame_bytes) {
Ok(frame) => {
let _ = broadcast_tx.send(frame.payload);
}
Err(e) => {
tracing::warn!(error = %e, "Failed to decode ZMQ frame in socket pump");
continue;
}
}
}
Ok(Ok(None)) => {
continue;
}
Ok(Err(e)) => {
tracing::error!(error = %e, "ZMQ receive error in socket pump");
break;
}
Err(e) => {
tracing::error!(error = %e, "Task join error in socket pump");
break;
}
}
}
tracing::info!("ZMQ socket pump task terminated");
})
}
}
#[async_trait]
impl EventTransportRx for ZmqSubTransport {
async fn subscribe(&self, _subject: &str) -> Result<WireStream> {
let mut receiver = self.broadcast_tx.subscribe();
let stream = stream! {
loop {
match receiver.recv().await {
Ok(payload) => {
yield Ok(payload);
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => {
tracing::warn!(
skipped = skipped,
"Subscriber lagged behind, skipped messages"
);
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
tracing::info!("Broadcast channel closed");
break;
}
}
}
};
Ok(Box::pin(stream))
}
fn kind(&self) -> EventTransportKind {
EventTransportKind::Zmq
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transports::event_plane::{EventEnvelope, MsgpackCodec};
use tokio::time::{Duration, timeout};
#[tokio::test]
async fn test_zmq_pubsub_basic() {
let port = 25555;
let endpoint = format!("tcp://127.0.0.1:{}", port);
let topic = "test-topic";
let (publisher, _actual_endpoint) = ZmqPubTransport::bind(&endpoint, topic)
.await
.expect("Failed to create publisher");
tokio::time::sleep(Duration::from_millis(100)).await;
let subscriber = ZmqSubTransport::connect(&endpoint, topic)
.await
.expect("Failed to create subscriber");
use futures::StreamExt;
let mut stream = subscriber
.subscribe(topic)
.await
.expect("Failed to create subscription");
tokio::time::sleep(Duration::from_millis(100)).await;
let codec = MsgpackCodec;
let envelope = EventEnvelope {
publisher_id: 12345,
sequence: 1,
published_at: 1700000000000,
topic: topic.to_string(),
payload: Bytes::from("test payload"),
};
let envelope_bytes = codec.encode_envelope(&envelope).unwrap();
publisher.publish(topic, envelope_bytes).await.unwrap();
let result = timeout(Duration::from_secs(2), stream.next()).await;
assert!(result.is_ok(), "Timeout waiting for message");
let received_bytes = result.unwrap().unwrap().unwrap();
let decoded = codec.decode_envelope(&received_bytes).unwrap();
assert_eq!(decoded.publisher_id, 12345);
assert_eq!(decoded.sequence, 1);
assert_eq!(decoded.topic, topic);
}
#[tokio::test]
async fn test_zmq_multiple_messages() {
let port = 25556;
let endpoint = format!("tcp://127.0.0.1:{}", port);
let topic = "multi-test";
let (publisher, _) = ZmqPubTransport::bind(&endpoint, topic).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let subscriber = ZmqSubTransport::connect(&endpoint, topic).await.unwrap();
use futures::StreamExt;
let mut stream = subscriber.subscribe(topic).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let codec = MsgpackCodec;
for i in 0..5 {
let envelope = EventEnvelope {
publisher_id: 99999,
sequence: i,
published_at: 1700000000000 + i,
topic: topic.to_string(),
payload: Bytes::from(format!("message {}", i)),
};
let bytes = codec.encode_envelope(&envelope).unwrap();
publisher.publish(topic, bytes).await.unwrap();
}
for i in 0..5 {
let result = timeout(Duration::from_secs(2), stream.next()).await;
assert!(result.is_ok(), "Timeout on message {}", i);
let received = result.unwrap().unwrap().unwrap();
let decoded = codec.decode_envelope(&received).unwrap();
assert_eq!(decoded.sequence, i);
assert_eq!(decoded.topic, topic);
}
}
}