1use std::{marker::PhantomData, net::SocketAddr, path::PathBuf, sync::Arc};
2
3use arc_swap::Guard;
4use bytes::Bytes;
5use rustc_hash::FxHashMap;
6use tokio::{
7 net::{ToSocketAddrs, lookup_host},
8 sync::{mpsc, mpsc::error::TrySendError, oneshot},
9};
10use tokio_util::codec::Framed;
11
12use msg_common::span::WithSpan;
13use msg_transport::{Address, MeteredIo, Transport};
14use msg_wire::{compression::Compressor, reqrep};
15
16use super::{ReqError, ReqOptions};
17use crate::{
18 ConnectionHook, ConnectionHookErased, ConnectionState, DRIVER_ID, ExponentialBackoff,
19 ReqMessage, SendCommand,
20 req::{
21 SocketState,
22 conn_manager::{ConnCtl, ConnManager},
23 driver::ReqDriver,
24 stats::ReqStats,
25 },
26 stats::SocketStats,
27};
28use std::sync::atomic::Ordering;
29
30pub struct ReqSocket<T: Transport<A>, A: Address> {
32 to_driver: Option<mpsc::Sender<SendCommand>>,
34 transport: Option<T>,
36 options: Arc<ReqOptions>,
38 state: SocketState<T::Stats>,
40 hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
42 compressor: Option<Arc<dyn Compressor>>,
46 _marker: PhantomData<A>,
48}
49
50impl<T> ReqSocket<T, SocketAddr>
51where
52 T: Transport<SocketAddr>,
53{
54 pub async fn connect(&mut self, addr: impl ToSocketAddrs) -> Result<(), ReqError> {
56 let mut addrs = lookup_host(addr).await?;
57 let endpoint = addrs.next().ok_or(ReqError::NoValidEndpoints)?;
58
59 self.try_connect(endpoint).await
60 }
61
62 pub fn connect_sync(&mut self, addr: SocketAddr) {
65 let transport = self.transport.take().expect("Transport has been moved already");
67 let conn_state = ConnectionState::Inactive {
70 addr,
71 backoff: ExponentialBackoff::from(&self.options.conn),
72 };
73
74 self.spawn_driver(addr, transport, conn_state)
75 }
76}
77
78impl<T> ReqSocket<T, PathBuf>
79where
80 T: Transport<PathBuf>,
81{
82 pub async fn connect(&mut self, addr: impl Into<PathBuf>) -> Result<(), ReqError> {
84 self.try_connect(addr.into().clone()).await
85 }
86}
87
88impl<T, A> ReqSocket<T, A>
89where
90 T: Transport<A>,
91 A: Address,
92{
93 pub fn new(transport: T) -> Self {
94 Self::with_options(transport, ReqOptions::balanced())
95 }
96
97 pub fn with_options(transport: T, options: ReqOptions) -> Self {
98 Self {
99 to_driver: None,
100 transport: Some(transport),
101 options: Arc::new(options),
102 state: SocketState::default(),
103 hook: None,
104 compressor: None,
105 _marker: PhantomData,
106 }
107 }
108
109 pub fn with_compressor<C: Compressor + 'static>(mut self, compressor: C) -> Self {
111 self.compressor = Some(Arc::new(compressor));
112 self
113 }
114
115 pub fn with_connection_hook<H>(mut self, hook: H) -> Self
124 where
125 H: ConnectionHook<T::Io>,
126 {
127 assert!(self.transport.is_some(), "cannot set connection hook after driver has started");
128 self.hook = Some(Arc::new(hook));
129 self
130 }
131
132 pub fn stats(&self) -> &SocketStats<ReqStats> {
134 &self.state.stats
135 }
136
137 pub fn transport_stats(&self) -> Guard<Arc<T::Stats>> {
139 self.state.transport_stats.load()
140 }
141
142 pub async fn request(&self, message: Bytes) -> Result<Bytes, ReqError> {
143 let (response_tx, response_rx) = oneshot::channel();
144
145 let msg = ReqMessage::new(message);
146
147 self.to_driver
148 .as_ref()
149 .ok_or(ReqError::SocketClosed)?
150 .try_send(SendCommand::new(WithSpan::current(msg), response_tx))
151 .map_err(|err| match err {
152 TrySendError::Full(_) => ReqError::HighWaterMarkReached,
153 TrySendError::Closed(_) => ReqError::SocketClosed,
154 })?;
155
156 response_rx.await.map_err(|_| ReqError::SocketClosed)?
157 }
158
159 pub async fn try_connect(&mut self, endpoint: A) -> Result<(), ReqError> {
162 let mut transport = self.transport.take().expect("transport has been moved already");
164
165 let conn_state = if self.options.blocking_connect {
166 let io = transport
167 .connect(endpoint.clone())
168 .await
169 .map_err(|e| ReqError::Connect(Box::new(e)))?;
170
171 let metered = MeteredIo::new(io, Arc::clone(&self.state.transport_stats));
172 let framed = Framed::new(metered, reqrep::Codec::new());
173
174 ConnectionState::Active { channel: framed }
175 } else {
176 ConnectionState::Inactive {
179 addr: endpoint.clone(),
180 backoff: ExponentialBackoff::from(&self.options.conn),
181 }
182 };
183
184 self.spawn_driver(endpoint, transport, conn_state);
185
186 Ok(())
187 }
188
189 fn spawn_driver(&mut self, endpoint: A, transport: T, conn_ctl: ConnCtl<T::Io, T::Stats, A>) {
191 let (to_driver, from_socket) = mpsc::channel(self.options.max_queue_size);
192
193 let timeout_check_interval = tokio::time::interval(self.options.timeout / 10);
194
195 let pending_requests = FxHashMap::default();
198
199 let id = DRIVER_ID.fetch_add(1, Ordering::Relaxed);
200 let span = tracing::info_span!(parent: None, "req_driver", id = format!("req-{}", id), addr = ?endpoint);
201
202 let linger_timer = self.options.write_buffer_linger.map(|duration| {
203 let mut timer = tokio::time::interval(duration);
204 timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
205 timer
206 });
207
208 let conn_manager = ConnManager::new(
210 self.options.conn.clone(),
211 transport,
212 endpoint,
213 conn_ctl,
214 Arc::clone(&self.state.transport_stats),
215 self.hook.take(),
216 span.clone(),
217 );
218
219 let driver: ReqDriver<T, A> = ReqDriver {
221 options: Arc::clone(&self.options),
222 socket_state: self.state.clone(),
223 id_counter: 0,
224 from_socket,
225 conn_manager,
226 linger_timer,
227 pending_requests,
228 timeout_check_interval,
229 pending_egress: None,
230 compressor: self.compressor.clone(),
231 id,
232 span,
233 };
234
235 tokio::spawn(driver);
237
238 self.to_driver = Some(to_driver);
239 }
240}