Skip to main content

msg_socket/req/
mod.rs

1use std::{
2    sync::{Arc, atomic::AtomicUsize},
3    time::Duration,
4};
5
6use arc_swap::ArcSwap;
7use bytes::Bytes;
8use thiserror::Error;
9use tokio::sync::oneshot;
10
11use msg_common::{constants::KiB, span::WithSpan};
12use msg_wire::{
13    compression::{CompressionType, Compressor},
14    reqrep,
15};
16
17mod conn_manager;
18mod driver;
19mod socket;
20mod stats;
21pub use socket::*;
22
23use crate::{Profile, stats::SocketStats};
24use stats::ReqStats;
25
26use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_QUEUE_SIZE};
27
28pub(crate) static DRIVER_ID: AtomicUsize = AtomicUsize::new(0);
29
30/// Errors that can occur when using a request socket.
31#[derive(Debug, Error)]
32pub enum ReqError {
33    #[error("IO error: {0:?}")]
34    Io(#[from] std::io::Error),
35    #[error("Wire protocol error: {0:?}")]
36    Wire(#[from] reqrep::Error),
37    #[error("Socket closed")]
38    SocketClosed,
39    #[error("Request timed out")]
40    Timeout,
41    #[error("Could not connect to any valid endpoints")]
42    NoValidEndpoints,
43    #[error("Failed to connect to the target endpoint: {0:?}")]
44    Connect(Box<dyn std::error::Error + Send + Sync>),
45    #[error("High-water mark reached")]
46    HighWaterMarkReached,
47}
48
49/// A command to send a request message and wait for a response.
50#[derive(Debug)]
51pub struct SendCommand {
52    /// The request message to send.
53    pub message: WithSpan<ReqMessage>,
54    /// The channel to send the peer's response back.
55    pub response: oneshot::Sender<Result<Bytes, ReqError>>,
56}
57
58impl SendCommand {
59    /// Creates a new send command.
60    pub fn new(
61        message: WithSpan<ReqMessage>,
62        response: oneshot::Sender<Result<Bytes, ReqError>>,
63    ) -> Self {
64        Self { message, response }
65    }
66}
67
68/// Options for the connection manager.
69#[derive(Debug, Clone)]
70pub struct ConnOptions {
71    /// The backoff duration for the underlying transport on reconnections.
72    pub backoff_duration: Duration,
73    /// The maximum number of retry attempts. If `None`, the connection will retry indefinitely.
74    pub retry_attempts: Option<usize>,
75}
76
77impl Default for ConnOptions {
78    fn default() -> Self {
79        Self {
80            // These values give a good default for most use cases.
81            //
82            // * formula: w_i = w_0 * 2^i
83            // * w_0 = 200ms, i = 0..9
84            // * worst-case total wait: sum(w_i) = 200ms * (2^9 - 1) = 102.2s
85            backoff_duration: Duration::from_millis(200),
86            retry_attempts: Some(9),
87        }
88    }
89}
90
91/// The request socket options.
92#[derive(Debug, Clone)]
93pub struct ReqOptions {
94    /// Options for the connection manager.
95    pub conn: ConnOptions,
96    /// Timeout duration for requests.
97    pub timeout: Duration,
98    /// Wether to block on initial connection to the target.
99    pub blocking_connect: bool,
100    /// Minimum payload size in bytes for compression to be used.
101    /// If the payload is smaller than this threshold, it will not be compressed.
102    pub min_compress_size: usize,
103    /// The size of the write buffer in bytes.
104    pub write_buffer_size: usize,
105    /// The linger duration for the write buffer (how long to wait before flushing).
106    pub write_buffer_linger: Option<Duration>,
107    /// The size of the channel buffer between the socket and the driver.
108    /// This controls how many requests can be queued, on top of the current pending requests,
109    /// before the socket returns [`ReqError::HighWaterMarkReached`].
110    pub max_queue_size: usize,
111    /// High-water mark for pending requests. When this limit is reached, new requests
112    /// will not be processed and will be queued up to [`max_queue_size`](Self::max_queue_size)
113    /// elements. Once both limits are reached, new requests will return
114    /// [`ReqError::HighWaterMarkReached`].
115    pub max_pending_requests: usize,
116}
117
118impl ReqOptions {
119    /// Creates new options based on the given profile.
120    pub fn new(profile: Profile) -> Self {
121        match profile {
122            Profile::Latency => Self::low_latency(),
123            Profile::Throughput => Self::high_throughput(),
124            Profile::Balanced => Self::balanced(),
125        }
126    }
127
128    /// Creates options optimized for low latency.
129    pub fn low_latency() -> Self {
130        Self {
131            write_buffer_size: 8 * KiB as usize,
132            write_buffer_linger: Some(Duration::from_micros(50)),
133            ..Default::default()
134        }
135    }
136
137    /// Creates options optimized for high throughput.
138    pub fn high_throughput() -> Self {
139        Self {
140            write_buffer_size: 256 * KiB as usize,
141            write_buffer_linger: Some(Duration::from_micros(200)),
142            ..Default::default()
143        }
144    }
145
146    /// Creates options optimized for a balanced trade-off between latency and throughput.
147    pub fn balanced() -> Self {
148        Self {
149            write_buffer_size: 32 * KiB as usize,
150            write_buffer_linger: Some(Duration::from_micros(100)),
151            ..Default::default()
152        }
153    }
154}
155
156impl ReqOptions {
157    /// Sets the timeout for the socket.
158    pub fn with_timeout(mut self, timeout: Duration) -> Self {
159        self.timeout = timeout;
160        self
161    }
162
163    /// Enables blocking initial connections to the target.
164    pub fn with_blocking_connect(mut self) -> Self {
165        self.blocking_connect = true;
166        self
167    }
168
169    /// Sets the backoff duration for the socket.
170    pub fn with_backoff_duration(mut self, backoff_duration: Duration) -> Self {
171        self.conn.backoff_duration = backoff_duration;
172        self
173    }
174
175    /// Sets the maximum number of retry attempts.
176    ///
177    /// If `None`, all connections will be retried indefinitely.
178    pub fn with_retry_attempts(mut self, retry_attempts: usize) -> Self {
179        self.conn.retry_attempts = Some(retry_attempts);
180        self
181    }
182
183    /// Sets the minimum payload size in bytes for compression to be used.
184    ///
185    /// If the payload is smaller than this threshold, it will not be compressed.
186    ///
187    /// Default: [`DEFAULT_BUFFER_SIZE`]
188    pub fn with_min_compress_size(mut self, min_compress_size: usize) -> Self {
189        self.min_compress_size = min_compress_size;
190        self
191    }
192
193    /// Sets the size (max capacity) of the write buffer in bytes.
194    /// When the buffer is full, it will be flushed to the underlying transport.
195    ///
196    /// Default: [`DEFAULT_BUFFER_SIZE`]
197    pub fn with_write_buffer_size(mut self, size: usize) -> Self {
198        self.write_buffer_size = size;
199        self
200    }
201
202    /// Sets the linger duration for the write buffer. If `None`, the write buffer will only be
203    /// flushed when the buffer is full.
204    ///
205    /// Default: 100µs
206    pub fn with_write_buffer_linger(mut self, duration: Option<Duration>) -> Self {
207        self.write_buffer_linger = duration;
208        self
209    }
210
211    /// Sets the size of the channel buffer between the socket and the driver.
212    /// This controls how many requests can be queued, on top of the current pending requests,
213    /// before the socket returns [`ReqError::HighWaterMarkReached`].
214    ///
215    /// Default: [`DEFAULT_QUEUE_SIZE`]
216    pub fn with_max_queue_size(mut self, size: usize) -> Self {
217        self.max_queue_size = size;
218        self
219    }
220
221    /// Sets the high-water mark for pending requests. When this limit is reached, new requests
222    /// will not be processed and will be queued up to [`Self::with_max_queue_size`] elements.
223    /// Once both limits are reached, new requests will return [`ReqError::HighWaterMarkReached`].
224    ///
225    /// Default: [`DEFAULT_QUEUE_SIZE`]
226    pub fn with_max_pending_requests(mut self, hwm: usize) -> Self {
227        self.max_pending_requests = hwm;
228        self
229    }
230}
231
232impl Default for ReqOptions {
233    fn default() -> Self {
234        Self {
235            conn: ConnOptions::default(),
236            timeout: Duration::from_secs(5),
237            blocking_connect: false,
238            min_compress_size: DEFAULT_BUFFER_SIZE,
239            write_buffer_size: DEFAULT_BUFFER_SIZE,
240            write_buffer_linger: Some(Duration::from_micros(100)),
241            max_queue_size: DEFAULT_QUEUE_SIZE,
242            max_pending_requests: DEFAULT_QUEUE_SIZE,
243        }
244    }
245}
246
247/// A message sent from a [`ReqSocket`] to the backend task.
248#[derive(Debug, Clone)]
249pub struct ReqMessage {
250    compression_type: CompressionType,
251    payload: Bytes,
252}
253
254impl ReqMessage {
255    pub fn new(payload: Bytes) -> Self {
256        Self {
257            // Initialize the compression type to None.
258            // The actual compression type will be set in the `compress` method.
259            compression_type: CompressionType::None,
260            payload,
261        }
262    }
263
264    #[inline]
265    pub fn payload(&self) -> &Bytes {
266        &self.payload
267    }
268
269    #[inline]
270    pub fn into_payload(self) -> Bytes {
271        self.payload
272    }
273
274    #[inline]
275    pub fn into_wire(self, id: u32) -> reqrep::Message {
276        reqrep::Message::new(id, self.compression_type as u8, self.payload)
277    }
278
279    #[inline]
280    pub fn compress(&mut self, compressor: &dyn Compressor) -> Result<(), ReqError> {
281        self.payload = compressor.compress(&self.payload)?;
282        self.compression_type = compressor.compression_type();
283
284        Ok(())
285    }
286}
287
288/// The request socket state, shared between the backend task and the socket.
289/// Generic over the transport-level stats type.
290#[derive(Debug, Default)]
291pub(crate) struct SocketState<S: Default> {
292    /// The socket stats.
293    pub(crate) stats: Arc<SocketStats<ReqStats>>,
294    /// The transport-level stats. We wrap the inner stats in an `Arc`
295    /// for cheap clone on read.
296    pub(crate) transport_stats: Arc<ArcSwap<S>>,
297}
298
299// Manual clone implementation needed here because `S` is not `Clone`.
300impl<S: Default> Clone for SocketState<S> {
301    fn clone(&self) -> Self {
302        Self { stats: Arc::clone(&self.stats), transport_stats: self.transport_stats.clone() }
303    }
304}