Skip to main content

msg_socket/pub/
mod.rs

1use std::{io, time::Duration};
2
3use bytes::Bytes;
4use msg_common::constants::KiB;
5use thiserror::Error;
6
7mod driver;
8
9mod session;
10
11mod socket;
12pub use socket::*;
13
14mod stats;
15use crate::{Profile, stats::SocketStats};
16use stats::PubStats;
17
18mod trie;
19
20use msg_wire::{
21    compression::{CompressionType, Compressor},
22    pubsub,
23};
24
25/// The default high water mark for the socket.
26const DEFAULT_HWM: usize = 1024;
27
28/// Errors that can occur when using a publisher socket.
29#[derive(Debug, Error)]
30pub enum PubError {
31    #[error("IO error: {0:?}")]
32    Io(#[from] io::Error),
33    #[error("Wire protocol error: {0:?}")]
34    Wire(#[from] msg_wire::reqrep::Error),
35    #[error("Socket closed")]
36    SocketClosed,
37    #[error("Topic already exists")]
38    TopicExists,
39    #[error("Unknown topic: {0}")]
40    UnknownTopic(String),
41    #[error("Could not connect to any valid endpoints")]
42    NoValidEndpoints,
43}
44
45#[derive(Debug)]
46pub struct PubOptions {
47    /// The maximum number of concurrent clients.
48    max_clients: Option<usize>,
49    /// The maximum number of outgoing messages that can be buffered per session.
50    high_water_mark: usize,
51    /// The size of the write buffer in bytes.
52    pub write_buffer_size: usize,
53    /// The linger duration for the write buffer (how long to wait before flushing).
54    pub write_buffer_linger: Option<Duration>,
55    /// Minimum payload size in bytes for compression to be used. If the payload is smaller than
56    /// this threshold, it will not be compressed.
57    min_compress_size: usize,
58}
59
60impl PubOptions {
61    /// Creates new options based on the given profile.
62    pub fn new(profile: Profile) -> Self {
63        match profile {
64            Profile::Balanced => Self::balanced(),
65            Profile::Latency => Self::low_latency(),
66            Profile::Throughput => Self::high_throughput(),
67        }
68    }
69}
70
71impl PubOptions {
72    /// Creates options optimized for low latency.
73    pub fn low_latency() -> Self {
74        Self {
75            write_buffer_size: 8 * KiB as usize,
76            write_buffer_linger: Some(Duration::from_micros(50)),
77            ..Default::default()
78        }
79    }
80
81    /// Creates options optimized for high throughput.
82    pub fn high_throughput() -> Self {
83        Self {
84            write_buffer_size: 256 * KiB as usize,
85            write_buffer_linger: Some(Duration::from_micros(200)),
86            ..Default::default()
87        }
88    }
89
90    /// Creates options optimized for a balanced trade-off between latency and throughput.
91    pub fn balanced() -> Self {
92        Self {
93            write_buffer_size: 32 * KiB as usize,
94            write_buffer_linger: Some(Duration::from_micros(100)),
95            ..Default::default()
96        }
97    }
98}
99
100impl PubOptions {
101    /// Sets the maximum number of concurrent clients.
102    pub fn with_max_clients(mut self, max_clients: usize) -> Self {
103        self.max_clients = Some(max_clients);
104        self
105    }
106
107    /// Sets the high-water mark per session. This is the amount of messages that can be buffered
108    /// per session before messages start being dropped.
109    pub fn with_high_water_mark(mut self, hwm: usize) -> Self {
110        self.high_water_mark = hwm;
111        self
112    }
113
114    /// Sets the minimum payload size in bytes for compression to be used. If the payload is smaller
115    /// than this threshold, it will not be compressed.
116    pub fn with_min_compress_size(mut self, min_compress_size: usize) -> Self {
117        self.min_compress_size = min_compress_size;
118        self
119    }
120
121    /// Sets the size (max capacity) of the write buffer in bytes.
122    /// When the buffer is full, it will be flushed to the underlying transport.
123    ///
124    /// Default: 8KiB
125    pub fn with_write_buffer_size(mut self, size: usize) -> Self {
126        self.write_buffer_size = size;
127        self
128    }
129
130    /// Sets the linger duration for the write buffer. If `None`, the write buffer will only be
131    /// flushed when the buffer is full.
132    ///
133    /// Default: 100µs
134    pub fn with_write_buffer_linger(mut self, duration: Option<Duration>) -> Self {
135        self.write_buffer_linger = duration;
136        self
137    }
138}
139
140impl Default for PubOptions {
141    fn default() -> Self {
142        Self {
143            max_clients: None,
144            high_water_mark: DEFAULT_HWM,
145            min_compress_size: 8192,
146            write_buffer_size: 8192,
147            write_buffer_linger: Some(Duration::from_micros(100)),
148        }
149    }
150}
151
152/// A message received from a publisher.
153/// Includes the source, topic, and payload.
154#[derive(Debug, Clone)]
155pub struct PubMessage {
156    /// The compression type used for the message payload.
157    compression_type: CompressionType,
158    /// The topic of the message.
159    topic: String,
160    /// The message payload.
161    payload: Bytes,
162}
163
164#[allow(unused)]
165impl PubMessage {
166    pub fn new(topic: String, payload: Bytes) -> Self {
167        Self {
168            // Initialize the compression type to None.
169            // The actual compression type will be set in the `compress` method.
170            compression_type: CompressionType::None,
171            topic,
172            payload,
173        }
174    }
175
176    #[inline]
177    pub fn topic(&self) -> &str {
178        &self.topic
179    }
180
181    #[inline]
182    pub fn payload(&self) -> &Bytes {
183        &self.payload
184    }
185
186    #[inline]
187    pub fn into_payload(self) -> Bytes {
188        self.payload
189    }
190
191    #[inline]
192    pub fn into_wire(self, seq: u32) -> pubsub::Message {
193        pubsub::Message::new(
194            seq,
195            Bytes::from(self.topic),
196            self.payload,
197            self.compression_type as u8,
198        )
199    }
200
201    #[inline]
202    pub fn compress(&mut self, compressor: &dyn Compressor) -> Result<(), io::Error> {
203        self.payload = compressor.compress(&self.payload)?;
204        self.compression_type = compressor.compression_type();
205
206        Ok(())
207    }
208}
209
210/// The publisher socket state, shared between the backend task and the socket.
211#[derive(Debug, Default)]
212pub(crate) struct SocketState {
213    pub(crate) stats: SocketStats<PubStats>,
214}
215
216#[cfg(test)]
217mod tests {
218    use std::time::Duration;
219
220    use futures::StreamExt;
221    use msg_transport::{quic::Quic, tcp::Tcp};
222    use msg_wire::compression::GzipCompressor;
223    use tracing::info;
224
225    use crate::{
226        SubOptions, SubSocket,
227        hooks::token::{ClientHook, ServerHook},
228    };
229
230    use super::*;
231
232    #[tokio::test]
233    async fn pubsub_simple() {
234        let _ = tracing_subscriber::fmt::try_init();
235
236        let mut pub_socket = PubSocket::new(Tcp::default());
237
238        let mut sub_socket = SubSocket::with_options(Tcp::default(), SubOptions::default());
239
240        pub_socket.bind("0.0.0.0:0").await.unwrap();
241        let addr = pub_socket.local_addr().unwrap();
242
243        sub_socket.connect(addr).await.unwrap();
244        sub_socket.subscribe("HELLO".to_string()).await.unwrap();
245        tokio::time::sleep(Duration::from_millis(100)).await;
246
247        pub_socket.publish("HELLO".to_string(), "WORLD".into()).await.unwrap();
248
249        let msg = sub_socket.next().await.unwrap();
250        info!("Received message: {:?}", msg);
251        assert_eq!("HELLO", msg.topic());
252        assert_eq!("WORLD", msg.payload());
253    }
254
255    #[tokio::test]
256    async fn pubsub_auth_tcp() {
257        let _ = tracing_subscriber::fmt::try_init();
258
259        let mut pub_socket =
260            PubSocket::new(Tcp::default()).with_connection_hook(ServerHook::accept_all());
261
262        let mut sub_socket = SubSocket::new(Tcp::default())
263            .with_connection_hook(ClientHook::new(Bytes::from("client1")));
264
265        pub_socket.bind("0.0.0.0:0").await.unwrap();
266        let addr = pub_socket.local_addr().unwrap();
267
268        sub_socket.connect(addr).await.unwrap();
269        sub_socket.subscribe("HELLO".to_string()).await.unwrap();
270        tokio::time::sleep(Duration::from_millis(100)).await;
271
272        pub_socket.publish("HELLO".to_string(), "WORLD".into()).await.unwrap();
273
274        let msg = sub_socket.next().await.unwrap();
275        info!("Received message: {:?}", msg);
276        assert_eq!("HELLO", msg.topic());
277        assert_eq!("WORLD", msg.payload());
278    }
279
280    #[tokio::test]
281    async fn pubsub_auth_quic() {
282        let _ = tracing_subscriber::fmt::try_init();
283
284        let mut pub_socket =
285            PubSocket::new(Quic::default()).with_connection_hook(ServerHook::accept_all());
286
287        let mut sub_socket = SubSocket::new(Quic::default())
288            .with_connection_hook(ClientHook::new(Bytes::from("client1")));
289
290        pub_socket.bind("0.0.0.0:0").await.unwrap();
291        let addr = pub_socket.local_addr().unwrap();
292
293        sub_socket.connect(addr).await.unwrap();
294        sub_socket.subscribe("HELLO".to_string()).await.unwrap();
295        tokio::time::sleep(Duration::from_millis(100)).await;
296
297        pub_socket.publish("HELLO".to_string(), "WORLD".into()).await.unwrap();
298
299        let msg = sub_socket.next().await.unwrap();
300        info!("Received message: {:?}", msg);
301        assert_eq!("HELLO", msg.topic());
302        assert_eq!("WORLD", msg.payload());
303    }
304
305    #[tokio::test]
306    async fn pubsub_many() {
307        let _ = tracing_subscriber::fmt::try_init();
308
309        let mut pub_socket = PubSocket::new(Tcp::default());
310
311        let mut sub1 = SubSocket::new(Tcp::default());
312
313        let mut sub2 = SubSocket::new(Tcp::default());
314
315        pub_socket.bind("0.0.0.0:0").await.unwrap();
316        let addr = pub_socket.local_addr().unwrap();
317
318        sub1.connect(addr).await.unwrap();
319        sub2.connect(addr).await.unwrap();
320        sub1.subscribe("HELLO".to_string()).await.unwrap();
321        sub2.subscribe("HELLO".to_string()).await.unwrap();
322        tokio::time::sleep(Duration::from_millis(100)).await;
323
324        pub_socket.publish("HELLO".to_string(), Bytes::from("WORLD")).await.unwrap();
325
326        let msg = sub1.next().await.unwrap();
327        info!("Received message: {:?}", msg);
328        assert_eq!("HELLO", msg.topic());
329        assert_eq!("WORLD", msg.payload());
330
331        let msg = sub2.next().await.unwrap();
332        info!("Received message: {:?}", msg);
333        assert_eq!("HELLO", msg.topic());
334        assert_eq!("WORLD", msg.payload());
335    }
336
337    #[tokio::test]
338    async fn pubsub_many_compressed() {
339        let _ = tracing_subscriber::fmt::try_init();
340
341        let mut pub_socket = PubSocket::new(Tcp::default()).with_compressor(GzipCompressor::new(6));
342
343        let mut sub1 = SubSocket::new(Tcp::default());
344
345        let mut sub2 = SubSocket::new(Tcp::default());
346
347        pub_socket.bind("0.0.0.0:0").await.unwrap();
348        let addr = pub_socket.local_addr().unwrap();
349
350        sub1.connect(addr).await.unwrap();
351        sub2.connect(addr).await.unwrap();
352        sub1.subscribe("HELLO".to_string()).await.unwrap();
353        sub2.subscribe("HELLO".to_string()).await.unwrap();
354        tokio::time::sleep(Duration::from_millis(100)).await;
355
356        let original_msg = Bytes::from(
357            "WOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOORLD",
358        );
359
360        pub_socket.publish("HELLO".to_string(), original_msg.clone()).await.unwrap();
361
362        let msg = sub1.next().await.unwrap();
363        info!("Received message: {:?}", msg);
364        assert_eq!("HELLO", msg.topic());
365        assert_eq!(original_msg, msg.payload());
366
367        let msg = sub2.next().await.unwrap();
368        info!("Received message: {:?}", msg);
369        assert_eq!("HELLO", msg.topic());
370        assert_eq!(original_msg, msg.payload());
371    }
372
373    #[tokio::test]
374    async fn pubsub_durable_tcp() {
375        let _ = tracing_subscriber::fmt::try_init();
376
377        let mut pub_socket = PubSocket::new(Tcp::default());
378
379        let mut sub_socket = SubSocket::new(Tcp::default());
380
381        // Try to connect and subscribe before the publisher is up
382        sub_socket.connect("0.0.0.0:6662").await.unwrap();
383        sub_socket.subscribe("HELLO".to_string()).await.unwrap();
384        tokio::time::sleep(Duration::from_millis(500)).await;
385
386        pub_socket.bind("0.0.0.0:6662").await.unwrap();
387        tokio::time::sleep(Duration::from_millis(2000)).await;
388
389        pub_socket.publish("HELLO".to_string(), Bytes::from("WORLD")).await.unwrap();
390
391        let msg = sub_socket.next().await.unwrap();
392        info!("Received message: {:?}", msg);
393        assert_eq!("HELLO", msg.topic());
394        assert_eq!("WORLD", msg.payload());
395    }
396
397    #[tokio::test]
398    async fn pubsub_durable_quic() {
399        let _ = tracing_subscriber::fmt::try_init();
400
401        let mut pub_socket = PubSocket::new(Quic::default());
402
403        let mut sub_socket = SubSocket::new(Quic::default());
404
405        // Try to connect and subscribe before the publisher is up
406        sub_socket.connect("0.0.0.0:6662").await.unwrap();
407        sub_socket.subscribe("HELLO".to_string()).await.unwrap();
408        tokio::time::sleep(Duration::from_millis(1000)).await;
409
410        pub_socket.bind("0.0.0.0:6662").await.unwrap();
411        tokio::time::sleep(Duration::from_millis(2000)).await;
412
413        pub_socket.publish("HELLO".to_string(), Bytes::from("WORLD")).await.unwrap();
414
415        let msg = sub_socket.next().await.unwrap();
416        info!("Received message: {:?}", msg);
417        assert_eq!("HELLO", msg.topic());
418        assert_eq!("WORLD", msg.payload());
419    }
420
421    #[tokio::test]
422    async fn pubsub_max_clients() {
423        let _ = tracing_subscriber::fmt::try_init();
424
425        let mut pub_socket =
426            PubSocket::with_options(Tcp::default(), PubOptions::default().with_max_clients(1));
427
428        pub_socket.bind("0.0.0.0:0").await.unwrap();
429
430        let mut sub1 = SubSocket::with_options(Tcp::default(), SubOptions::default());
431
432        let mut sub2 = SubSocket::with_options(Tcp::default(), SubOptions::default());
433
434        let addr = pub_socket.local_addr().unwrap();
435
436        sub1.connect(addr).await.unwrap();
437        tokio::time::sleep(Duration::from_millis(100)).await;
438        assert_eq!(pub_socket.stats().active_clients(), 1);
439        sub2.connect(addr).await.unwrap();
440        tokio::time::sleep(Duration::from_millis(100)).await;
441        assert_eq!(pub_socket.stats().active_clients(), 1);
442    }
443}