1use std::{fmt, time::Duration};
2
3use bytes::Bytes;
4use thiserror::Error;
5
6mod driver;
7use driver::SubDriver;
8
9mod session;
10
11mod socket;
12pub use socket::*;
13
14mod stats;
15use stats::SubStats;
16
17mod stream;
18
19use msg_transport::Address;
20use msg_wire::pubsub;
21
22use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_QUEUE_SIZE, stats::SocketStats};
23
24#[derive(Debug, Error)]
26pub enum SubError {
27 #[error("IO error: {0:?}")]
28 Io(#[from] std::io::Error),
29 #[error("Wire protocol error: {0:?}")]
30 Wire(#[from] pubsub::Error),
31 #[error("Socket closed")]
32 SocketClosed,
33 #[error("Command channel full")]
34 ChannelFull,
35 #[error("Could not find any valid endpoints")]
36 NoValidEndpoints,
37 #[error("Reserved topic 'MSG' cannot be used")]
38 ReservedTopic,
39}
40
41#[derive(Debug)]
42enum Command<A: Address> {
43 Subscribe { topic: String },
45 Unsubscribe { topic: String },
47 Connect { endpoint: A },
49 Disconnect { endpoint: A },
51 Shutdown,
53}
54
55#[derive(Debug, Clone)]
56pub struct SubOptions {
57 ingress_queue_size: usize,
60 read_buffer_size: usize,
62 initial_backoff: Duration,
64 retry_attempts: Option<usize>,
66}
67
68impl SubOptions {
69 pub fn with_ingress_queue_size(mut self, ingress_queue_size: usize) -> Self {
75 self.ingress_queue_size = ingress_queue_size;
76 self
77 }
78
79 pub fn with_read_buffer_size(mut self, read_buffer_size: usize) -> Self {
83 self.read_buffer_size = read_buffer_size;
84 self
85 }
86
87 pub fn with_initial_backoff(mut self, initial_backoff: Duration) -> Self {
89 self.initial_backoff = initial_backoff;
90 self
91 }
92
93 pub fn with_retry_attempts(mut self, retry_attempts: usize) -> Self {
96 self.retry_attempts = Some(retry_attempts);
97 self
98 }
99}
100
101impl Default for SubOptions {
102 fn default() -> Self {
103 Self {
104 ingress_queue_size: DEFAULT_QUEUE_SIZE,
105 read_buffer_size: 8192,
106 initial_backoff: Duration::from_millis(100),
107 retry_attempts: Some(24),
108 }
109 }
110}
111
112#[derive(Clone)]
115pub struct PubMessage<A: Address> {
116 source: A,
119 topic: String,
121 payload: Bytes,
123}
124
125impl<A: Address> fmt::Debug for PubMessage<A> {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("PubMessage")
128 .field("source", &self.source)
129 .field("topic", &self.topic)
130 .field("payload_size", &self.payload.len())
131 .finish()
132 }
133}
134
135impl<A: Address> PubMessage<A> {
136 pub fn new(source: A, topic: String, payload: Bytes) -> Self {
137 Self { source, topic, payload }
138 }
139
140 #[inline]
141 pub fn source(&self) -> &A {
142 &self.source
143 }
144
145 #[inline]
146 pub fn topic(&self) -> &str {
147 &self.topic
148 }
149
150 #[inline]
151 pub fn payload(&self) -> &Bytes {
152 &self.payload
153 }
154
155 #[inline]
156 pub fn into_payload(self) -> Bytes {
157 self.payload
158 }
159}
160
161#[derive(Debug)]
163pub(crate) struct SocketState<A: Address> {
164 pub(crate) stats: SocketStats<SubStats<A>>,
165}
166
167impl<A: Address> Default for SocketState<A> {
168 fn default() -> Self {
169 Self { stats: SocketStats::default() }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use std::net::SocketAddr;
176
177 use msg_transport::tcp::Tcp;
178 use tokio::{
179 io::{AsyncReadExt, AsyncWriteExt},
180 net::TcpListener,
181 };
182 use tokio_stream::StreamExt;
183 use tracing::{Instrument, info, info_span};
184
185 use super::*;
186
187 async fn spawn_listener() -> SocketAddr {
188 let listener = TcpListener::bind("[::]:0").await.unwrap();
189
190 let addr = listener.local_addr().unwrap();
191
192 tokio::spawn(
193 async move {
194 let (mut socket, _) = listener.accept().await.unwrap();
195
196 let mut buf = [0u8; 1024];
197 let b = socket.read(&mut buf).await.unwrap();
198 let read = &buf[..b];
199
200 info!("Received bytes: {:?}", read);
201 socket.write_all(read).await.unwrap();
202 socket.flush().await.unwrap();
203 }
204 .instrument(info_span!("listener")),
205 );
206
207 addr
208 }
209
210 #[tokio::test]
211 async fn test_sub() {
212 let _ = tracing_subscriber::fmt::try_init();
213 let mut socket = socket::SubSocket::new(Tcp::default());
214
215 let addr = spawn_listener().await;
216 socket.connect(addr).await.unwrap();
217 socket.subscribe("HELLO".to_string()).await.unwrap();
218
219 let mirror = socket.next().await.unwrap();
220 assert_eq!("MSG.SUB.HELLO", mirror.topic);
221 }
222}