Skip to main content

msg_socket/rep/
mod.rs

1use std::{io, time::Duration};
2
3use bytes::Bytes;
4use msg_common::constants::KiB;
5use msg_transport::Address;
6use thiserror::Error;
7use tokio::sync::oneshot;
8
9mod driver;
10mod socket;
11pub use socket::*;
12
13mod stats;
14use stats::RepStats;
15
16use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_QUEUE_SIZE, Profile, stats::SocketStats};
17
18/// Errors that can occur when using a reply socket.
19#[derive(Debug, Error)]
20pub enum RepError {
21    #[error("IO error: {0:?}")]
22    Io(#[from] std::io::Error),
23    #[error("Wire protocol error: {0:?}")]
24    Wire(#[from] msg_wire::reqrep::Error),
25    #[error("Socket closed")]
26    SocketClosed,
27    #[error("Could not connect to any valid endpoints")]
28    NoValidEndpoints,
29}
30
31impl RepError {
32    pub fn is_connection_reset(&self) -> bool {
33        match self {
34            Self::Io(e) | Self::Wire(msg_wire::reqrep::Error::Io(e)) => {
35                e.kind() == io::ErrorKind::ConnectionReset
36            }
37            _ => false,
38        }
39    }
40}
41
42/// The reply socket options.
43pub struct RepOptions {
44    /// The maximum number of concurrent clients.
45    pub(crate) max_clients: Option<usize>,
46    /// Minimum payload size in bytes for compression to be used.
47    ///
48    /// If the payload is smaller than this threshold, it will not be compressed.
49    pub(crate) min_compress_size: usize,
50    /// The size of the write buffer in bytes.
51    pub(crate) write_buffer_size: usize,
52    /// The maximum duration between flushes to the underlying transport
53    pub(crate) write_buffer_linger: Option<Duration>,
54    /// High-water mark for pending responses per peer.
55    ///
56    /// When this limit is reached, new requests will not be read from the underlying connection
57    /// until pending responses are fulfilled.
58    pub(crate) max_pending_responses: usize,
59}
60
61impl Default for RepOptions {
62    fn default() -> Self {
63        Self {
64            max_clients: None,
65            min_compress_size: DEFAULT_BUFFER_SIZE,
66            write_buffer_size: DEFAULT_BUFFER_SIZE,
67            write_buffer_linger: Some(Duration::from_micros(100)),
68            max_pending_responses: DEFAULT_QUEUE_SIZE,
69        }
70    }
71}
72
73impl RepOptions {
74    /// Creates new options based on the given profile.
75    pub fn new(profile: Profile) -> Self {
76        match profile {
77            Profile::Latency => Self::low_latency(),
78            Profile::Throughput => Self::high_throughput(),
79            Profile::Balanced => Self::balanced(),
80        }
81    }
82
83    /// Creates options optimized for low latency.
84    pub fn low_latency() -> Self {
85        Self {
86            write_buffer_size: 8 * KiB as usize,
87            write_buffer_linger: Some(Duration::from_micros(50)),
88            ..Default::default()
89        }
90    }
91
92    /// Creates options optimized for high throughput.
93    pub fn high_throughput() -> Self {
94        Self {
95            write_buffer_size: 256 * KiB as usize,
96            write_buffer_linger: Some(Duration::from_micros(200)),
97            ..Default::default()
98        }
99    }
100
101    /// Creates options optimized for a balanced trade-off between latency and throughput.
102    pub fn balanced() -> Self {
103        Self {
104            write_buffer_size: 32 * KiB as usize,
105            write_buffer_linger: Some(Duration::from_micros(100)),
106            ..Default::default()
107        }
108    }
109}
110
111impl RepOptions {
112    /// Sets the number of maximum concurrent clients.
113    pub fn with_max_clients(mut self, max_clients: usize) -> Self {
114        self.max_clients = Some(max_clients);
115        self
116    }
117
118    /// Sets the minimum payload size for compression.
119    /// If the payload is smaller than this value, it will not be compressed.
120    ///
121    /// Default: [`DEFAULT_BUFFER_SIZE`]
122    pub fn with_min_compress_size(mut self, min_compress_size: usize) -> Self {
123        self.min_compress_size = min_compress_size;
124        self
125    }
126
127    /// Sets the size (max capacity) of the write buffer in bytes. When the buffer is full, it will
128    /// be flushed to the underlying transport.
129    ///
130    /// Default: [`DEFAULT_BUFFER_SIZE`]
131    pub fn with_write_buffer_size(mut self, size: usize) -> Self {
132        self.write_buffer_size = size;
133        self
134    }
135
136    /// Sets the linger duration for the write buffer. If `None`, the write buffer will only be
137    /// flushed when the buffer is full.
138    ///
139    /// Default: 100µs
140    pub fn with_write_buffer_linger(mut self, duration: Option<Duration>) -> Self {
141        self.write_buffer_linger = duration;
142        self
143    }
144
145    /// Sets the high-water mark for pending responses per peer. When this limit is reached,
146    /// new requests will not be read from the underlying connection until pending
147    /// responses are fulfilled.
148    ///
149    /// Default: [`DEFAULT_QUEUE_SIZE`]
150    pub fn with_max_pending_responses(mut self, hwm: usize) -> Self {
151        self.max_pending_responses = hwm;
152        self
153    }
154}
155
156/// The request socket state, shared between the backend task and the socket.
157#[derive(Debug, Default)]
158pub(crate) struct SocketState {
159    pub(crate) stats: SocketStats<RepStats>,
160}
161
162/// A request received by the socket.
163pub struct Request<A: Address> {
164    /// The source address of the request.
165    source: A,
166    /// The compression type used for the request payload
167    compression_type: u8,
168    /// The oneshot channel to respond to the request.
169    response: oneshot::Sender<Bytes>,
170    /// The message payload.
171    msg: Bytes,
172}
173
174impl<A: Address> Request<A> {
175    /// Returns the source address of the request.
176    pub fn source(&self) -> &A {
177        &self.source
178    }
179
180    /// Returns a reference to the message.
181    pub fn msg(&self) -> &Bytes {
182        &self.msg
183    }
184
185    /// Responds to the request.
186    pub fn respond(self, response: Bytes) -> Result<(), RepError> {
187        self.response.send(response).map_err(|_| RepError::SocketClosed)
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use std::{net::SocketAddr, time::Duration};
194
195    use futures::StreamExt;
196    use msg_transport::tcp::Tcp;
197    use msg_wire::compression::{GzipCompressor, SnappyCompressor};
198
199    use rand::Rng;
200    use tracing::{debug, info};
201
202    use crate::{
203        ReqOptions,
204        hooks::token::{ClientHook, ServerHook},
205        req::ReqSocket,
206    };
207
208    use super::*;
209
210    fn localhost() -> SocketAddr {
211        "127.0.0.1:0".parse().unwrap()
212    }
213
214    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
215    async fn reqrep_simple() {
216        let _ = tracing_subscriber::fmt::try_init();
217        let mut rep = RepSocket::new(Tcp::default());
218        rep.bind(localhost()).await.unwrap();
219
220        let mut req = ReqSocket::new(Tcp::default());
221        req.connect(rep.local_addr().unwrap()).await.unwrap();
222
223        tokio::spawn(async move {
224            loop {
225                let req = rep.next().await.unwrap();
226
227                req.respond(Bytes::from("hello")).unwrap();
228            }
229        });
230
231        let n_reqs = 1000;
232        let mut rng = rand::rng();
233        let msg_vec: Vec<Bytes> = (0..n_reqs)
234            .map(|_| {
235                let mut vec = vec![0u8; 512];
236                rng.fill(&mut vec[..]);
237                Bytes::from(vec)
238            })
239            .collect();
240
241        let start = std::time::Instant::now();
242        for msg in msg_vec {
243            let _res = req.request(msg).await.unwrap();
244            // println!("Response: {:?} {:?}", _res, req_start.elapsed());
245        }
246        let elapsed = start.elapsed();
247        info!("{} reqs in {:?}", n_reqs, elapsed);
248    }
249
250    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
251    async fn reqrep_durable() {
252        let _ = tracing_subscriber::fmt::try_init();
253        let random_port = rand::rng().random_range(10000..65535);
254        let addr = format!("0.0.0.0:{random_port}");
255
256        // Initialize the request socket (client side) with a transport
257        let mut req = ReqSocket::new(Tcp::default());
258        // Try to connect even through the server isn't up yet
259        let endpoint = addr.clone();
260        let connection_attempt = tokio::spawn(async move {
261            req.connect(endpoint).await.unwrap();
262
263            req
264        });
265
266        // Wait a moment to start the server
267        tokio::time::sleep(Duration::from_millis(500)).await;
268        let mut rep = RepSocket::new(Tcp::default());
269        rep.bind(addr).await.unwrap();
270
271        let req = connection_attempt.await.unwrap();
272
273        tokio::spawn(async move {
274            // Receive the request and respond with "world"
275            // RepSocket implements `Stream`
276            let req = rep.next().await.unwrap();
277            println!("Message: {:?}", req.msg());
278
279            req.respond(Bytes::from("world")).unwrap();
280        });
281
282        let _ = req.request(Bytes::from("hello")).await.unwrap();
283    }
284
285    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
286    async fn reqrep_auth() {
287        let _ = tracing_subscriber::fmt::try_init();
288        let mut rep = RepSocket::new(Tcp::default()).with_connection_hook(ServerHook::accept_all());
289        rep.bind(localhost()).await.unwrap();
290
291        // Initialize socket with a client ID via connection hook.
292        let mut req = ReqSocket::new(Tcp::default())
293            .with_connection_hook(ClientHook::new(Bytes::from("REQ")));
294
295        req.connect(rep.local_addr().unwrap()).await.unwrap();
296
297        info!("Connected to rep");
298
299        tokio::spawn(async move {
300            loop {
301                let req = rep.next().await.unwrap();
302                debug!("Received request");
303
304                req.respond(Bytes::from("hello")).unwrap();
305            }
306        });
307
308        let n_reqs = 1000;
309        let mut rng = rand::rng();
310        let msg_vec: Vec<Bytes> = (0..n_reqs)
311            .map(|_| {
312                let mut vec = vec![0u8; 512];
313                rng.fill(&mut vec[..]);
314                Bytes::from(vec)
315            })
316            .collect();
317
318        let start = std::time::Instant::now();
319        for msg in msg_vec {
320            let _res = req.request(msg).await.unwrap();
321        }
322        let elapsed = start.elapsed();
323        info!("{} reqs in {:?}", n_reqs, elapsed);
324    }
325
326    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
327    async fn rep_max_connections() {
328        let _ = tracing_subscriber::fmt::try_init();
329        let mut rep =
330            RepSocket::with_options(Tcp::default(), RepOptions::default().with_max_clients(1));
331        rep.bind("127.0.0.1:0").await.unwrap();
332        let addr = rep.local_addr().unwrap();
333
334        let mut req1 = ReqSocket::new(Tcp::default());
335        req1.connect(addr).await.unwrap();
336        tokio::time::sleep(Duration::from_secs(1)).await;
337        assert_eq!(rep.stats().active_clients(), 1);
338
339        let mut req2 = ReqSocket::new(Tcp::default());
340        req2.connect(addr).await.unwrap();
341        tokio::time::sleep(Duration::from_secs(1)).await;
342        assert_eq!(rep.stats().active_clients(), 1);
343    }
344
345    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
346    async fn test_basic_reqrep_with_compression() {
347        let mut rep = RepSocket::with_options(
348            Tcp::default(),
349            RepOptions::default().with_min_compress_size(0),
350        )
351        .with_compressor(SnappyCompressor);
352
353        rep.bind("0.0.0.0:4445").await.unwrap();
354
355        let mut req = ReqSocket::with_options(
356            Tcp::default(),
357            ReqOptions::default().with_min_compress_size(0),
358        )
359        .with_compressor(GzipCompressor::new(6));
360
361        req.connect("0.0.0.0:4445").await.unwrap();
362
363        tokio::spawn(async move {
364            let req = rep.next().await.unwrap();
365
366            assert_eq!(req.msg(), &Bytes::from("hello"));
367            req.respond(Bytes::from("world")).unwrap();
368        });
369
370        let res: Bytes = req.request(Bytes::from("hello")).await.unwrap();
371        assert_eq!(res, Bytes::from("world"));
372    }
373}