#![cfg(feature = "curve")]
mod curve {
use bytes::Bytes;
use crypto_box::{aead::OsRng, SecretKey};
use rustzmq2::__async_rt as async_rt;
use rustzmq2::prelude::*;
use rustzmq2::ZmqMessage;
fn gen_keypair() -> ([u8; 32], [u8; 32]) {
let sec = SecretKey::generate(&mut OsRng);
let pub_ = sec.public_key();
(*pub_.as_bytes(), sec.to_bytes())
}
#[async_rt::test]
async fn curve_auth_success() {
let (server_pub, server_sec) = gen_keypair();
let (client_pub, client_sec) = gen_keypair();
let mut rep = rustzmq2::RepSocket::builder()
.curve_server(server_pub, server_sec)
.build();
let endpoint = rep.bind("tcp://127.0.0.1:0").await.expect("bind");
let mut dealer = rustzmq2::DealerSocket::builder()
.curve_client(server_pub, client_pub, client_sec)
.build();
dealer
.connect(endpoint.to_string().as_str())
.await
.expect("connect");
async_rt::task::sleep(std::time::Duration::from_millis(150)).await;
let mut msg = ZmqMessage::from(Bytes::from_static(b""));
msg.push_back(Bytes::from_static(b"curve-hello"));
dealer.send(msg).await.expect("send");
let got = async_rt::task::timeout(std::time::Duration::from_secs(2), rep.recv())
.await
.expect("timeout")
.expect("recv");
assert_eq!(got.get(0).expect("no frame").as_ref(), b"curve-hello");
}
#[async_rt::test]
async fn curve_encrypted_message_roundtrip() {
let (server_pub, server_sec) = gen_keypair();
let (client_pub, client_sec) = gen_keypair();
let mut rep = rustzmq2::RepSocket::builder()
.curve_server(server_pub, server_sec)
.build();
let endpoint = rep.bind("tcp://127.0.0.1:0").await.expect("bind");
let mut dealer = rustzmq2::DealerSocket::builder()
.curve_client(server_pub, client_pub, client_sec)
.build();
dealer
.connect(endpoint.to_string().as_str())
.await
.expect("connect");
async_rt::task::sleep(std::time::Duration::from_millis(150)).await;
let mut send_msg = ZmqMessage::from(Bytes::from_static(b""));
send_msg.push_back(Bytes::from_static(b"hello"));
send_msg.push_back(Bytes::from_static(b"world"));
dealer.send(send_msg).await.expect("dealer send");
let server_got = async_rt::task::timeout(std::time::Duration::from_secs(2), rep.recv())
.await
.expect("server recv timeout")
.expect("server recv");
assert_eq!(
server_got.get(0).map(|f| f.as_ref()),
Some(b"hello" as &[u8])
);
assert_eq!(
server_got.get(1).map(|f| f.as_ref()),
Some(b"world" as &[u8])
);
let mut reply = ZmqMessage::from(Bytes::from_static(b"echo-hello"));
reply.push_back(Bytes::from_static(b"echo-world"));
rep.send(reply).await.expect("rep send");
let client_got = async_rt::task::timeout(std::time::Duration::from_secs(2), dealer.recv())
.await
.expect("client recv timeout")
.expect("client recv");
assert_eq!(
client_got.get(1).map(|f| f.as_ref()),
Some(b"echo-hello" as &[u8])
);
assert_eq!(
client_got.get(2).map(|f| f.as_ref()),
Some(b"echo-world" as &[u8])
);
}
#[async_rt::test]
async fn curve_wrong_server_key_rejected() {
let (real_server_pub, real_server_sec) = gen_keypair();
let (wrong_server_pub, _) = gen_keypair();
let (client_pub, client_sec) = gen_keypair();
let mut rep = rustzmq2::RepSocket::builder()
.curve_server(real_server_pub, real_server_sec)
.build();
let endpoint = rep.bind("tcp://127.0.0.1:0").await.expect("bind");
let mut dealer = rustzmq2::DealerSocket::builder()
.curve_client(wrong_server_pub, client_pub, client_sec)
.build();
let result = dealer.connect(endpoint.to_string().as_str()).await;
assert!(
result.is_err(),
"expected handshake failure with wrong server key"
);
}
}