use rustzmq2::{prelude::*, PubSocket, SubSocket, ZmqMessage};
use std::time::{Duration, Instant};
use tokio::runtime::Builder;
fn msg_size() -> usize {
std::env::var("MSG_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(256)
}
fn throughput_mode() -> bool {
std::env::var("MODE").as_deref() == Ok("throughput")
}
const ITERS: usize = 200_000;
fn main() {
let rt = Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.expect("tokio runtime");
rt.block_on(async {
let mut p = PubSocket::new();
let bound = p
.bind("tcp://127.0.0.1:0")
.await
.expect("pub bind")
.to_string();
let mut s = SubSocket::new();
s.connect(&bound).await.expect("sub connect");
s.subscribe("").await.expect("subscribe");
tokio::time::sleep(Duration::from_millis(200)).await;
let payload: Vec<u8> = vec![0xAB; msg_size()];
let mode = if throughput_mode() {
"throughput"
} else {
"rtt"
};
eprintln!("MODE={mode} MSG_SIZE={} ITERS={}", payload.len(), ITERS);
for _ in 0..2_000 {
p.send(ZmqMessage::from(payload.clone()))
.await
.expect("warmup send");
let _ = s.recv().await.expect("warmup recv");
}
rustzmq2::__bench::wake_dump_and_reset("warmup");
let elapsed = if throughput_mode() {
let mut s = s;
const BATCH: usize = 64;
let ack = std::sync::Arc::new(tokio::sync::Notify::new());
let ack_for_sub = ack.clone();
let payload_for_sub = payload.clone();
let recv_task = tokio::spawn(async move {
for i in 0..ITERS {
let m = s.recv().await.expect("sub recv");
debug_assert_eq!(m.get(0).map_or(0, |f| f.len()), payload_for_sub.len());
std::hint::black_box(m);
if (i + 1) % BATCH == 0 {
ack_for_sub.notify_one();
}
}
ack_for_sub.notify_one();
});
let start = Instant::now();
for i in 0..ITERS {
p.send(ZmqMessage::from(payload.clone()))
.await
.expect("send");
if (i + 1) % BATCH == 0 {
ack.notified().await;
}
}
recv_task.await.expect("recv task");
start.elapsed()
} else {
let start = Instant::now();
for _ in 0..ITERS {
p.send(ZmqMessage::from(payload.clone()))
.await
.expect("send");
let m = s.recv().await.expect("recv");
std::hint::black_box(m);
}
start.elapsed()
};
let per_iter_ns = elapsed.as_nanos() as f64 / ITERS as f64;
let mb_per_s = (payload.len() as f64 * ITERS as f64) / elapsed.as_secs_f64() / 1_048_576.0;
eprintln!(
"iters={} total={:?} per_iter={:.2} ns throughput={:.1} MiB/s",
ITERS, elapsed, per_iter_ns, mb_per_s
);
rustzmq2::__bench::wake_dump_and_reset("measured");
});
}