Skip to main content

msg_socket/pub/
socket.rs

1use std::{net::SocketAddr, path::PathBuf, sync::Arc};
2
3use bytes::Bytes;
4use futures::stream::FuturesUnordered;
5use tokio::{
6    net::{ToSocketAddrs, lookup_host},
7    sync::broadcast,
8    task::JoinSet,
9};
10use tracing::{debug, trace, warn};
11
12use super::{PubError, PubMessage, PubOptions, SocketState, driver::PubDriver, stats::PubStats};
13use crate::{ConnectionHook, ConnectionHookErased};
14
15use msg_transport::{Address, Transport};
16use msg_wire::compression::Compressor;
17
18/// A publisher socket. This is thread-safe and can be cloned.
19///
20/// Publisher sockets are used to publish messages under certain topics to multiple subscribers.
21///
22/// ## Session
23/// Per subscriber, the socket maintains a session. The session
24/// manages the underlying connection and all of its state, such as the topic subscriptions. It also
25/// manages a queue of messages to be transmitted on the connection.
26#[derive(Clone)]
27pub struct PubSocket<T: Transport<A>, A: Address> {
28    /// The reply socket options, shared with the driver.
29    options: Arc<PubOptions>,
30    /// The reply socket state, shared with the driver.
31    state: Arc<SocketState>,
32    /// The transport used by this socket. This value is temporary and will be moved
33    /// to the driver task once the socket is bound.
34    transport: Option<T>,
35    /// The broadcast channel to all active
36    /// [`SubscriberSession`](super::session::SubscriberSession)s.
37    to_sessions_bcast: Option<broadcast::Sender<PubMessage>>,
38    /// Optional connection hook.
39    hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
40    /// Optional message compressor.
41    // NOTE: for now we're using dynamic dispatch, since using generics here
42    // complicates the API a lot. We can always change this later for perf reasons.
43    compressor: Option<Arc<dyn Compressor>>,
44    /// The local address this socket is bound to.
45    local_addr: Option<A>,
46}
47
48impl<T> PubSocket<T, SocketAddr>
49where
50    T: Transport<SocketAddr>,
51{
52    /// Binds the socket to the given socket address.
53    ///
54    /// This method is only available for transports that support [`SocketAddr`] as address type,
55    /// like [`Tcp`](msg_transport::tcp::Tcp) and [`Quic`](msg_transport::quic::Quic).
56    pub async fn bind(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> {
57        let addrs = lookup_host(addr).await?;
58        self.try_bind(addrs.collect()).await
59    }
60}
61
62impl<T> PubSocket<T, PathBuf>
63where
64    T: Transport<PathBuf>,
65{
66    /// Binds the socket to the given path.
67    ///
68    /// This method is only available for transports that support [`PathBuf`] as address type,
69    /// like [`Ipc`](msg_transport::ipc::Ipc).
70    pub async fn bind(&mut self, path: impl Into<PathBuf>) -> Result<(), PubError> {
71        self.try_bind(vec![path.into()]).await
72    }
73}
74
75impl<T, A> PubSocket<T, A>
76where
77    T: Transport<A>,
78    A: Address,
79{
80    /// Creates a new reply socket with the default [`PubOptions`].
81    pub fn new(transport: T) -> Self {
82        Self::with_options(transport, PubOptions::default())
83    }
84
85    /// Creates a new publisher socket with the given transport and options.
86    pub fn with_options(transport: T, options: PubOptions) -> Self {
87        Self {
88            local_addr: None,
89            to_sessions_bcast: None,
90            options: Arc::new(options),
91            transport: Some(transport),
92            state: Arc::new(SocketState::default()),
93            hook: None,
94            compressor: None,
95        }
96    }
97
98    /// Sets the message compressor for this socket.
99    pub fn with_compressor<C: Compressor + 'static>(mut self, compressor: C) -> Self {
100        self.compressor = Some(Arc::new(compressor));
101        self
102    }
103
104    /// Sets the connection hook for this socket.
105    ///
106    /// The connection hook is called when a new connection is accepted, before the connection
107    /// is used for pub/sub communication.
108    ///
109    /// # Panics
110    ///
111    /// Panics if the socket has already been bound (driver started).
112    pub fn with_connection_hook<H>(mut self, hook: H) -> Self
113    where
114        H: ConnectionHook<T::Io>,
115    {
116        assert!(self.transport.is_some(), "cannot set connection hook after socket has been bound");
117        self.hook = Some(Arc::new(hook));
118        self
119    }
120
121    /// Binds the socket to the given addresses in order until one succeeds.
122    ///
123    /// This also spawns the socket driver task.
124    pub async fn try_bind(&mut self, addresses: Vec<A>) -> Result<(), PubError> {
125        let (to_sessions_bcast, from_socket_bcast) =
126            broadcast::channel(self.options.high_water_mark);
127
128        let mut transport = self.transport.take().expect("Transport has been moved already");
129
130        for addr in addresses {
131            match transport.bind(addr.clone()).await {
132                Ok(_) => break,
133                Err(e) => {
134                    warn!(err = ?e, "Failed to bind to {:?}, trying next address", addr);
135                    continue;
136                }
137            }
138        }
139
140        let Some(local_addr) = transport.local_addr() else {
141            return Err(PubError::NoValidEndpoints);
142        };
143
144        debug!("Listening on {:?}", local_addr);
145
146        let backend = PubDriver {
147            id_counter: 0,
148            transport,
149            options: Arc::clone(&self.options),
150            state: Arc::clone(&self.state),
151            hook: self.hook.take(),
152            hook_tasks: JoinSet::new(),
153            conn_tasks: FuturesUnordered::new(),
154            from_socket_bcast,
155        };
156
157        tokio::spawn(backend);
158
159        self.local_addr = Some(local_addr);
160        self.to_sessions_bcast = Some(to_sessions_bcast);
161
162        Ok(())
163    }
164
165    /// Publishes a message to the given topic. If the topic doesn't exist, this is a no-op.
166    pub async fn publish(&self, topic: impl Into<String>, message: Bytes) -> Result<(), PubError> {
167        let mut msg = PubMessage::new(topic.into(), message);
168
169        // We compress here since that way we only have to do it once.
170        // Compression is only done if the message is larger than the
171        // configured minimum payload size.
172        let len_before = msg.payload().len();
173        if len_before > self.options.min_compress_size &&
174            let Some(ref compressor) = self.compressor
175        {
176            msg.compress(compressor.as_ref())?;
177            trace!("Compressed message from {} to {} bytes", len_before, msg.payload().len());
178        }
179
180        // Broadcast the message directly to all active sessions.
181        if self.to_sessions_bcast.as_ref().ok_or(PubError::SocketClosed)?.send(msg).is_err() {
182            debug!("No active subscriber sessions");
183        }
184
185        Ok(())
186    }
187
188    /// Publishes a message to the given topic, compressing the payload if a compressor is set.
189    /// If the topic doesn't exist, this is a no-op.
190    pub fn try_publish(&self, topic: String, message: Bytes) -> Result<(), PubError> {
191        let mut msg = PubMessage::new(topic, message);
192
193        // We compress here since that way we only have to do it once.
194        if let Some(ref compressor) = self.compressor {
195            let len_before = msg.payload().len();
196
197            // For relatively small messages, this takes <100us
198            msg.compress(compressor.as_ref())?;
199
200            debug!("Compressed message from {} to {} bytes", len_before, msg.payload().len(),);
201        }
202
203        // Broadcast the message directly to all active sessions.
204        if self.to_sessions_bcast.as_ref().ok_or(PubError::SocketClosed)?.send(msg).is_err() {
205            debug!("No active subscriber sessions");
206        }
207
208        Ok(())
209    }
210
211    pub fn stats(&self) -> &PubStats {
212        &self.state.stats.specific
213    }
214
215    /// Returns the local address this socket is bound to. `None` if the socket is not bound.
216    pub fn local_addr(&self) -> Option<&A> {
217        self.local_addr.as_ref()
218    }
219}