Skip to main content

msg_socket/sub/
mod.rs

1use std::{fmt, time::Duration};
2
3use bytes::Bytes;
4use thiserror::Error;
5
6mod driver;
7use driver::SubDriver;
8
9mod session;
10
11mod socket;
12pub use socket::*;
13
14mod stats;
15use stats::SubStats;
16
17mod stream;
18
19use msg_transport::Address;
20use msg_wire::pubsub;
21
22use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_QUEUE_SIZE, stats::SocketStats};
23
24/// Errors that can occur when using a subscriber socket.
25#[derive(Debug, Error)]
26pub enum SubError {
27    #[error("IO error: {0:?}")]
28    Io(#[from] std::io::Error),
29    #[error("Wire protocol error: {0:?}")]
30    Wire(#[from] pubsub::Error),
31    #[error("Socket closed")]
32    SocketClosed,
33    #[error("Command channel full")]
34    ChannelFull,
35    #[error("Could not find any valid endpoints")]
36    NoValidEndpoints,
37    #[error("Reserved topic 'MSG' cannot be used")]
38    ReservedTopic,
39}
40
41#[derive(Debug)]
42enum Command<A: Address> {
43    /// Subscribe to a topic.
44    Subscribe { topic: String },
45    /// Unsubscribe from a topic.
46    Unsubscribe { topic: String },
47    /// Connect to a publisher socket.
48    Connect { endpoint: A },
49    /// Disconnect from a publisher socket.
50    Disconnect { endpoint: A },
51    /// Shut down the driver.
52    Shutdown,
53}
54
55#[derive(Debug, Clone)]
56pub struct SubOptions {
57    /// The maximum amount of incoming messages that will be buffered before being dropped due to
58    /// a slow consumer.
59    ingress_queue_size: usize,
60    /// The read buffer size for each session.
61    read_buffer_size: usize,
62    /// The initial backoff for reconnecting to a publisher.
63    initial_backoff: Duration,
64    /// The maximum number of retry attempts. If `None`, the connection will retry indefinitely.
65    retry_attempts: Option<usize>,
66}
67
68impl SubOptions {
69    /// Sets the ingress queue size. This is the maximum amount of incoming messages that will be
70    /// buffered. If the consumer cannot keep up with the incoming messages, messages will start
71    /// being dropped.
72    ///
73    /// Default: [`DEFAULT_QUEUE_SIZE`]
74    pub fn with_ingress_queue_size(mut self, ingress_queue_size: usize) -> Self {
75        self.ingress_queue_size = ingress_queue_size;
76        self
77    }
78
79    /// Sets the read buffer size. This sets the size of the read buffer for each session.
80    ///
81    /// Default: [`DEFAULT_BUFFER_SIZE`]
82    pub fn with_read_buffer_size(mut self, read_buffer_size: usize) -> Self {
83        self.read_buffer_size = read_buffer_size;
84        self
85    }
86
87    /// Set the initial backoff for reconnecting to a publisher.
88    pub fn with_initial_backoff(mut self, initial_backoff: Duration) -> Self {
89        self.initial_backoff = initial_backoff;
90        self
91    }
92
93    /// Sets the maximum number of retry attempts. If `None`, the connection will retry
94    /// indefinitely.
95    pub fn with_retry_attempts(mut self, retry_attempts: usize) -> Self {
96        self.retry_attempts = Some(retry_attempts);
97        self
98    }
99}
100
101impl Default for SubOptions {
102    fn default() -> Self {
103        Self {
104            ingress_queue_size: DEFAULT_QUEUE_SIZE,
105            read_buffer_size: 8192,
106            initial_backoff: Duration::from_millis(100),
107            retry_attempts: Some(24),
108        }
109    }
110}
111
112/// A message received from a publisher.
113/// Includes the source, topic, and payload.
114#[derive(Clone)]
115pub struct PubMessage<A: Address> {
116    /// The source address of the publisher. We need this because
117    /// a subscriber can connect to multiple publishers.
118    source: A,
119    /// The topic of the message.
120    topic: String,
121    /// The message payload.
122    payload: Bytes,
123}
124
125impl<A: Address> fmt::Debug for PubMessage<A> {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        f.debug_struct("PubMessage")
128            .field("source", &self.source)
129            .field("topic", &self.topic)
130            .field("payload_size", &self.payload.len())
131            .finish()
132    }
133}
134
135impl<A: Address> PubMessage<A> {
136    pub fn new(source: A, topic: String, payload: Bytes) -> Self {
137        Self { source, topic, payload }
138    }
139
140    #[inline]
141    pub fn source(&self) -> &A {
142        &self.source
143    }
144
145    #[inline]
146    pub fn topic(&self) -> &str {
147        &self.topic
148    }
149
150    #[inline]
151    pub fn payload(&self) -> &Bytes {
152        &self.payload
153    }
154
155    #[inline]
156    pub fn into_payload(self) -> Bytes {
157        self.payload
158    }
159}
160
161/// The subscriber socket state, shared between the backend task and the socket frontend.
162#[derive(Debug)]
163pub(crate) struct SocketState<A: Address> {
164    pub(crate) stats: SocketStats<SubStats<A>>,
165}
166
167impl<A: Address> Default for SocketState<A> {
168    fn default() -> Self {
169        Self { stats: SocketStats::default() }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use std::net::SocketAddr;
176
177    use msg_transport::tcp::Tcp;
178    use tokio::{
179        io::{AsyncReadExt, AsyncWriteExt},
180        net::TcpListener,
181    };
182    use tokio_stream::StreamExt;
183    use tracing::{Instrument, info, info_span};
184
185    use super::*;
186
187    async fn spawn_listener() -> SocketAddr {
188        let listener = TcpListener::bind("[::]:0").await.unwrap();
189
190        let addr = listener.local_addr().unwrap();
191
192        tokio::spawn(
193            async move {
194                let (mut socket, _) = listener.accept().await.unwrap();
195
196                let mut buf = [0u8; 1024];
197                let b = socket.read(&mut buf).await.unwrap();
198                let read = &buf[..b];
199
200                info!("Received bytes: {:?}", read);
201                socket.write_all(read).await.unwrap();
202                socket.flush().await.unwrap();
203            }
204            .instrument(info_span!("listener")),
205        );
206
207        addr
208    }
209
210    #[tokio::test]
211    async fn test_sub() {
212        let _ = tracing_subscriber::fmt::try_init();
213        let mut socket = socket::SubSocket::new(Tcp::default());
214
215        let addr = spawn_listener().await;
216        socket.connect(addr).await.unwrap();
217        socket.subscribe("HELLO".to_string()).await.unwrap();
218
219        let mirror = socket.next().await.unwrap();
220        assert_eq!("MSG.SUB.HELLO", mirror.topic);
221    }
222}