Skip to main content

msg_socket/req/
socket.rs

1use std::{marker::PhantomData, net::SocketAddr, path::PathBuf, sync::Arc};
2
3use arc_swap::Guard;
4use bytes::Bytes;
5use rustc_hash::FxHashMap;
6use tokio::{
7    net::{ToSocketAddrs, lookup_host},
8    sync::{mpsc, mpsc::error::TrySendError, oneshot},
9};
10use tokio_util::codec::Framed;
11
12use msg_common::span::WithSpan;
13use msg_transport::{Address, MeteredIo, Transport};
14use msg_wire::{compression::Compressor, reqrep};
15
16use super::{ReqError, ReqOptions};
17use crate::{
18    ConnectionHook, ConnectionHookErased, ConnectionState, DRIVER_ID, ExponentialBackoff,
19    ReqMessage, SendCommand,
20    req::{
21        SocketState,
22        conn_manager::{ConnCtl, ConnManager},
23        driver::ReqDriver,
24        stats::ReqStats,
25    },
26    stats::SocketStats,
27};
28use std::sync::atomic::Ordering;
29
30/// The request socket.
31pub struct ReqSocket<T: Transport<A>, A: Address> {
32    /// Command channel to the backend task.
33    to_driver: Option<mpsc::Sender<SendCommand>>,
34    /// The socket transport.
35    transport: Option<T>,
36    /// Options for the socket. These are shared with the backend task.
37    options: Arc<ReqOptions>,
38    /// Socket state. This is shared with the backend task.
39    state: SocketState<T::Stats>,
40    /// Optional connection hook.
41    hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
42    /// Optional message compressor. This is shared with the backend task.
43    // NOTE: for now we're using dynamic dispatch, since using generics here
44    // complicates the API a lot. We can always change this later for perf reasons.
45    compressor: Option<Arc<dyn Compressor>>,
46    /// Marker for the address type.
47    _marker: PhantomData<A>,
48}
49
50impl<T> ReqSocket<T, SocketAddr>
51where
52    T: Transport<SocketAddr>,
53{
54    /// Connects to the target address with the default options.
55    pub async fn connect(&mut self, addr: impl ToSocketAddrs) -> Result<(), ReqError> {
56        let mut addrs = lookup_host(addr).await?;
57        let endpoint = addrs.next().ok_or(ReqError::NoValidEndpoints)?;
58
59        self.try_connect(endpoint).await
60    }
61
62    /// Starts connecting to a resolved socket address. This is essentially a [`Self::connect`]
63    /// variant that doesn't error or block due to DNS resolution and blocking connect.
64    pub fn connect_sync(&mut self, addr: SocketAddr) {
65        // TODO: Don't panic, return error
66        let transport = self.transport.take().expect("Transport has been moved already");
67        // We initialize the connection as inactive, and let it be activated
68        // by the backend task as soon as the driver is spawned.
69        let conn_state = ConnectionState::Inactive {
70            addr,
71            backoff: ExponentialBackoff::from(&self.options.conn),
72        };
73
74        self.spawn_driver(addr, transport, conn_state)
75    }
76}
77
78impl<T> ReqSocket<T, PathBuf>
79where
80    T: Transport<PathBuf>,
81{
82    /// Connects to the target path with the default options.
83    pub async fn connect(&mut self, addr: impl Into<PathBuf>) -> Result<(), ReqError> {
84        self.try_connect(addr.into().clone()).await
85    }
86}
87
88impl<T, A> ReqSocket<T, A>
89where
90    T: Transport<A>,
91    A: Address,
92{
93    pub fn new(transport: T) -> Self {
94        Self::with_options(transport, ReqOptions::balanced())
95    }
96
97    pub fn with_options(transport: T, options: ReqOptions) -> Self {
98        Self {
99            to_driver: None,
100            transport: Some(transport),
101            options: Arc::new(options),
102            state: SocketState::default(),
103            hook: None,
104            compressor: None,
105            _marker: PhantomData,
106        }
107    }
108
109    /// Sets the message compressor for this socket.
110    pub fn with_compressor<C: Compressor + 'static>(mut self, compressor: C) -> Self {
111        self.compressor = Some(Arc::new(compressor));
112        self
113    }
114
115    /// Sets the connection hook for this socket.
116    ///
117    /// The connection hook is called after connecting to the server, before the connection
118    /// is used for request/reply communication.
119    ///
120    /// # Panics
121    ///
122    /// Panics if the driver has already been started (i.e., after calling `connect`).
123    pub fn with_connection_hook<H>(mut self, hook: H) -> Self
124    where
125        H: ConnectionHook<T::Io>,
126    {
127        assert!(self.transport.is_some(), "cannot set connection hook after driver has started");
128        self.hook = Some(Arc::new(hook));
129        self
130    }
131
132    /// Returns the socket stats.
133    pub fn stats(&self) -> &SocketStats<ReqStats> {
134        &self.state.stats
135    }
136
137    /// Get the latest transport-level stats snapshot.
138    pub fn transport_stats(&self) -> Guard<Arc<T::Stats>> {
139        self.state.transport_stats.load()
140    }
141
142    pub async fn request(&self, message: Bytes) -> Result<Bytes, ReqError> {
143        let (response_tx, response_rx) = oneshot::channel();
144
145        let msg = ReqMessage::new(message);
146
147        self.to_driver
148            .as_ref()
149            .ok_or(ReqError::SocketClosed)?
150            .try_send(SendCommand::new(WithSpan::current(msg), response_tx))
151            .map_err(|err| match err {
152                TrySendError::Full(_) => ReqError::HighWaterMarkReached,
153                TrySendError::Closed(_) => ReqError::SocketClosed,
154            })?;
155
156        response_rx.await.map_err(|_| ReqError::SocketClosed)?
157    }
158
159    /// Tries to connect to the target endpoint with the default options.
160    /// A ReqSocket can only be connected to a single address.
161    pub async fn try_connect(&mut self, endpoint: A) -> Result<(), ReqError> {
162        // TODO: Don't panic, return error
163        let mut transport = self.transport.take().expect("transport has been moved already");
164
165        let conn_state = if self.options.blocking_connect {
166            let io = transport
167                .connect(endpoint.clone())
168                .await
169                .map_err(|e| ReqError::Connect(Box::new(e)))?;
170
171            let metered = MeteredIo::new(io, Arc::clone(&self.state.transport_stats));
172            let framed = Framed::new(metered, reqrep::Codec::new());
173
174            ConnectionState::Active { channel: framed }
175        } else {
176            // We initialize the connection as inactive, and let it be activated
177            // by the backend task as soon as the driver is spawned.
178            ConnectionState::Inactive {
179                addr: endpoint.clone(),
180                backoff: ExponentialBackoff::from(&self.options.conn),
181            }
182        };
183
184        self.spawn_driver(endpoint, transport, conn_state);
185
186        Ok(())
187    }
188
189    /// Internal method to initialize and spawn the driver.
190    fn spawn_driver(&mut self, endpoint: A, transport: T, conn_ctl: ConnCtl<T::Io, T::Stats, A>) {
191        let (to_driver, from_socket) = mpsc::channel(self.options.max_queue_size);
192
193        let timeout_check_interval = tokio::time::interval(self.options.timeout / 10);
194
195        // TODO: we should limit the amount of active outgoing requests, and that should be the
196        // capacity. If we do this, we'll never have to re-allocate.
197        let pending_requests = FxHashMap::default();
198
199        let id = DRIVER_ID.fetch_add(1, Ordering::Relaxed);
200        let span = tracing::info_span!(parent: None, "req_driver", id = format!("req-{}", id), addr = ?endpoint);
201
202        let linger_timer = self.options.write_buffer_linger.map(|duration| {
203            let mut timer = tokio::time::interval(duration);
204            timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
205            timer
206        });
207
208        // Create connection manager
209        let conn_manager = ConnManager::new(
210            self.options.conn.clone(),
211            transport,
212            endpoint,
213            conn_ctl,
214            Arc::clone(&self.state.transport_stats),
215            self.hook.take(),
216            span.clone(),
217        );
218
219        // Create the socket backend
220        let driver: ReqDriver<T, A> = ReqDriver {
221            options: Arc::clone(&self.options),
222            socket_state: self.state.clone(),
223            id_counter: 0,
224            from_socket,
225            conn_manager,
226            linger_timer,
227            pending_requests,
228            timeout_check_interval,
229            pending_egress: None,
230            compressor: self.compressor.clone(),
231            id,
232            span,
233        };
234
235        // Spawn the backend task
236        tokio::spawn(driver);
237
238        self.to_driver = Some(to_driver);
239    }
240}