use std::{
sync::{Arc, atomic::AtomicUsize},
time::Duration,
};
use arc_swap::ArcSwap;
use bytes::Bytes;
use thiserror::Error;
use tokio::sync::oneshot;
use msg_common::{constants::KiB, span::WithSpan};
use msg_wire::{
compression::{CompressionType, Compressor},
reqrep,
};
mod conn_manager;
mod driver;
mod socket;
mod stats;
pub use socket::*;
use crate::{Profile, stats::SocketStats};
use stats::ReqStats;
use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_QUEUE_SIZE};
pub(crate) static DRIVER_ID: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug, Error)]
pub enum ReqError {
#[error("IO error: {0:?}")]
Io(#[from] std::io::Error),
#[error("Wire protocol error: {0:?}")]
Wire(#[from] reqrep::Error),
#[error("Socket closed")]
SocketClosed,
#[error("Request timed out")]
Timeout,
#[error("Could not connect to any valid endpoints")]
NoValidEndpoints,
#[error("Failed to connect to the target endpoint: {0:?}")]
Connect(Box<dyn std::error::Error + Send + Sync>),
#[error("High-water mark reached")]
HighWaterMarkReached,
}
#[derive(Debug)]
pub struct SendCommand {
pub message: WithSpan<ReqMessage>,
pub response: oneshot::Sender<Result<Bytes, ReqError>>,
}
impl SendCommand {
pub fn new(
message: WithSpan<ReqMessage>,
response: oneshot::Sender<Result<Bytes, ReqError>>,
) -> Self {
Self { message, response }
}
}
#[derive(Debug, Clone)]
pub struct ConnOptions {
pub backoff_duration: Duration,
pub retry_attempts: Option<usize>,
}
impl Default for ConnOptions {
fn default() -> Self {
Self {
backoff_duration: Duration::from_millis(200),
retry_attempts: Some(9),
}
}
}
#[derive(Debug, Clone)]
pub struct ReqOptions {
pub conn: ConnOptions,
pub timeout: Duration,
pub blocking_connect: bool,
pub min_compress_size: usize,
pub write_buffer_size: usize,
pub write_buffer_linger: Option<Duration>,
pub max_queue_size: usize,
pub max_pending_requests: usize,
}
impl ReqOptions {
pub fn new(profile: Profile) -> Self {
match profile {
Profile::Latency => Self::low_latency(),
Profile::Throughput => Self::high_throughput(),
Profile::Balanced => Self::balanced(),
}
}
pub fn low_latency() -> Self {
Self {
write_buffer_size: 8 * KiB as usize,
write_buffer_linger: Some(Duration::from_micros(50)),
..Default::default()
}
}
pub fn high_throughput() -> Self {
Self {
write_buffer_size: 256 * KiB as usize,
write_buffer_linger: Some(Duration::from_micros(200)),
..Default::default()
}
}
pub fn balanced() -> Self {
Self {
write_buffer_size: 32 * KiB as usize,
write_buffer_linger: Some(Duration::from_micros(100)),
..Default::default()
}
}
}
impl ReqOptions {
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_blocking_connect(mut self) -> Self {
self.blocking_connect = true;
self
}
pub fn with_backoff_duration(mut self, backoff_duration: Duration) -> Self {
self.conn.backoff_duration = backoff_duration;
self
}
pub fn with_retry_attempts(mut self, retry_attempts: usize) -> Self {
self.conn.retry_attempts = Some(retry_attempts);
self
}
pub fn with_min_compress_size(mut self, min_compress_size: usize) -> Self {
self.min_compress_size = min_compress_size;
self
}
pub fn with_write_buffer_size(mut self, size: usize) -> Self {
self.write_buffer_size = size;
self
}
pub fn with_write_buffer_linger(mut self, duration: Option<Duration>) -> Self {
self.write_buffer_linger = duration;
self
}
pub fn with_max_queue_size(mut self, size: usize) -> Self {
self.max_queue_size = size;
self
}
pub fn with_max_pending_requests(mut self, hwm: usize) -> Self {
self.max_pending_requests = hwm;
self
}
}
impl Default for ReqOptions {
fn default() -> Self {
Self {
conn: ConnOptions::default(),
timeout: Duration::from_secs(5),
blocking_connect: false,
min_compress_size: DEFAULT_BUFFER_SIZE,
write_buffer_size: DEFAULT_BUFFER_SIZE,
write_buffer_linger: Some(Duration::from_micros(100)),
max_queue_size: DEFAULT_QUEUE_SIZE,
max_pending_requests: DEFAULT_QUEUE_SIZE,
}
}
}
#[derive(Debug, Clone)]
pub struct ReqMessage {
compression_type: CompressionType,
payload: Bytes,
}
impl ReqMessage {
pub fn new(payload: Bytes) -> Self {
Self {
compression_type: CompressionType::None,
payload,
}
}
#[inline]
pub fn payload(&self) -> &Bytes {
&self.payload
}
#[inline]
pub fn into_payload(self) -> Bytes {
self.payload
}
#[inline]
pub fn into_wire(self, id: u32) -> reqrep::Message {
reqrep::Message::new(id, self.compression_type as u8, self.payload)
}
#[inline]
pub fn compress(&mut self, compressor: &dyn Compressor) -> Result<(), ReqError> {
self.payload = compressor.compress(&self.payload)?;
self.compression_type = compressor.compression_type();
Ok(())
}
}
#[derive(Debug, Default)]
pub(crate) struct SocketState<S: Default> {
pub(crate) stats: Arc<SocketStats<ReqStats>>,
pub(crate) transport_stats: Arc<ArcSwap<S>>,
}
impl<S: Default> Clone for SocketState<S> {
fn clone(&self) -> Self {
Self { stats: Arc::clone(&self.stats), transport_stats: self.transport_stats.clone() }
}
}