1use std::{
2 sync::{Arc, atomic::AtomicUsize},
3 time::Duration,
4};
5
6use arc_swap::ArcSwap;
7use bytes::Bytes;
8use thiserror::Error;
9use tokio::sync::oneshot;
10
11use msg_common::{constants::KiB, span::WithSpan};
12use msg_wire::{
13 compression::{CompressionType, Compressor},
14 reqrep,
15};
16
17mod conn_manager;
18mod driver;
19mod socket;
20mod stats;
21pub use socket::*;
22
23use crate::{Profile, stats::SocketStats};
24use stats::ReqStats;
25
26use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_QUEUE_SIZE};
27
28pub(crate) static DRIVER_ID: AtomicUsize = AtomicUsize::new(0);
29
30#[derive(Debug, Error)]
32pub enum ReqError {
33 #[error("IO error: {0:?}")]
34 Io(#[from] std::io::Error),
35 #[error("Wire protocol error: {0:?}")]
36 Wire(#[from] reqrep::Error),
37 #[error("Socket closed")]
38 SocketClosed,
39 #[error("Request timed out")]
40 Timeout,
41 #[error("Could not connect to any valid endpoints")]
42 NoValidEndpoints,
43 #[error("Failed to connect to the target endpoint: {0:?}")]
44 Connect(Box<dyn std::error::Error + Send + Sync>),
45 #[error("High-water mark reached")]
46 HighWaterMarkReached,
47}
48
49#[derive(Debug)]
51pub struct SendCommand {
52 pub message: WithSpan<ReqMessage>,
54 pub response: oneshot::Sender<Result<Bytes, ReqError>>,
56}
57
58impl SendCommand {
59 pub fn new(
61 message: WithSpan<ReqMessage>,
62 response: oneshot::Sender<Result<Bytes, ReqError>>,
63 ) -> Self {
64 Self { message, response }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct ConnOptions {
71 pub backoff_duration: Duration,
73 pub retry_attempts: Option<usize>,
75}
76
77impl Default for ConnOptions {
78 fn default() -> Self {
79 Self {
80 backoff_duration: Duration::from_millis(200),
86 retry_attempts: Some(9),
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct ReqOptions {
94 pub conn: ConnOptions,
96 pub timeout: Duration,
98 pub blocking_connect: bool,
100 pub min_compress_size: usize,
103 pub write_buffer_size: usize,
105 pub write_buffer_linger: Option<Duration>,
107 pub max_queue_size: usize,
111 pub max_pending_requests: usize,
116}
117
118impl ReqOptions {
119 pub fn new(profile: Profile) -> Self {
121 match profile {
122 Profile::Latency => Self::low_latency(),
123 Profile::Throughput => Self::high_throughput(),
124 Profile::Balanced => Self::balanced(),
125 }
126 }
127
128 pub fn low_latency() -> Self {
130 Self {
131 write_buffer_size: 8 * KiB as usize,
132 write_buffer_linger: Some(Duration::from_micros(50)),
133 ..Default::default()
134 }
135 }
136
137 pub fn high_throughput() -> Self {
139 Self {
140 write_buffer_size: 256 * KiB as usize,
141 write_buffer_linger: Some(Duration::from_micros(200)),
142 ..Default::default()
143 }
144 }
145
146 pub fn balanced() -> Self {
148 Self {
149 write_buffer_size: 32 * KiB as usize,
150 write_buffer_linger: Some(Duration::from_micros(100)),
151 ..Default::default()
152 }
153 }
154}
155
156impl ReqOptions {
157 pub fn with_timeout(mut self, timeout: Duration) -> Self {
159 self.timeout = timeout;
160 self
161 }
162
163 pub fn with_blocking_connect(mut self) -> Self {
165 self.blocking_connect = true;
166 self
167 }
168
169 pub fn with_backoff_duration(mut self, backoff_duration: Duration) -> Self {
171 self.conn.backoff_duration = backoff_duration;
172 self
173 }
174
175 pub fn with_retry_attempts(mut self, retry_attempts: usize) -> Self {
179 self.conn.retry_attempts = Some(retry_attempts);
180 self
181 }
182
183 pub fn with_min_compress_size(mut self, min_compress_size: usize) -> Self {
189 self.min_compress_size = min_compress_size;
190 self
191 }
192
193 pub fn with_write_buffer_size(mut self, size: usize) -> Self {
198 self.write_buffer_size = size;
199 self
200 }
201
202 pub fn with_write_buffer_linger(mut self, duration: Option<Duration>) -> Self {
207 self.write_buffer_linger = duration;
208 self
209 }
210
211 pub fn with_max_queue_size(mut self, size: usize) -> Self {
217 self.max_queue_size = size;
218 self
219 }
220
221 pub fn with_max_pending_requests(mut self, hwm: usize) -> Self {
227 self.max_pending_requests = hwm;
228 self
229 }
230}
231
232impl Default for ReqOptions {
233 fn default() -> Self {
234 Self {
235 conn: ConnOptions::default(),
236 timeout: Duration::from_secs(5),
237 blocking_connect: false,
238 min_compress_size: DEFAULT_BUFFER_SIZE,
239 write_buffer_size: DEFAULT_BUFFER_SIZE,
240 write_buffer_linger: Some(Duration::from_micros(100)),
241 max_queue_size: DEFAULT_QUEUE_SIZE,
242 max_pending_requests: DEFAULT_QUEUE_SIZE,
243 }
244 }
245}
246
247#[derive(Debug, Clone)]
249pub struct ReqMessage {
250 compression_type: CompressionType,
251 payload: Bytes,
252}
253
254impl ReqMessage {
255 pub fn new(payload: Bytes) -> Self {
256 Self {
257 compression_type: CompressionType::None,
260 payload,
261 }
262 }
263
264 #[inline]
265 pub fn payload(&self) -> &Bytes {
266 &self.payload
267 }
268
269 #[inline]
270 pub fn into_payload(self) -> Bytes {
271 self.payload
272 }
273
274 #[inline]
275 pub fn into_wire(self, id: u32) -> reqrep::Message {
276 reqrep::Message::new(id, self.compression_type as u8, self.payload)
277 }
278
279 #[inline]
280 pub fn compress(&mut self, compressor: &dyn Compressor) -> Result<(), ReqError> {
281 self.payload = compressor.compress(&self.payload)?;
282 self.compression_type = compressor.compression_type();
283
284 Ok(())
285 }
286}
287
288#[derive(Debug, Default)]
291pub(crate) struct SocketState<S: Default> {
292 pub(crate) stats: Arc<SocketStats<ReqStats>>,
294 pub(crate) transport_stats: Arc<ArcSwap<S>>,
297}
298
299impl<S: Default> Clone for SocketState<S> {
301 fn clone(&self) -> Self {
302 Self { stats: Arc::clone(&self.stats), transport_stats: self.transport_stats.clone() }
303 }
304}