use rzmq::socket::options::SNDTIMEO;
use rzmq::socket::SocketEvent;
use rzmq::{Msg, SocketType, ZmqError};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify; use tokio::task::{self, JoinHandle};
use tokio::time::timeout;
mod common;
const SHORT_TIMEOUT: Duration = Duration::from_millis(250);
const LONG_TIMEOUT: Duration = Duration::from_secs(2);
const MONITOR_EVENT_TIMEOUT: Duration = Duration::from_secs(3);
#[tokio::test]
async fn test_context_term_closes_sockets() -> Result<(), ZmqError> {
println!("Starting test_context_term_closes_sockets...");
let ctx = common::test_context();
let push = ctx.socket(SocketType::Push)?;
let pull = ctx.socket(SocketType::Pull)?;
println!("Setting up monitor for PUSH socket...");
let push_monitor_rx = push.monitor_default().await?;
println!("PUSH monitor setup.");
let endpoint = "inproc://term-test";
println!("Binding PULL...");
pull.bind(endpoint).await?;
println!("Connecting PUSH...");
push.connect(endpoint).await?;
println!("Main test: Waiting for PUSH monitor Connected event...");
common::wait_for_monitor_event(
&push_monitor_rx,
MONITOR_EVENT_TIMEOUT,
SHORT_TIMEOUT,
|e| matches!(e, SocketEvent::Connected { endpoint: ep, .. } if *ep == endpoint),
)
.await
.expect("Monitor did not receive Connected event");
println!("Main test: PUSH monitor received Connected event.");
tokio::time::sleep(Duration::from_millis(20)).await;
println!("PUSH sending message...");
push.send(Msg::from_static(b"Before Term")).await?;
println!("PUSH sent.");
println!("PULL receiving message...");
let msg1 = common::recv_timeout(&pull, SHORT_TIMEOUT).await?;
assert_eq!(msg1.data().unwrap(), b"Before Term");
println!("PULL received.");
println!("Initiating context termination...");
let shutdown_task = ctx.shutdown();
println!("Waiting for Disconnected event for {}...", endpoint);
let event_wait_task = common::wait_for_monitor_event(
&push_monitor_rx,
MONITOR_EVENT_TIMEOUT, SHORT_TIMEOUT,
|e| matches!(e, SocketEvent::Disconnected { endpoint: ep } if *ep == endpoint),
);
let (_, event_wait_result) = futures::future::join(shutdown_task, event_wait_task).await;
match event_wait_result {
Ok(_) => println!("PUSH monitor received Disconnected event as expected."),
Err(e) => panic!("PUSH monitor wait failed: {}", e),
}
println!("Event received or timed out. Now waiting for term() to complete...");
let term_result = ctx.term().await; println!("ctx.term() completed with: {:?}", term_result);
assert!(term_result.is_ok(), "Context termination failed");
println!("Context terminated.");
println!("Attempting PUSH set_option_raw after term (should fail)...");
let setopt_res = push
.set_option_raw(rzmq::socket::options::SNDTIMEO, &(0i32).to_ne_bytes())
.await;
println!("PUSH set_option_raw result: {:?}", setopt_res);
assert!(
setopt_res.is_err(),
"Expected error setting option after term, got {:?}",
setopt_res
);
println!("PUSH set_option_raw correctly failed: {:?}", setopt_res.err().unwrap());
println!("Attempting PUSH send after term (should fail)...");
let send_res = push.send(Msg::from_static(b"After Term")).await;
assert!(
send_res.is_err(),
"Expected error sending after term and disconnect, got {:?}",
send_res
);
println!("PUSH send correctly failed after term: {:?}", send_res.err().unwrap());
println!("Attempting PULL recv after term (should fail)...");
let recv_res = pull.recv().await;
assert!(
recv_res.is_err(),
"Expected error receiving after term, got {:?}",
recv_res
);
println!("PULL recv correctly failed: {:?}", recv_res.err().unwrap());
println!("Test test_context_term_closes_sockets finished.");
Ok(())
}
#[tokio::test]
async fn test_socket_close_stops_connection() -> Result<(), ZmqError> {
println!("Starting test_socket_close_stops_connection...");
let ctx = common::test_context();
let push = ctx.socket(SocketType::Push)?;
let pull = ctx.socket(SocketType::Pull)?;
let endpoint = "tcp://127.0.0.1:5640";
println!("Binding PUSH...");
push.bind(endpoint).await?; tokio::time::sleep(Duration::from_millis(50)).await;
println!("Connecting PULL...");
pull.connect(endpoint).await?;
tokio::time::sleep(Duration::from_millis(150)).await;
println!("PUSH sending Message 1...");
push.send(Msg::from_static(b"Message 1")).await?;
let msg1 = common::recv_timeout(&pull, LONG_TIMEOUT).await?;
assert_eq!(msg1.data().unwrap(), b"Message 1");
println!("PULL received Message 1.");
println!("Closing PULL socket...");
pull.close().await?;
println!("PULL socket closed.");
tokio::time::sleep(Duration::from_millis(200)).await;
println!("PUSH sending Message 2 (after PULL closed)...");
push
.set_option_raw(rzmq::socket::options::SNDTIMEO, &(0i32).to_ne_bytes())
.await?;
let send_res = push.send(Msg::from_static(b"Message 2")).await;
println!("PUSH send result: {:?}", send_res);
assert!(
matches!(send_res, Err(ZmqError::ResourceLimitReached)),
"Expected ResourceLimitReached sending after peer close, got {:?}",
send_res
);
println!("PUSH correctly failed sending after PULL closed.");
println!("Terminating context...");
ctx.term().await?; println!("Test test_socket_close_stops_connection finished.");
Ok(())
}
#[tokio::test]
async fn test_socket_explicit_close_triggers_disconnect_event() -> anyhow::Result<()> {
println!("Starting test_socket_explicit_close_triggers_disconnect_event...");
let ctx = common::test_context();
let push = ctx.socket(SocketType::Push)?;
println!("Setting up PUSH monitor...");
let push_monitor = push.monitor_default().await?;
let endpoint = "tcp://127.0.0.1:5641";
println!("Binding PUSH...");
push.bind(endpoint).await?;
println!("Expecting Listening event...");
common::wait_for_monitor_event(
&push_monitor,
MONITOR_EVENT_TIMEOUT,
SHORT_TIMEOUT,
|e| matches!(e, SocketEvent::Listening { endpoint: ep } if ep == endpoint),
)
.await
.map_err(|e| anyhow::anyhow!("Listening event wait failed: {}", e))?;
println!("PUSH Monitor: Received Listening event.");
tokio::time::sleep(Duration::from_millis(50)).await;
let disconnected_endpoint_uri: String;
{
let pull = ctx.socket(SocketType::Pull)?;
println!("Connecting PULL...");
pull.connect(endpoint).await?;
println!("Expecting Accepted/Handshake event...");
let event2 = common::wait_for_monitor_event(&push_monitor, MONITOR_EVENT_TIMEOUT, SHORT_TIMEOUT, |e| {
matches!(e, SocketEvent::HandshakeSucceeded { .. })
})
.await
.map_err(|e| anyhow::anyhow!("Accepted/Handshake event wait failed: {}", e))?;
println!("PUSH Monitor: Received connection event: {:?}", event2);
disconnected_endpoint_uri = match event2 {
SocketEvent::Accepted { endpoint: _, peer_addr } => format!("tcp://{}", peer_addr),
SocketEvent::HandshakeSucceeded { endpoint: ep } => ep,
_ => panic!("Unexpected event type received: {:?}", event2),
};
println!("Determined peer endpoint URI: {}", disconnected_endpoint_uri);
tokio::time::sleep(Duration::from_millis(150)).await;
println!("PUSH sending Message 1...");
push.send(Msg::from_static(b"Message 1")).await?;
let msg1 = common::recv_timeout(&pull, LONG_TIMEOUT).await?;
assert_eq!(msg1.data().unwrap(), b"Message 1");
println!("PULL received Message 1.");
println!("PULL socket closing...");
pull.close().await?;
println!("PULL socket closed (explicitly).");
}
println!("Waiting for Disconnected event for {}...", disconnected_endpoint_uri);
common::wait_for_monitor_event(
&push_monitor,
MONITOR_EVENT_TIMEOUT,
SHORT_TIMEOUT,
|e| matches!(e, SocketEvent::Disconnected { endpoint: ep } if *ep == disconnected_endpoint_uri),
)
.await
.map_err(|e| anyhow::anyhow!("Disconnected event wait failed: {}", e))?;
println!(
"PUSH Monitor: Received Disconnected event for {}.",
disconnected_endpoint_uri
);
println!("PUSH setting SNDTIMEO=0...");
push.set_option_raw(SNDTIMEO, &(0i32).to_ne_bytes()).await?;
println!("PUSH sending Message 2 (after PULL disconnected)...");
let send_res = push.send(Msg::from_static(b"Message 2")).await;
println!("PUSH send result: {:?}", send_res);
assert!(
matches!(send_res, Err(ZmqError::ResourceLimitReached)),
"Expected ResourceLimitReached sending after peer disconnect, got {:?}",
send_res
);
println!("PUSH correctly failed sending after PULL disconnected.");
println!("Closing PUSH socket explicitly before context term...");
let push_close_res = push.close().await;
println!("PUSH close result: {:?}", push_close_res);
assert!(push_close_res.is_ok(), "PUSH close failed unexpectedly");
tokio::time::sleep(Duration::from_millis(50)).await;
println!("Terminating context...");
ctx.term().await?;
println!("Checking if monitor channel is closed...");
match timeout(SHORT_TIMEOUT, push_monitor.recv()).await {
Ok(Ok(event)) => {
match event {
SocketEvent::ConnectDelayed { .. } => {
anyhow::bail!(
"Received unexpected ConnectDelayed event after context term: {:?}",
event
);
}
_ => {
println!(
"PUSH Monitor received final event {:?} (expected closed/empty). Tolerating.",
event
);
}
}
}
Ok(Err(_recv_err)) => {
println!("PUSH Monitor channel correctly closed after context term.");
}
Err(_) => {
println!("PUSH Monitor channel timed out (likely closed and empty) after context term.");
}
}
println!("Test test_socket_explicit_close_triggers_disconnect_event finished.");
Ok(())
}
#[tokio::test]
async fn test_concurrent_term_and_op() -> Result<(), ZmqError> {
println!("Starting test_concurrent_term_and_op...");
let ctx = common::test_context();
let push = Arc::new(ctx.socket(SocketType::Push)?);
let pull = ctx.socket(SocketType::Pull)?;
let endpoint = "inproc://concurrent-term";
println!("Binding PULL...");
pull.bind(endpoint).await?;
println!("Connecting PUSH...");
push.connect(endpoint).await?;
tokio::time::sleep(Duration::from_millis(50)).await;
let push_clone = push.clone();
let finished_sending = Arc::new(Notify::new());
let finished_sending_clone = finished_sending.clone();
let send_task: JoinHandle<()> = task::spawn(async move {
let mut count = 0;
loop {
let msg = Msg::from_vec(format!("Msg {}", count).into_bytes());
match common::send_timeout(&*push_clone, msg, SHORT_TIMEOUT).await {
Ok(()) => {
count += 1;
tokio::task::yield_now().await;
}
Err(ZmqError::Timeout) => {
println!("Send task: Send timed out.");
break;
}
Err(e) => {
println!("Send task: Send failed: {}", e);
break;
}
}
if count % 1000 == 0 {
println!("Send task: Sent {} messages", count);
}
}
println!("Send task finished.");
finished_sending_clone.notify_one(); });
println!("Main task receiving first message...");
let _ = common::recv_timeout(&pull, LONG_TIMEOUT).await?; println!("Main task received one message.");
tokio::time::sleep(Duration::from_millis(10)).await;
println!("Main task initiating context termination...");
let termination_complete = Arc::new(Notify::new());
let termination_complete_clone = termination_complete.clone();
let term_task: JoinHandle<Result<(), ZmqError>> = task::spawn(async move {
let result = ctx.term().await; println!("Termination task: ctx.term() finished with result: {:?}", result);
termination_complete_clone.notify_waiters(); result });
tokio::select! {
_ = termination_complete.notified() => {
println!("Termination completed signal received first.");
}
_ = finished_sending.notified() => {
println!("Send task finished notification received first.");
println!("Awaiting final termination signal after send task finished...");
_ = termination_complete.notified().await; println!("Final termination signal received after send task finished.");
}
_ = tokio::time::sleep(Duration::from_secs(5)) => {
panic!("Test timed out waiting for termination or send task completion");
}
}
let term_final_result = term_task.await; match term_final_result {
Ok(Ok(())) => {
println!("Termination task joined successfully.");
}
Ok(Err(e)) => return Err(e), Err(join_err) => panic!("Term task panicked: {:?}", join_err), }
if let Err(e) = send_task.await {
match e.try_into_panic() {
Ok(payload) => std::panic::resume_unwind(payload),
Err(join_err) if !join_err.is_cancelled() => {
panic!("Send task failed to join normally: {:?}", join_err);
}
_ => {
println!("Send task was cancelled (expected outcome possible).");
}
}
}
println!("Test test_concurrent_term_and_op finished.");
Ok(())
}