1use std::{io, time::Duration};
2
3use bytes::Bytes;
4use msg_common::constants::KiB;
5use msg_transport::Address;
6use thiserror::Error;
7use tokio::sync::oneshot;
8
9mod driver;
10mod socket;
11pub use socket::*;
12
13mod stats;
14use stats::RepStats;
15
16use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_QUEUE_SIZE, Profile, stats::SocketStats};
17
18#[derive(Debug, Error)]
20pub enum RepError {
21 #[error("IO error: {0:?}")]
22 Io(#[from] std::io::Error),
23 #[error("Wire protocol error: {0:?}")]
24 Wire(#[from] msg_wire::reqrep::Error),
25 #[error("Socket closed")]
26 SocketClosed,
27 #[error("Could not connect to any valid endpoints")]
28 NoValidEndpoints,
29}
30
31impl RepError {
32 pub fn is_connection_reset(&self) -> bool {
33 match self {
34 Self::Io(e) | Self::Wire(msg_wire::reqrep::Error::Io(e)) => {
35 e.kind() == io::ErrorKind::ConnectionReset
36 }
37 _ => false,
38 }
39 }
40}
41
42pub struct RepOptions {
44 pub(crate) max_clients: Option<usize>,
46 pub(crate) min_compress_size: usize,
50 pub(crate) write_buffer_size: usize,
52 pub(crate) write_buffer_linger: Option<Duration>,
54 pub(crate) max_pending_responses: usize,
59}
60
61impl Default for RepOptions {
62 fn default() -> Self {
63 Self {
64 max_clients: None,
65 min_compress_size: DEFAULT_BUFFER_SIZE,
66 write_buffer_size: DEFAULT_BUFFER_SIZE,
67 write_buffer_linger: Some(Duration::from_micros(100)),
68 max_pending_responses: DEFAULT_QUEUE_SIZE,
69 }
70 }
71}
72
73impl RepOptions {
74 pub fn new(profile: Profile) -> Self {
76 match profile {
77 Profile::Latency => Self::low_latency(),
78 Profile::Throughput => Self::high_throughput(),
79 Profile::Balanced => Self::balanced(),
80 }
81 }
82
83 pub fn low_latency() -> Self {
85 Self {
86 write_buffer_size: 8 * KiB as usize,
87 write_buffer_linger: Some(Duration::from_micros(50)),
88 ..Default::default()
89 }
90 }
91
92 pub fn high_throughput() -> Self {
94 Self {
95 write_buffer_size: 256 * KiB as usize,
96 write_buffer_linger: Some(Duration::from_micros(200)),
97 ..Default::default()
98 }
99 }
100
101 pub fn balanced() -> Self {
103 Self {
104 write_buffer_size: 32 * KiB as usize,
105 write_buffer_linger: Some(Duration::from_micros(100)),
106 ..Default::default()
107 }
108 }
109}
110
111impl RepOptions {
112 pub fn with_max_clients(mut self, max_clients: usize) -> Self {
114 self.max_clients = Some(max_clients);
115 self
116 }
117
118 pub fn with_min_compress_size(mut self, min_compress_size: usize) -> Self {
123 self.min_compress_size = min_compress_size;
124 self
125 }
126
127 pub fn with_write_buffer_size(mut self, size: usize) -> Self {
132 self.write_buffer_size = size;
133 self
134 }
135
136 pub fn with_write_buffer_linger(mut self, duration: Option<Duration>) -> Self {
141 self.write_buffer_linger = duration;
142 self
143 }
144
145 pub fn with_max_pending_responses(mut self, hwm: usize) -> Self {
151 self.max_pending_responses = hwm;
152 self
153 }
154}
155
156#[derive(Debug, Default)]
158pub(crate) struct SocketState {
159 pub(crate) stats: SocketStats<RepStats>,
160}
161
162pub struct Request<A: Address> {
164 source: A,
166 compression_type: u8,
168 response: oneshot::Sender<Bytes>,
170 msg: Bytes,
172}
173
174impl<A: Address> Request<A> {
175 pub fn source(&self) -> &A {
177 &self.source
178 }
179
180 pub fn msg(&self) -> &Bytes {
182 &self.msg
183 }
184
185 pub fn respond(self, response: Bytes) -> Result<(), RepError> {
187 self.response.send(response).map_err(|_| RepError::SocketClosed)
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use std::{net::SocketAddr, time::Duration};
194
195 use futures::StreamExt;
196 use msg_transport::tcp::Tcp;
197 use msg_wire::compression::{GzipCompressor, SnappyCompressor};
198
199 use rand::Rng;
200 use tracing::{debug, info};
201
202 use crate::{
203 ReqOptions,
204 hooks::token::{ClientHook, ServerHook},
205 req::ReqSocket,
206 };
207
208 use super::*;
209
210 fn localhost() -> SocketAddr {
211 "127.0.0.1:0".parse().unwrap()
212 }
213
214 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
215 async fn reqrep_simple() {
216 let _ = tracing_subscriber::fmt::try_init();
217 let mut rep = RepSocket::new(Tcp::default());
218 rep.bind(localhost()).await.unwrap();
219
220 let mut req = ReqSocket::new(Tcp::default());
221 req.connect(rep.local_addr().unwrap()).await.unwrap();
222
223 tokio::spawn(async move {
224 loop {
225 let req = rep.next().await.unwrap();
226
227 req.respond(Bytes::from("hello")).unwrap();
228 }
229 });
230
231 let n_reqs = 1000;
232 let mut rng = rand::rng();
233 let msg_vec: Vec<Bytes> = (0..n_reqs)
234 .map(|_| {
235 let mut vec = vec![0u8; 512];
236 rng.fill(&mut vec[..]);
237 Bytes::from(vec)
238 })
239 .collect();
240
241 let start = std::time::Instant::now();
242 for msg in msg_vec {
243 let _res = req.request(msg).await.unwrap();
244 }
246 let elapsed = start.elapsed();
247 info!("{} reqs in {:?}", n_reqs, elapsed);
248 }
249
250 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
251 async fn reqrep_durable() {
252 let _ = tracing_subscriber::fmt::try_init();
253 let random_port = rand::rng().random_range(10000..65535);
254 let addr = format!("0.0.0.0:{random_port}");
255
256 let mut req = ReqSocket::new(Tcp::default());
258 let endpoint = addr.clone();
260 let connection_attempt = tokio::spawn(async move {
261 req.connect(endpoint).await.unwrap();
262
263 req
264 });
265
266 tokio::time::sleep(Duration::from_millis(500)).await;
268 let mut rep = RepSocket::new(Tcp::default());
269 rep.bind(addr).await.unwrap();
270
271 let req = connection_attempt.await.unwrap();
272
273 tokio::spawn(async move {
274 let req = rep.next().await.unwrap();
277 println!("Message: {:?}", req.msg());
278
279 req.respond(Bytes::from("world")).unwrap();
280 });
281
282 let _ = req.request(Bytes::from("hello")).await.unwrap();
283 }
284
285 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
286 async fn reqrep_auth() {
287 let _ = tracing_subscriber::fmt::try_init();
288 let mut rep = RepSocket::new(Tcp::default()).with_connection_hook(ServerHook::accept_all());
289 rep.bind(localhost()).await.unwrap();
290
291 let mut req = ReqSocket::new(Tcp::default())
293 .with_connection_hook(ClientHook::new(Bytes::from("REQ")));
294
295 req.connect(rep.local_addr().unwrap()).await.unwrap();
296
297 info!("Connected to rep");
298
299 tokio::spawn(async move {
300 loop {
301 let req = rep.next().await.unwrap();
302 debug!("Received request");
303
304 req.respond(Bytes::from("hello")).unwrap();
305 }
306 });
307
308 let n_reqs = 1000;
309 let mut rng = rand::rng();
310 let msg_vec: Vec<Bytes> = (0..n_reqs)
311 .map(|_| {
312 let mut vec = vec![0u8; 512];
313 rng.fill(&mut vec[..]);
314 Bytes::from(vec)
315 })
316 .collect();
317
318 let start = std::time::Instant::now();
319 for msg in msg_vec {
320 let _res = req.request(msg).await.unwrap();
321 }
322 let elapsed = start.elapsed();
323 info!("{} reqs in {:?}", n_reqs, elapsed);
324 }
325
326 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
327 async fn rep_max_connections() {
328 let _ = tracing_subscriber::fmt::try_init();
329 let mut rep =
330 RepSocket::with_options(Tcp::default(), RepOptions::default().with_max_clients(1));
331 rep.bind("127.0.0.1:0").await.unwrap();
332 let addr = rep.local_addr().unwrap();
333
334 let mut req1 = ReqSocket::new(Tcp::default());
335 req1.connect(addr).await.unwrap();
336 tokio::time::sleep(Duration::from_secs(1)).await;
337 assert_eq!(rep.stats().active_clients(), 1);
338
339 let mut req2 = ReqSocket::new(Tcp::default());
340 req2.connect(addr).await.unwrap();
341 tokio::time::sleep(Duration::from_secs(1)).await;
342 assert_eq!(rep.stats().active_clients(), 1);
343 }
344
345 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
346 async fn test_basic_reqrep_with_compression() {
347 let mut rep = RepSocket::with_options(
348 Tcp::default(),
349 RepOptions::default().with_min_compress_size(0),
350 )
351 .with_compressor(SnappyCompressor);
352
353 rep.bind("0.0.0.0:4445").await.unwrap();
354
355 let mut req = ReqSocket::with_options(
356 Tcp::default(),
357 ReqOptions::default().with_min_compress_size(0),
358 )
359 .with_compressor(GzipCompressor::new(6));
360
361 req.connect("0.0.0.0:4445").await.unwrap();
362
363 tokio::spawn(async move {
364 let req = rep.next().await.unwrap();
365
366 assert_eq!(req.msg(), &Bytes::from("hello"));
367 req.respond(Bytes::from("world")).unwrap();
368 });
369
370 let res: Bytes = req.request(Bytes::from("hello")).await.unwrap();
371 assert_eq!(res, Bytes::from("world"));
372 }
373}