Skip to main content

msg_socket/rep/
socket.rs

1use std::{
2    net::SocketAddr,
3    path::PathBuf,
4    pin::Pin,
5    sync::Arc,
6    task::{Context, Poll},
7};
8
9use futures::{Stream, stream::FuturesUnordered};
10use tokio::{
11    net::{ToSocketAddrs, lookup_host},
12    sync::mpsc,
13    task::{JoinHandle, JoinSet},
14};
15use tokio_stream::StreamMap;
16use tracing::{debug, warn};
17
18use crate::{
19    ConnectionHook, ConnectionHookErased, DEFAULT_QUEUE_SIZE, RepOptions, Request,
20    rep::{RepError, SocketState, driver::RepDriver},
21};
22
23use msg_transport::{Address, Transport};
24use msg_wire::compression::Compressor;
25
26use super::stats::RepStats;
27
28/// A reply socket. This socket implements [`Stream`] and yields incoming [`Request`]s.
29pub struct RepSocket<T: Transport<A>, A: Address> {
30    /// The reply socket options, shared with the driver.
31    options: Arc<RepOptions>,
32    /// The reply socket state, shared with the driver.
33    state: Arc<SocketState>,
34    /// Receiver from the socket driver.
35    from_driver: Option<mpsc::Receiver<Request<A>>>,
36    /// The transport used by this socket. This value is temporary and will be moved
37    /// to the driver task once the socket is bound.
38    transport: Option<T>,
39    /// Optional connection hook.
40    hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
41    /// The local address this socket is bound to.
42    local_addr: Option<A>,
43    /// Optional message compressor.
44    compressor: Option<Arc<dyn Compressor>>,
45    /// A sender channel for [`Transport::Control`] changes.
46    control_tx: Option<mpsc::Sender<T::Control>>,
47
48    /// Internal task representing a running [`RepDriver`].
49    _driver_task: Option<JoinHandle<Result<(), RepError>>>,
50}
51
52impl<T> RepSocket<T, SocketAddr>
53where
54    T: Transport<SocketAddr>,
55{
56    /// Binds the socket to the given socket address.
57    pub async fn bind(&mut self, addr: impl ToSocketAddrs) -> Result<(), RepError> {
58        let addrs = lookup_host(addr).await?;
59        self.try_bind(addrs.collect()).await
60    }
61}
62
63impl<T> RepSocket<T, PathBuf>
64where
65    T: Transport<PathBuf>,
66{
67    /// Binds the socket to the given path.
68    pub async fn bind(&mut self, path: impl Into<PathBuf>) -> Result<(), RepError> {
69        let addr = path.into().clone();
70        self.try_bind(vec![addr]).await
71    }
72}
73
74impl<T, A> RepSocket<T, A>
75where
76    T: Transport<A>,
77    A: Address,
78{
79    /// Creates a new reply socket with the default [`RepOptions`].
80    pub fn new(transport: T) -> Self {
81        Self::with_options(transport, RepOptions::balanced())
82    }
83
84    /// Sets the options for this socket.
85    pub fn with_options(transport: T, options: RepOptions) -> Self {
86        Self {
87            from_driver: None,
88            local_addr: None,
89            transport: Some(transport),
90            options: Arc::new(options),
91            state: Arc::new(SocketState::default()),
92            hook: None,
93            compressor: None,
94            control_tx: None,
95            _driver_task: None,
96        }
97    }
98
99    /// Sets the message compressor for this socket.
100    pub fn with_compressor<C: Compressor + 'static>(mut self, compressor: C) -> Self {
101        self.compressor = Some(Arc::new(compressor));
102        self
103    }
104
105    /// Sets the connection hook for this socket.
106    ///
107    /// The connection hook is called when a new connection is accepted, before the connection
108    /// is used for request/reply communication.
109    ///
110    /// # Panics
111    ///
112    /// Panics if the socket has already been bound (driver started).
113    pub fn with_connection_hook<H>(mut self, hook: H) -> Self
114    where
115        H: ConnectionHook<T::Io>,
116    {
117        assert!(self.transport.is_some(), "cannot set connection hook after socket has been bound");
118        self.hook = Some(Arc::new(hook));
119        self
120    }
121
122    /// Binds the socket to the given address. This spawns the socket driver task.
123    pub async fn try_bind(&mut self, addresses: Vec<A>) -> Result<(), RepError> {
124        let (to_socket, from_backend) = mpsc::channel(DEFAULT_QUEUE_SIZE);
125        let (control_tx, control_rx) = mpsc::channel(DEFAULT_QUEUE_SIZE);
126
127        let mut transport = self.transport.take().expect("transport has been moved already");
128
129        for addr in addresses {
130            match transport.bind(addr.clone()).await {
131                Ok(_) => break,
132                Err(e) => {
133                    warn!(?e, ?addr, "failed to bind");
134                    continue;
135                }
136            }
137        }
138
139        let Some(local_addr) = transport.local_addr() else {
140            return Err(RepError::NoValidEndpoints);
141        };
142
143        let span = tracing::info_span!(parent: None, "rep_driver", ?local_addr);
144
145        span.in_scope(|| {
146            debug!("listening");
147        });
148
149        let backend = RepDriver {
150            transport,
151            options: Arc::clone(&self.options),
152            state: Arc::clone(&self.state),
153            peer_states: StreamMap::with_capacity(self.options.max_clients.unwrap_or(64)),
154            to_socket,
155            hook: self.hook.take(),
156            hook_tasks: JoinSet::new(),
157            compressor: self.compressor.take(),
158            conn_tasks: FuturesUnordered::new(),
159            control_rx,
160            span,
161        };
162
163        self._driver_task = Some(tokio::spawn(backend));
164        self.local_addr = Some(local_addr);
165        self.from_driver = Some(from_backend);
166        self.control_tx = Some(control_tx);
167
168        Ok(())
169    }
170
171    /// Returns the statistics for this socket.
172    pub fn stats(&self) -> &RepStats {
173        &self.state.stats.specific
174    }
175
176    /// Returns the local address this socket is bound to. `None` if the socket is not bound.
177    pub fn local_addr(&self) -> Option<&A> {
178        self.local_addr.as_ref()
179    }
180
181    /// Returns the next request from the socket using an unpinned interface.
182    pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Option<Request<A>>> {
183        Pin::new(self).poll_next(cx)
184    }
185
186    /// Issue a [`Transport::Control`] change to the underlying transport.
187    pub async fn control(
188        &mut self,
189        control: T::Control,
190    ) -> Result<(), mpsc::error::SendError<T::Control>> {
191        let Some(tx) = self.control_tx.as_mut() else {
192            tracing::warn!("calling control on a non-bound socket, this is a no-op");
193            return Ok(());
194        };
195        tx.send(control).await
196    }
197}
198
199impl<T, A> Stream for RepSocket<T, A>
200where
201    T: Transport<A>,
202    A: Address,
203{
204    type Item = Request<A>;
205
206    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207        self.get_mut().from_driver.as_mut().expect("Inactive socket").poll_recv(cx)
208    }
209}