mod error;
#[cfg(feature = "async")]
mod async_channel;
#[cfg(feature = "async")]
pub use async_channel::{AsyncChannelCapacity, AsyncChannelPair, AsyncChannelPairBuilder};
pub use error::ChannelPairError;
use crossbeam_channel::{
bounded, select, unbounded, Receiver as CrossbeamReceiver, RecvTimeoutError,
Sender as CrossbeamSender, TryRecvError, TrySendError,
};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use zmq::{Context, PollEvents, Socket};
static UNIQUE_INDEX: AtomicU64 = AtomicU64::new(0);
const PAIR_IN: usize = 0;
const PAIR_OUT: usize = 1;
pub type ZmqMessage = Vec<Vec<u8>>;
#[derive(Debug, Clone, Copy, Default)]
pub enum ChannelCapacity {
#[default]
Unbounded,
Bounded(usize),
}
#[derive(Debug, Clone)]
pub struct Sender {
inner: CrossbeamSender<ZmqMessage>,
}
impl Sender {
pub fn send(&self, msg: ZmqMessage) -> Result<(), ChannelPairError> {
self.inner
.send(msg)
.map_err(|e| ChannelPairError::ChannelDisconnected(format!("send failed: {}", e)))
}
pub fn try_send(&self, msg: ZmqMessage) -> Result<(), TrySendError<ZmqMessage>> {
self.inner.try_send(msg)
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn is_full(&self) -> bool {
self.inner.is_full()
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn capacity(&self) -> Option<usize> {
self.inner.capacity()
}
}
#[derive(Debug, Clone)]
pub struct Receiver {
inner: CrossbeamReceiver<ZmqMessage>,
}
impl Receiver {
pub fn recv(&self) -> Result<ZmqMessage, ChannelPairError> {
self.inner
.recv()
.map_err(|_| ChannelPairError::ChannelDisconnected("receive channel closed".into()))
}
pub fn recv_timeout(&self, timeout: Duration) -> Result<ZmqMessage, RecvTimeoutError> {
self.inner.recv_timeout(timeout)
}
pub fn try_recv(&self) -> Result<ZmqMessage, TryRecvError> {
self.inner.try_recv()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn iter(&self) -> impl Iterator<Item = ZmqMessage> + '_ {
self.inner.iter()
}
pub fn try_iter(&self) -> impl Iterator<Item = ZmqMessage> + '_ {
self.inner.try_iter()
}
}
pub struct ChannelPairBuilder<'a> {
context: &'a Context,
socket: Socket,
capacity: ChannelCapacity,
}
impl<'a> ChannelPairBuilder<'a> {
pub fn new(context: &'a Context, socket: Socket) -> Self {
Self {
context,
socket,
capacity: ChannelCapacity::default(),
}
}
pub fn with_capacity(mut self, capacity: ChannelCapacity) -> Self {
self.capacity = capacity;
self
}
pub fn with_bounded_queue(self, depth: usize) -> Self {
self.with_capacity(ChannelCapacity::Bounded(depth))
}
pub fn with_unbounded_queue(self) -> Self {
self.with_capacity(ChannelCapacity::Unbounded)
}
pub fn build(self) -> Result<Arc<ChannelPair>, ChannelPairError> {
ChannelPair::with_capacity(self.context, self.socket, self.capacity)
}
}
enum SocketState {
Idle,
ReadyToSend(ZmqMessage),
}
impl SocketState {
fn reset(&mut self) {
*self = SocketState::Idle;
}
}
pub struct ChannelPair {
z_sock: Socket,
z_tx: Vec<Socket>,
z_control: Vec<Socket>,
tx_chan: (CrossbeamSender<ZmqMessage>, CrossbeamReceiver<ZmqMessage>),
rx_chan: (CrossbeamSender<ZmqMessage>, CrossbeamReceiver<ZmqMessage>),
error_chan: (
CrossbeamSender<ChannelPairError>,
CrossbeamReceiver<ChannelPairError>,
),
control_chan: (CrossbeamSender<bool>, CrossbeamReceiver<bool>),
is_shutdown: AtomicBool,
}
unsafe impl Send for ChannelPair {}
unsafe impl Sync for ChannelPair {}
impl ChannelPair {
pub fn new(context: &Context, socket: Socket) -> Result<Arc<Self>, ChannelPairError> {
Self::with_capacity(context, socket, ChannelCapacity::Unbounded)
}
pub fn with_capacity(
context: &Context,
socket: Socket,
capacity: ChannelCapacity,
) -> Result<Arc<Self>, ChannelPairError> {
let z_tx = Self::new_pair(context)?;
let z_control = Self::new_pair(context)?;
let (tx_chan, rx_chan) = match capacity {
ChannelCapacity::Unbounded => (unbounded(), unbounded()),
ChannelCapacity::Bounded(cap) => (bounded(cap), bounded(cap)),
};
let mut channel_pair = Self {
z_tx,
z_control,
z_sock: socket,
tx_chan,
rx_chan,
error_chan: unbounded(),
control_chan: unbounded(),
is_shutdown: AtomicBool::new(false),
};
Self::configure_socket(&mut channel_pair)?;
let channel_pair = Arc::new(channel_pair);
let cp_sockets = Arc::clone(&channel_pair);
std::thread::Builder::new()
.name("zmq-socket-io".into())
.spawn(move || cp_sockets.run_sockets())
.map_err(|e| {
ChannelPairError::Other(format!("failed to spawn socket thread: {}", e))
})?;
let cp_channels = Arc::clone(&channel_pair);
std::thread::Builder::new()
.name("zmq-channel-bridge".into())
.spawn(move || cp_channels.run_channels())
.map_err(|e| {
ChannelPairError::Other(format!("failed to spawn channel thread: {}", e))
})?;
Ok(channel_pair)
}
pub fn send(&self, msg: ZmqMessage) -> Result<(), ChannelPairError> {
self.sender().send(msg)
}
pub fn try_send(&self, msg: ZmqMessage) -> Result<(), TrySendError<ZmqMessage>> {
self.sender().try_send(msg)
}
pub fn recv(&self) -> Result<ZmqMessage, ChannelPairError> {
self.receiver().recv()
}
pub fn recv_timeout(&self, timeout: Duration) -> Result<ZmqMessage, RecvTimeoutError> {
self.receiver().recv_timeout(timeout)
}
pub fn try_recv(&self) -> Result<ZmqMessage, TryRecvError> {
self.receiver().try_recv()
}
pub fn sender(&self) -> Sender {
Sender {
inner: self.tx_chan.0.clone(),
}
}
pub fn receiver(&self) -> Receiver {
Receiver {
inner: self.rx_chan.1.clone(),
}
}
pub fn error_receiver(&self) -> &CrossbeamReceiver<ChannelPairError> {
&self.error_chan.1
}
pub fn shutdown(&self) {
if self
.is_shutdown
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
let _ = self.control_chan.0.send(true);
let _ = self.z_control[PAIR_IN].send("", 0);
}
}
pub fn stop(&self) {
if self
.is_shutdown
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
let _ = self.control_chan.0.send(false);
let _ = self.z_control[PAIR_IN].send("", 0);
}
}
pub fn is_shutdown(&self) -> bool {
self.is_shutdown.load(Ordering::SeqCst)
}
#[deprecated(since = "2.0.0", note = "use receiver() instead")]
pub fn rx(&self) -> &CrossbeamReceiver<ZmqMessage> {
&self.rx_chan.1
}
#[deprecated(since = "2.0.0", note = "use sender() instead")]
pub fn tx(&self) -> &CrossbeamSender<ZmqMessage> {
&self.tx_chan.0
}
#[deprecated(since = "2.0.0", note = "use error_receiver() instead")]
pub fn rx_err_chan(&self) -> &CrossbeamReceiver<ChannelPairError> {
&self.error_chan.1
}
fn rx_writer(&self) -> &CrossbeamSender<ZmqMessage> {
&self.rx_chan.0
}
fn tx_reader(&self) -> &CrossbeamReceiver<ZmqMessage> {
&self.tx_chan.1
}
fn tx_control_chan(&self) -> &CrossbeamSender<bool> {
&self.control_chan.0
}
fn rx_control_chan(&self) -> &CrossbeamReceiver<bool> {
&self.control_chan.1
}
fn tx_err_chan(&self) -> &CrossbeamSender<ChannelPairError> {
&self.error_chan.0
}
fn on_err(&self, error: ChannelPairError) {
let _ = self.tx_err_chan().send(error);
let _ = self.tx_control_chan().send(false);
}
fn configure_socket(&mut self) -> Result<(), ChannelPairError> {
self.z_sock.set_rcvtimeo(0)?;
self.z_sock.set_sndtimeo(0)?;
for socket in &self.z_tx {
socket.set_rcvtimeo(0)?;
socket.set_sndtimeo(0)?;
}
for socket in &self.z_control {
socket.set_rcvtimeo(0)?;
socket.set_sndtimeo(0)?;
}
Ok(())
}
fn new_pair(context: &Context) -> Result<Vec<Socket>, ChannelPairError> {
let id = UNIQUE_INDEX.fetch_add(1, Ordering::SeqCst);
let addr = format!("inproc://_channelpair_internal-{}", id);
let server = context.socket(zmq::PAIR)?;
server.bind(&addr)?;
let client = context.socket(zmq::PAIR)?;
client.connect(&addr)?;
Ok(vec![server, client])
}
fn run_sockets(&self) {
let mut state = SocketState::Idle;
const SOCK_IDX: usize = 0;
const TX_IDX: usize = 1;
const CTRL_IDX: usize = 2;
let mut items = [
self.z_sock.as_poll_item(PollEvents::empty()),
self.z_tx[PAIR_OUT].as_poll_item(PollEvents::empty()),
self.z_control[PAIR_OUT].as_poll_item(PollEvents::POLLIN),
];
loop {
items[SOCK_IDX].set_events(match state {
SocketState::ReadyToSend(_) => zmq::POLLOUT,
SocketState::Idle => zmq::POLLIN,
});
items[TX_IDX].set_events(match state {
SocketState::Idle => zmq::POLLIN,
_ => PollEvents::empty(),
});
match zmq::poll(&mut items, -1) {
Ok(_) => {
if items[SOCK_IDX].is_readable() {
match self.z_sock.recv_multipart(0) {
Ok(msg) => {
if let Err(err) = self.rx_writer().send(msg) {
self.on_err(ChannelPairError::ChannelDisconnected(format!(
"failed to forward received message: {}",
err
)));
return;
}
}
Err(e) => {
self.on_err(ChannelPairError::Zmq(e));
return;
}
}
}
if items[SOCK_IDX].is_writable() {
if let SocketState::ReadyToSend(ref msg) = state {
match self.z_sock.send_multipart(msg, 0) {
Ok(_) => state.reset(),
Err(e) => {
self.on_err(ChannelPairError::Zmq(e));
return;
}
}
}
}
if items[TX_IDX].is_readable() {
match self.z_tx[PAIR_OUT].recv_multipart(0) {
Ok(msg) => state = SocketState::ReadyToSend(msg),
Err(e) => {
self.on_err(ChannelPairError::Zmq(e));
return;
}
}
}
if items[CTRL_IDX].is_readable() {
if let Err(e) = self.z_control[PAIR_OUT].recv_multipart(0) {
self.on_err(ChannelPairError::Zmq(e));
return;
}
self.handle_shutdown(&mut state);
return;
}
}
Err(e) => {
self.on_err(ChannelPairError::Zmq(e));
return;
}
}
}
}
fn handle_shutdown(&self, state: &mut SocketState) {
let linger = self.z_sock.get_linger().unwrap_or(0);
let _ = self.z_sock.set_sndtimeo(linger);
if let SocketState::ReadyToSend(ref msg) = state {
let _ = self.z_sock.send_multipart(msg, 0);
state.reset();
}
let mut items = [self.z_tx[PAIR_OUT].as_poll_item(zmq::POLLIN)];
loop {
match zmq::poll(&mut items, 0) {
Ok(_) if items[0].is_readable() => {
if let Ok(msg) = self.z_tx[PAIR_OUT].recv_multipart(0) {
let _ = self.z_sock.send_multipart(msg, 0);
}
}
_ => break,
}
}
while let Ok(msg) = self.tx_reader().try_recv() {
let _ = self.z_sock.send_multipart(msg, 0);
}
}
fn run_channels(&self) {
loop {
select! {
recv(self.tx_reader()) -> msg => {
match msg {
Ok(msg) => {
if let Err(e) = self.z_tx[PAIR_IN].send_multipart(&msg, 0) {
self.on_err(ChannelPairError::Zmq(e));
return;
}
}
Err(_) => {
return;
}
}
}
recv(self.rx_control_chan()) -> msg => {
match msg {
Ok(drain) => {
if drain {
let _ = self.z_tx[PAIR_IN].send("", 0);
}
return;
}
Err(_) => return,
}
}
}
}
}
}
impl Drop for ChannelPair {
fn drop(&mut self) {
self.is_shutdown.store(true, Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
fn create_pair_sockets(ctx: &Context) -> (Socket, Socket) {
let addr = format!(
"inproc://test-{}",
UNIQUE_INDEX.fetch_add(1, Ordering::SeqCst)
);
let server = ctx.socket(zmq::PAIR).expect("server socket");
server.bind(&addr).expect("bind");
let client = ctx.socket(zmq::PAIR).expect("client socket");
client.connect(&addr).expect("connect");
(server, client)
}
#[test]
fn test_send_receive_basic() {
let ctx = Context::new();
let (server_sock, client_sock) = create_pair_sockets(&ctx);
let server = ChannelPair::new(&ctx, server_sock).unwrap();
let client = ChannelPair::new(&ctx, client_sock).unwrap();
thread::sleep(Duration::from_millis(10));
let msg = vec![b"Hello".to_vec(), b"World".to_vec()];
client.send(msg.clone()).unwrap();
let received = server.recv_timeout(Duration::from_secs(1)).unwrap();
assert_eq!(received, msg);
server.shutdown();
client.shutdown();
}
#[test]
fn test_bounded_channel() {
let ctx = Context::new();
let (server_sock, client_sock) = create_pair_sockets(&ctx);
let server = ChannelPairBuilder::new(&ctx, server_sock)
.with_bounded_queue(2)
.build()
.unwrap();
let client = ChannelPairBuilder::new(&ctx, client_sock)
.with_bounded_queue(2)
.build()
.unwrap();
thread::sleep(Duration::from_millis(10));
client.send(vec![b"1".to_vec()]).unwrap();
client.send(vec![b"2".to_vec()]).unwrap();
let msg1 = server.recv_timeout(Duration::from_secs(1)).unwrap();
assert_eq!(msg1, vec![b"1".to_vec()]);
let msg2 = server.recv_timeout(Duration::from_secs(1)).unwrap();
assert_eq!(msg2, vec![b"2".to_vec()]);
server.shutdown();
client.shutdown();
}
#[test]
fn test_echo() {
let ctx = Context::new();
let (server_sock, client_sock) = create_pair_sockets(&ctx);
let server = ChannelPair::new(&ctx, server_sock).unwrap();
let client = ChannelPair::new(&ctx, client_sock).unwrap();
thread::sleep(Duration::from_millis(10));
const NUM_MESSAGES: usize = 10;
let server_clone = Arc::clone(&server);
let echo_handle = thread::spawn(move || {
for _ in 0..NUM_MESSAGES {
let msg = server_clone.recv_timeout(Duration::from_secs(1)).unwrap();
server_clone.send(msg).unwrap();
}
});
for i in 0..NUM_MESSAGES {
let msg = vec![format!("message-{}", i).into_bytes()];
client.send(msg.clone()).unwrap();
let response = client.recv_timeout(Duration::from_secs(1)).unwrap();
assert_eq!(response, msg);
}
echo_handle.join().unwrap();
server.shutdown();
client.shutdown();
}
#[test]
fn test_sender_receiver_handles() {
let ctx = Context::new();
let (server_sock, client_sock) = create_pair_sockets(&ctx);
let server = ChannelPair::new(&ctx, server_sock).unwrap();
let client = ChannelPair::new(&ctx, client_sock).unwrap();
thread::sleep(Duration::from_millis(10));
let sender = client.sender();
let receiver = server.receiver();
let handle = thread::spawn(move || {
sender.send(vec![b"from handle".to_vec()]).unwrap();
});
handle.join().unwrap();
let msg = receiver.recv_timeout(Duration::from_secs(1)).unwrap();
assert_eq!(msg, vec![b"from handle".to_vec()]);
server.shutdown();
client.shutdown();
}
#[test]
fn test_try_recv_empty() {
let ctx = Context::new();
let (server_sock, _client_sock) = create_pair_sockets(&ctx);
let server = ChannelPair::new(&ctx, server_sock).unwrap();
match server.try_recv() {
Err(TryRecvError::Empty) => {}
other => panic!("expected Empty, got {:?}", other),
}
server.shutdown();
}
#[test]
fn test_graceful_shutdown() {
let ctx = Context::new();
let (server_sock, client_sock) = create_pair_sockets(&ctx);
let server = ChannelPair::new(&ctx, server_sock).unwrap();
let client = ChannelPair::new(&ctx, client_sock).unwrap();
thread::sleep(Duration::from_millis(10));
client.send(vec![b"test".to_vec()]).unwrap();
client.shutdown();
assert!(client.is_shutdown());
let msg = server.recv_timeout(Duration::from_secs(1)).unwrap();
assert_eq!(msg, vec![b"test".to_vec()]);
server.shutdown();
}
#[test]
fn test_multipart_message() {
let ctx = Context::new();
let (server_sock, client_sock) = create_pair_sockets(&ctx);
let server = ChannelPair::new(&ctx, server_sock).unwrap();
let client = ChannelPair::new(&ctx, client_sock).unwrap();
thread::sleep(Duration::from_millis(10));
let multipart = vec![
b"identity".to_vec(),
b"".to_vec(),
b"header".to_vec(),
b"body".to_vec(),
];
client.send(multipart.clone()).unwrap();
let received = server.recv_timeout(Duration::from_secs(1)).unwrap();
assert_eq!(received.len(), 4);
assert_eq!(received, multipart);
server.shutdown();
client.shutdown();
}
}