use std::{marker::PhantomData, net::SocketAddr, path::PathBuf, sync::Arc};
use arc_swap::Guard;
use bytes::Bytes;
use rustc_hash::FxHashMap;
use tokio::{
net::{ToSocketAddrs, lookup_host},
sync::{mpsc, mpsc::error::TrySendError, oneshot},
};
use tokio_util::codec::Framed;
use msg_common::span::WithSpan;
use msg_transport::{Address, MeteredIo, Transport};
use msg_wire::{compression::Compressor, reqrep};
use super::{ReqError, ReqOptions};
use crate::{
ConnectionHook, ConnectionHookErased, ConnectionState, DRIVER_ID, ExponentialBackoff,
ReqMessage, SendCommand,
req::{
SocketState,
conn_manager::{ConnCtl, ConnManager},
driver::ReqDriver,
stats::ReqStats,
},
stats::SocketStats,
};
use std::sync::atomic::Ordering;
pub struct ReqSocket<T: Transport<A>, A: Address> {
to_driver: Option<mpsc::Sender<SendCommand>>,
transport: Option<T>,
options: Arc<ReqOptions>,
state: SocketState<T::Stats>,
hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
compressor: Option<Arc<dyn Compressor>>,
_marker: PhantomData<A>,
}
impl<T> ReqSocket<T, SocketAddr>
where
T: Transport<SocketAddr>,
{
pub async fn connect(&mut self, addr: impl ToSocketAddrs) -> Result<(), ReqError> {
let mut addrs = lookup_host(addr).await?;
let endpoint = addrs.next().ok_or(ReqError::NoValidEndpoints)?;
self.try_connect(endpoint).await
}
pub fn connect_sync(&mut self, addr: SocketAddr) {
let transport = self.transport.take().expect("Transport has been moved already");
let conn_state = ConnectionState::Inactive {
addr,
backoff: ExponentialBackoff::from(&self.options.conn),
};
self.spawn_driver(addr, transport, conn_state)
}
}
impl<T> ReqSocket<T, PathBuf>
where
T: Transport<PathBuf>,
{
pub async fn connect(&mut self, addr: impl Into<PathBuf>) -> Result<(), ReqError> {
self.try_connect(addr.into().clone()).await
}
}
impl<T, A> ReqSocket<T, A>
where
T: Transport<A>,
A: Address,
{
pub fn new(transport: T) -> Self {
Self::with_options(transport, ReqOptions::balanced())
}
pub fn with_options(transport: T, options: ReqOptions) -> Self {
Self {
to_driver: None,
transport: Some(transport),
options: Arc::new(options),
state: SocketState::default(),
hook: None,
compressor: None,
_marker: PhantomData,
}
}
pub fn with_compressor<C: Compressor + 'static>(mut self, compressor: C) -> Self {
self.compressor = Some(Arc::new(compressor));
self
}
pub fn with_connection_hook<H>(mut self, hook: H) -> Self
where
H: ConnectionHook<T::Io>,
{
assert!(self.transport.is_some(), "cannot set connection hook after driver has started");
self.hook = Some(Arc::new(hook));
self
}
pub fn stats(&self) -> &SocketStats<ReqStats> {
&self.state.stats
}
pub fn transport_stats(&self) -> Guard<Arc<T::Stats>> {
self.state.transport_stats.load()
}
pub async fn request(&self, message: Bytes) -> Result<Bytes, ReqError> {
let (response_tx, response_rx) = oneshot::channel();
let msg = ReqMessage::new(message);
self.to_driver
.as_ref()
.ok_or(ReqError::SocketClosed)?
.try_send(SendCommand::new(WithSpan::current(msg), response_tx))
.map_err(|err| match err {
TrySendError::Full(_) => ReqError::HighWaterMarkReached,
TrySendError::Closed(_) => ReqError::SocketClosed,
})?;
response_rx.await.map_err(|_| ReqError::SocketClosed)?
}
pub async fn try_connect(&mut self, endpoint: A) -> Result<(), ReqError> {
let mut transport = self.transport.take().expect("transport has been moved already");
let conn_state = if self.options.blocking_connect {
let io = transport
.connect(endpoint.clone())
.await
.map_err(|e| ReqError::Connect(Box::new(e)))?;
let metered = MeteredIo::new(io, Arc::clone(&self.state.transport_stats));
let framed = Framed::new(metered, reqrep::Codec::new());
ConnectionState::Active { channel: framed }
} else {
ConnectionState::Inactive {
addr: endpoint.clone(),
backoff: ExponentialBackoff::from(&self.options.conn),
}
};
self.spawn_driver(endpoint, transport, conn_state);
Ok(())
}
fn spawn_driver(&mut self, endpoint: A, transport: T, conn_ctl: ConnCtl<T::Io, T::Stats, A>) {
let (to_driver, from_socket) = mpsc::channel(self.options.max_queue_size);
let timeout_check_interval = tokio::time::interval(self.options.timeout / 10);
let pending_requests = FxHashMap::default();
let id = DRIVER_ID.fetch_add(1, Ordering::Relaxed);
let span = tracing::info_span!(parent: None, "req_driver", id = format!("req-{}", id), addr = ?endpoint);
let linger_timer = self.options.write_buffer_linger.map(|duration| {
let mut timer = tokio::time::interval(duration);
timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
timer
});
let conn_manager = ConnManager::new(
self.options.conn.clone(),
transport,
endpoint,
conn_ctl,
Arc::clone(&self.state.transport_stats),
self.hook.take(),
span.clone(),
);
let driver: ReqDriver<T, A> = ReqDriver {
options: Arc::clone(&self.options),
socket_state: self.state.clone(),
id_counter: 0,
from_socket,
conn_manager,
linger_timer,
pending_requests,
timeout_check_interval,
pending_egress: None,
compressor: self.compressor.clone(),
id,
span,
};
tokio::spawn(driver);
self.to_driver = Some(to_driver);
}
}