Skip to main content

msg_socket/sub/
socket.rs

1use std::{
2    collections::HashSet,
3    net::SocketAddr,
4    path::PathBuf,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use futures::Stream;
11use rustc_hash::FxHashMap;
12use tokio::{
13    net::{ToSocketAddrs, lookup_host},
14    sync::mpsc,
15    task::JoinSet,
16};
17
18use msg_common::{IpAddrExt, JoinMap};
19use msg_transport::{Address, Transport};
20
21use crate::{
22    ConnectionHook, ConnectionHookErased,
23    sub::{
24        Command, DEFAULT_BUFFER_SIZE, PubMessage, SocketState, SubDriver, SubError, SubOptions,
25        stats::SubStats,
26    },
27};
28
29/// A subscriber socket. This socket implements [`Stream`] and yields incoming [`PubMessage`]s.
30pub struct SubSocket<T: Transport<A>, A: Address> {
31    /// Command channel to the socket driver.
32    to_driver: mpsc::Sender<Command<A>>,
33    /// Receiver channel from the socket driver.
34    from_driver: mpsc::Receiver<PubMessage<A>>,
35    /// Options for the socket. These are shared with the backend task.
36    #[allow(unused)]
37    options: Arc<SubOptions>,
38    /// The pending driver.
39    driver: Option<SubDriver<T, A>>,
40    /// Socket state. This is shared with the socket frontend. Contains the unified stats.
41    state: Arc<SocketState<A>>,
42    /// Optional connection hook.
43    hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
44    /// Marker for the transport type.
45    _marker: std::marker::PhantomData<T>,
46}
47
48impl<T> SubSocket<T, SocketAddr>
49where
50    T: Transport<SocketAddr> + Send + Sync + Unpin + 'static,
51{
52    /// Connects to the given endpoint asynchronously.
53    pub async fn connect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> {
54        let mut addrs = lookup_host(endpoint).await?;
55        let mut endpoint = addrs.next().ok_or(SubError::NoValidEndpoints)?;
56
57        // Some transport implementations (e.g. Quinn) can't dial an unspecified
58        // IP address, so replace it with localhost.
59        if endpoint.ip().is_unspecified() {
60            endpoint.set_ip(endpoint.ip().as_localhost());
61        }
62
63        self.connect_inner(endpoint).await
64    }
65
66    /// Attempts to connect to the given endpoint immediately.
67    pub fn try_connect(&mut self, endpoint: impl Into<String>) -> Result<(), SubError> {
68        let addr = endpoint.into();
69        let mut endpoint: SocketAddr = addr.parse().map_err(|_| SubError::NoValidEndpoints)?;
70
71        // Some transport implementations (e.g. Quinn) can't dial an unspecified
72        // IP address, so replace it with localhost.
73        if endpoint.ip().is_unspecified() {
74            endpoint.set_ip(endpoint.ip().as_localhost());
75        }
76
77        self.try_connect_inner(endpoint)
78    }
79
80    /// Disconnects from the given endpoint asynchronously.
81    pub async fn disconnect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> {
82        let mut addrs = lookup_host(endpoint).await?;
83        let endpoint = addrs.next().ok_or(SubError::NoValidEndpoints)?;
84
85        self.disconnect_inner(endpoint).await
86    }
87
88    /// Attempts to disconnect from the given endpoint immediately.
89    pub fn try_disconnect(&mut self, endpoint: impl Into<String>) -> Result<(), SubError> {
90        let endpoint = endpoint.into();
91        let endpoint: SocketAddr = endpoint.parse().map_err(|_| SubError::NoValidEndpoints)?;
92
93        self.try_disconnect_inner(endpoint)
94    }
95}
96
97impl<T> SubSocket<T, PathBuf>
98where
99    T: Transport<PathBuf> + Send + Sync + Unpin + 'static,
100{
101    /// Connects to the given path asynchronously.
102    pub async fn connect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
103        self.connect_inner(path.into()).await
104    }
105
106    /// Attempts to connect to the given path immediately.
107    pub fn try_connect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
108        self.try_connect_inner(path.into())
109    }
110
111    /// Disconnects from the given path asynchronously.
112    pub async fn disconnect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
113        self.disconnect_inner(path.into()).await
114    }
115
116    /// Attempts to disconnect from the given path immediately.
117    pub fn try_disconnect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
118        self.try_disconnect_inner(path.into())
119    }
120}
121
122impl<T, A> SubSocket<T, A>
123where
124    T: Transport<A> + Send + Sync + Unpin + 'static,
125    A: Address,
126{
127    /// Creates a new subscriber socket with the default [`SubOptions`].
128    pub fn new(transport: T) -> Self {
129        Self::with_options(transport, SubOptions::default())
130    }
131
132    /// Creates a new subscriber socket with the given transport and options.
133    pub fn with_options(transport: T, options: SubOptions) -> Self {
134        let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE);
135        let (to_socket, from_driver) = mpsc::channel(options.ingress_queue_size);
136
137        let options = Arc::new(options);
138
139        let state = Arc::new(SocketState::default());
140
141        let mut publishers = FxHashMap::default();
142        publishers.reserve(32);
143
144        let driver = SubDriver {
145            options: Arc::clone(&options),
146            transport,
147            from_socket,
148            to_socket,
149            conn_tasks: JoinMap::new(),
150            hook_tasks: JoinSet::new(),
151            subscribed_topics: HashSet::with_capacity(32),
152            publishers,
153            state: Arc::clone(&state),
154            hook: None,
155        };
156
157        Self {
158            to_driver,
159            from_driver,
160            driver: Some(driver),
161            options,
162            state,
163            hook: None,
164            _marker: std::marker::PhantomData,
165        }
166    }
167
168    /// Sets the connection hook for this socket.
169    ///
170    /// The connection hook is called after connecting to each publisher, before the connection
171    /// is used for pub/sub communication.
172    ///
173    /// # Panics
174    ///
175    /// Panics if the driver has already been started (i.e., after calling `connect`).
176    pub fn with_connection_hook<H>(mut self, hook: H) -> Self
177    where
178        H: ConnectionHook<T::Io>,
179    {
180        let hook_arc: Arc<dyn ConnectionHookErased<T::Io>> = Arc::new(hook);
181
182        // The driver must exist (not yet spawned) to set the connection hook
183        let driver =
184            self.driver.as_mut().expect("cannot set connection hook after driver has started");
185        driver.hook = Some(hook_arc.clone());
186
187        self.hook = Some(hook_arc);
188        self
189    }
190
191    /// Asynchronously connects to the endpoint.
192    pub async fn connect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
193        self.ensure_active_driver();
194        self.send_command(Command::Connect { endpoint }).await?;
195        Ok(())
196    }
197
198    /// Immediately send a connect command to the driver.
199    pub fn try_connect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
200        self.ensure_active_driver();
201        self.try_send_command(Command::Connect { endpoint })?;
202        Ok(())
203    }
204
205    /// Asynchronously disconnects from the endpoint.
206    pub async fn disconnect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
207        self.ensure_active_driver();
208        self.send_command(Command::Disconnect { endpoint }).await?;
209        Ok(())
210    }
211
212    /// Immediately send a disconnect command to the driver.
213    pub fn try_disconnect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
214        self.ensure_active_driver();
215        self.try_send_command(Command::Disconnect { endpoint })?;
216        Ok(())
217    }
218
219    /// Subscribes to the given topic. This will subscribe to all connected publishers.
220    /// If the topic does not exist on a publisher, this will not return any data.
221    /// Any publishers that are connected after this call will also be subscribed to.
222    pub async fn subscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
223        self.ensure_active_driver();
224
225        let topic = topic.into();
226        if topic.starts_with("MSG") {
227            return Err(SubError::ReservedTopic);
228        }
229
230        self.send_command(Command::Subscribe { topic }).await?;
231
232        Ok(())
233    }
234
235    /// Immediately send a subscribe command to the driver.
236    pub fn try_subscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
237        self.ensure_active_driver();
238
239        let topic = topic.into();
240        if topic.starts_with("MSG") {
241            return Err(SubError::ReservedTopic);
242        }
243
244        self.try_send_command(Command::Subscribe { topic })?;
245
246        Ok(())
247    }
248
249    /// Unsubscribe from the given topic. This will unsubscribe from all connected publishers.
250    pub async fn unsubscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
251        self.ensure_active_driver();
252
253        let topic = topic.into();
254        if topic.starts_with("MSG") {
255            return Err(SubError::ReservedTopic);
256        }
257
258        self.send_command(Command::Unsubscribe { topic }).await?;
259
260        Ok(())
261    }
262
263    /// Immediately send an unsubscribe command to the driver.
264    pub fn try_unsubscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
265        self.ensure_active_driver();
266
267        let topic = topic.into();
268        if topic.starts_with("MSG") {
269            return Err(SubError::ReservedTopic);
270        }
271
272        self.try_send_command(Command::Unsubscribe { topic })?;
273
274        Ok(())
275    }
276
277    /// Sends a command to the driver, returning [`SubError::SocketClosed`] if the
278    /// driver has been dropped.
279    async fn send_command(&self, command: Command<A>) -> Result<(), SubError> {
280        self.to_driver.send(command).await.map_err(|_| SubError::SocketClosed)?;
281
282        Ok(())
283    }
284
285    fn try_send_command(&self, command: Command<A>) -> Result<(), SubError> {
286        use mpsc::error::TrySendError::*;
287        self.to_driver.try_send(command).map_err(|e| match e {
288            Full(_) => SubError::ChannelFull,
289            Closed(_) => SubError::SocketClosed,
290        })?;
291        Ok(())
292    }
293
294    /// Ensures that the driver task is running. This function will be called on every command,
295    /// which might be overkill, but it keeps the interface simple and is not in the hot path.
296    fn ensure_active_driver(&mut self) {
297        if let Some(driver) = self.driver.take() {
298            tokio::spawn(driver);
299        }
300    }
301
302    /// Returns the statistics specific to the subscriber socket.
303    pub fn stats(&self) -> &SubStats<A> {
304        &self.state.stats.specific
305    }
306}
307
308impl<T: Transport<A>, A: Address> Drop for SubSocket<T, A> {
309    fn drop(&mut self) {
310        // Try to tell the driver to gracefully shut down.
311        let _ = self.to_driver.try_send(Command::Shutdown);
312    }
313}
314
315impl<T: Transport<A> + Unpin, A: Address> Stream for SubSocket<T, A> {
316    type Item = PubMessage<A>;
317
318    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
319        self.from_driver.poll_recv(cx)
320    }
321}