use std::{io, time::Duration};
use bytes::Bytes;
use msg_common::constants::KiB;
use msg_transport::Address;
use thiserror::Error;
use tokio::sync::oneshot;
mod driver;
mod socket;
pub use socket::*;
mod stats;
use stats::RepStats;
use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_QUEUE_SIZE, Profile, stats::SocketStats};
#[derive(Debug, Error)]
pub enum RepError {
#[error("IO error: {0:?}")]
Io(#[from] std::io::Error),
#[error("Wire protocol error: {0:?}")]
Wire(#[from] msg_wire::reqrep::Error),
#[error("Socket closed")]
SocketClosed,
#[error("Could not connect to any valid endpoints")]
NoValidEndpoints,
}
impl RepError {
pub fn is_connection_reset(&self) -> bool {
match self {
Self::Io(e) | Self::Wire(msg_wire::reqrep::Error::Io(e)) => {
e.kind() == io::ErrorKind::ConnectionReset
}
_ => false,
}
}
}
pub struct RepOptions {
pub(crate) max_clients: Option<usize>,
pub(crate) min_compress_size: usize,
pub(crate) write_buffer_size: usize,
pub(crate) write_buffer_linger: Option<Duration>,
pub(crate) max_pending_responses: usize,
}
impl Default for RepOptions {
fn default() -> Self {
Self {
max_clients: None,
min_compress_size: DEFAULT_BUFFER_SIZE,
write_buffer_size: DEFAULT_BUFFER_SIZE,
write_buffer_linger: Some(Duration::from_micros(100)),
max_pending_responses: DEFAULT_QUEUE_SIZE,
}
}
}
impl RepOptions {
pub fn new(profile: Profile) -> Self {
match profile {
Profile::Latency => Self::low_latency(),
Profile::Throughput => Self::high_throughput(),
Profile::Balanced => Self::balanced(),
}
}
pub fn low_latency() -> Self {
Self {
write_buffer_size: 8 * KiB as usize,
write_buffer_linger: Some(Duration::from_micros(50)),
..Default::default()
}
}
pub fn high_throughput() -> Self {
Self {
write_buffer_size: 256 * KiB as usize,
write_buffer_linger: Some(Duration::from_micros(200)),
..Default::default()
}
}
pub fn balanced() -> Self {
Self {
write_buffer_size: 32 * KiB as usize,
write_buffer_linger: Some(Duration::from_micros(100)),
..Default::default()
}
}
}
impl RepOptions {
pub fn with_max_clients(mut self, max_clients: usize) -> Self {
self.max_clients = Some(max_clients);
self
}
pub fn with_min_compress_size(mut self, min_compress_size: usize) -> Self {
self.min_compress_size = min_compress_size;
self
}
pub fn with_write_buffer_size(mut self, size: usize) -> Self {
self.write_buffer_size = size;
self
}
pub fn with_write_buffer_linger(mut self, duration: Option<Duration>) -> Self {
self.write_buffer_linger = duration;
self
}
pub fn with_max_pending_responses(mut self, hwm: usize) -> Self {
self.max_pending_responses = hwm;
self
}
}
#[derive(Debug, Default)]
pub(crate) struct SocketState {
pub(crate) stats: SocketStats<RepStats>,
}
pub struct Request<A: Address> {
source: A,
compression_type: u8,
response: oneshot::Sender<Bytes>,
msg: Bytes,
}
impl<A: Address> Request<A> {
pub fn source(&self) -> &A {
&self.source
}
pub fn msg(&self) -> &Bytes {
&self.msg
}
pub fn respond(self, response: Bytes) -> Result<(), RepError> {
self.response.send(response).map_err(|_| RepError::SocketClosed)
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, time::Duration};
use futures::StreamExt;
use msg_transport::tcp::Tcp;
use msg_wire::compression::{GzipCompressor, SnappyCompressor};
use rand::Rng;
use tracing::{debug, info};
use crate::{
ReqOptions,
hooks::token::{ClientHook, ServerHook},
req::ReqSocket,
};
use super::*;
fn localhost() -> SocketAddr {
"127.0.0.1:0".parse().unwrap()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn reqrep_simple() {
let _ = tracing_subscriber::fmt::try_init();
let mut rep = RepSocket::new(Tcp::default());
rep.bind(localhost()).await.unwrap();
let mut req = ReqSocket::new(Tcp::default());
req.connect(rep.local_addr().unwrap()).await.unwrap();
tokio::spawn(async move {
loop {
let req = rep.next().await.unwrap();
req.respond(Bytes::from("hello")).unwrap();
}
});
let n_reqs = 1000;
let mut rng = rand::rng();
let msg_vec: Vec<Bytes> = (0..n_reqs)
.map(|_| {
let mut vec = vec![0u8; 512];
rng.fill(&mut vec[..]);
Bytes::from(vec)
})
.collect();
let start = std::time::Instant::now();
for msg in msg_vec {
let _res = req.request(msg).await.unwrap();
}
let elapsed = start.elapsed();
info!("{} reqs in {:?}", n_reqs, elapsed);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn reqrep_durable() {
let _ = tracing_subscriber::fmt::try_init();
let random_port = rand::rng().random_range(10000..65535);
let addr = format!("0.0.0.0:{random_port}");
let mut req = ReqSocket::new(Tcp::default());
let endpoint = addr.clone();
let connection_attempt = tokio::spawn(async move {
req.connect(endpoint).await.unwrap();
req
});
tokio::time::sleep(Duration::from_millis(500)).await;
let mut rep = RepSocket::new(Tcp::default());
rep.bind(addr).await.unwrap();
let req = connection_attempt.await.unwrap();
tokio::spawn(async move {
let req = rep.next().await.unwrap();
println!("Message: {:?}", req.msg());
req.respond(Bytes::from("world")).unwrap();
});
let _ = req.request(Bytes::from("hello")).await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn reqrep_auth() {
let _ = tracing_subscriber::fmt::try_init();
let mut rep = RepSocket::new(Tcp::default()).with_connection_hook(ServerHook::accept_all());
rep.bind(localhost()).await.unwrap();
let mut req = ReqSocket::new(Tcp::default())
.with_connection_hook(ClientHook::new(Bytes::from("REQ")));
req.connect(rep.local_addr().unwrap()).await.unwrap();
info!("Connected to rep");
tokio::spawn(async move {
loop {
let req = rep.next().await.unwrap();
debug!("Received request");
req.respond(Bytes::from("hello")).unwrap();
}
});
let n_reqs = 1000;
let mut rng = rand::rng();
let msg_vec: Vec<Bytes> = (0..n_reqs)
.map(|_| {
let mut vec = vec![0u8; 512];
rng.fill(&mut vec[..]);
Bytes::from(vec)
})
.collect();
let start = std::time::Instant::now();
for msg in msg_vec {
let _res = req.request(msg).await.unwrap();
}
let elapsed = start.elapsed();
info!("{} reqs in {:?}", n_reqs, elapsed);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn rep_max_connections() {
let _ = tracing_subscriber::fmt::try_init();
let mut rep =
RepSocket::with_options(Tcp::default(), RepOptions::default().with_max_clients(1));
rep.bind("127.0.0.1:0").await.unwrap();
let addr = rep.local_addr().unwrap();
let mut req1 = ReqSocket::new(Tcp::default());
req1.connect(addr).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(rep.stats().active_clients(), 1);
let mut req2 = ReqSocket::new(Tcp::default());
req2.connect(addr).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
assert_eq!(rep.stats().active_clients(), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_basic_reqrep_with_compression() {
let mut rep = RepSocket::with_options(
Tcp::default(),
RepOptions::default().with_min_compress_size(0),
)
.with_compressor(SnappyCompressor);
rep.bind("0.0.0.0:4445").await.unwrap();
let mut req = ReqSocket::with_options(
Tcp::default(),
ReqOptions::default().with_min_compress_size(0),
)
.with_compressor(GzipCompressor::new(6));
req.connect("0.0.0.0:4445").await.unwrap();
tokio::spawn(async move {
let req = rep.next().await.unwrap();
assert_eq!(req.msg(), &Bytes::from("hello"));
req.respond(Bytes::from("world")).unwrap();
});
let res: Bytes = req.request(Bytes::from("hello")).await.unwrap();
assert_eq!(res, Bytes::from("world"));
}
}