1use std::{io, time::Duration};
2
3use bytes::Bytes;
4use msg_common::constants::KiB;
5use thiserror::Error;
6
7mod driver;
8
9mod session;
10
11mod socket;
12pub use socket::*;
13
14mod stats;
15use crate::{Profile, stats::SocketStats};
16use stats::PubStats;
17
18mod trie;
19
20use msg_wire::{
21 compression::{CompressionType, Compressor},
22 pubsub,
23};
24
25const DEFAULT_HWM: usize = 1024;
27
28#[derive(Debug, Error)]
30pub enum PubError {
31 #[error("IO error: {0:?}")]
32 Io(#[from] io::Error),
33 #[error("Wire protocol error: {0:?}")]
34 Wire(#[from] msg_wire::reqrep::Error),
35 #[error("Socket closed")]
36 SocketClosed,
37 #[error("Topic already exists")]
38 TopicExists,
39 #[error("Unknown topic: {0}")]
40 UnknownTopic(String),
41 #[error("Could not connect to any valid endpoints")]
42 NoValidEndpoints,
43}
44
45#[derive(Debug)]
46pub struct PubOptions {
47 max_clients: Option<usize>,
49 high_water_mark: usize,
51 pub write_buffer_size: usize,
53 pub write_buffer_linger: Option<Duration>,
55 min_compress_size: usize,
58}
59
60impl PubOptions {
61 pub fn new(profile: Profile) -> Self {
63 match profile {
64 Profile::Balanced => Self::balanced(),
65 Profile::Latency => Self::low_latency(),
66 Profile::Throughput => Self::high_throughput(),
67 }
68 }
69}
70
71impl PubOptions {
72 pub fn low_latency() -> Self {
74 Self {
75 write_buffer_size: 8 * KiB as usize,
76 write_buffer_linger: Some(Duration::from_micros(50)),
77 ..Default::default()
78 }
79 }
80
81 pub fn high_throughput() -> Self {
83 Self {
84 write_buffer_size: 256 * KiB as usize,
85 write_buffer_linger: Some(Duration::from_micros(200)),
86 ..Default::default()
87 }
88 }
89
90 pub fn balanced() -> Self {
92 Self {
93 write_buffer_size: 32 * KiB as usize,
94 write_buffer_linger: Some(Duration::from_micros(100)),
95 ..Default::default()
96 }
97 }
98}
99
100impl PubOptions {
101 pub fn with_max_clients(mut self, max_clients: usize) -> Self {
103 self.max_clients = Some(max_clients);
104 self
105 }
106
107 pub fn with_high_water_mark(mut self, hwm: usize) -> Self {
110 self.high_water_mark = hwm;
111 self
112 }
113
114 pub fn with_min_compress_size(mut self, min_compress_size: usize) -> Self {
117 self.min_compress_size = min_compress_size;
118 self
119 }
120
121 pub fn with_write_buffer_size(mut self, size: usize) -> Self {
126 self.write_buffer_size = size;
127 self
128 }
129
130 pub fn with_write_buffer_linger(mut self, duration: Option<Duration>) -> Self {
135 self.write_buffer_linger = duration;
136 self
137 }
138}
139
140impl Default for PubOptions {
141 fn default() -> Self {
142 Self {
143 max_clients: None,
144 high_water_mark: DEFAULT_HWM,
145 min_compress_size: 8192,
146 write_buffer_size: 8192,
147 write_buffer_linger: Some(Duration::from_micros(100)),
148 }
149 }
150}
151
152#[derive(Debug, Clone)]
155pub struct PubMessage {
156 compression_type: CompressionType,
158 topic: String,
160 payload: Bytes,
162}
163
164#[allow(unused)]
165impl PubMessage {
166 pub fn new(topic: String, payload: Bytes) -> Self {
167 Self {
168 compression_type: CompressionType::None,
171 topic,
172 payload,
173 }
174 }
175
176 #[inline]
177 pub fn topic(&self) -> &str {
178 &self.topic
179 }
180
181 #[inline]
182 pub fn payload(&self) -> &Bytes {
183 &self.payload
184 }
185
186 #[inline]
187 pub fn into_payload(self) -> Bytes {
188 self.payload
189 }
190
191 #[inline]
192 pub fn into_wire(self, seq: u32) -> pubsub::Message {
193 pubsub::Message::new(
194 seq,
195 Bytes::from(self.topic),
196 self.payload,
197 self.compression_type as u8,
198 )
199 }
200
201 #[inline]
202 pub fn compress(&mut self, compressor: &dyn Compressor) -> Result<(), io::Error> {
203 self.payload = compressor.compress(&self.payload)?;
204 self.compression_type = compressor.compression_type();
205
206 Ok(())
207 }
208}
209
210#[derive(Debug, Default)]
212pub(crate) struct SocketState {
213 pub(crate) stats: SocketStats<PubStats>,
214}
215
216#[cfg(test)]
217mod tests {
218 use std::time::Duration;
219
220 use futures::StreamExt;
221 use msg_transport::{quic::Quic, tcp::Tcp};
222 use msg_wire::compression::GzipCompressor;
223 use tracing::info;
224
225 use crate::{
226 SubOptions, SubSocket,
227 hooks::token::{ClientHook, ServerHook},
228 };
229
230 use super::*;
231
232 #[tokio::test]
233 async fn pubsub_simple() {
234 let _ = tracing_subscriber::fmt::try_init();
235
236 let mut pub_socket = PubSocket::new(Tcp::default());
237
238 let mut sub_socket = SubSocket::with_options(Tcp::default(), SubOptions::default());
239
240 pub_socket.bind("0.0.0.0:0").await.unwrap();
241 let addr = pub_socket.local_addr().unwrap();
242
243 sub_socket.connect(addr).await.unwrap();
244 sub_socket.subscribe("HELLO".to_string()).await.unwrap();
245 tokio::time::sleep(Duration::from_millis(100)).await;
246
247 pub_socket.publish("HELLO".to_string(), "WORLD".into()).await.unwrap();
248
249 let msg = sub_socket.next().await.unwrap();
250 info!("Received message: {:?}", msg);
251 assert_eq!("HELLO", msg.topic());
252 assert_eq!("WORLD", msg.payload());
253 }
254
255 #[tokio::test]
256 async fn pubsub_auth_tcp() {
257 let _ = tracing_subscriber::fmt::try_init();
258
259 let mut pub_socket =
260 PubSocket::new(Tcp::default()).with_connection_hook(ServerHook::accept_all());
261
262 let mut sub_socket = SubSocket::new(Tcp::default())
263 .with_connection_hook(ClientHook::new(Bytes::from("client1")));
264
265 pub_socket.bind("0.0.0.0:0").await.unwrap();
266 let addr = pub_socket.local_addr().unwrap();
267
268 sub_socket.connect(addr).await.unwrap();
269 sub_socket.subscribe("HELLO".to_string()).await.unwrap();
270 tokio::time::sleep(Duration::from_millis(100)).await;
271
272 pub_socket.publish("HELLO".to_string(), "WORLD".into()).await.unwrap();
273
274 let msg = sub_socket.next().await.unwrap();
275 info!("Received message: {:?}", msg);
276 assert_eq!("HELLO", msg.topic());
277 assert_eq!("WORLD", msg.payload());
278 }
279
280 #[tokio::test]
281 async fn pubsub_auth_quic() {
282 let _ = tracing_subscriber::fmt::try_init();
283
284 let mut pub_socket =
285 PubSocket::new(Quic::default()).with_connection_hook(ServerHook::accept_all());
286
287 let mut sub_socket = SubSocket::new(Quic::default())
288 .with_connection_hook(ClientHook::new(Bytes::from("client1")));
289
290 pub_socket.bind("0.0.0.0:0").await.unwrap();
291 let addr = pub_socket.local_addr().unwrap();
292
293 sub_socket.connect(addr).await.unwrap();
294 sub_socket.subscribe("HELLO".to_string()).await.unwrap();
295 tokio::time::sleep(Duration::from_millis(100)).await;
296
297 pub_socket.publish("HELLO".to_string(), "WORLD".into()).await.unwrap();
298
299 let msg = sub_socket.next().await.unwrap();
300 info!("Received message: {:?}", msg);
301 assert_eq!("HELLO", msg.topic());
302 assert_eq!("WORLD", msg.payload());
303 }
304
305 #[tokio::test]
306 async fn pubsub_many() {
307 let _ = tracing_subscriber::fmt::try_init();
308
309 let mut pub_socket = PubSocket::new(Tcp::default());
310
311 let mut sub1 = SubSocket::new(Tcp::default());
312
313 let mut sub2 = SubSocket::new(Tcp::default());
314
315 pub_socket.bind("0.0.0.0:0").await.unwrap();
316 let addr = pub_socket.local_addr().unwrap();
317
318 sub1.connect(addr).await.unwrap();
319 sub2.connect(addr).await.unwrap();
320 sub1.subscribe("HELLO".to_string()).await.unwrap();
321 sub2.subscribe("HELLO".to_string()).await.unwrap();
322 tokio::time::sleep(Duration::from_millis(100)).await;
323
324 pub_socket.publish("HELLO".to_string(), Bytes::from("WORLD")).await.unwrap();
325
326 let msg = sub1.next().await.unwrap();
327 info!("Received message: {:?}", msg);
328 assert_eq!("HELLO", msg.topic());
329 assert_eq!("WORLD", msg.payload());
330
331 let msg = sub2.next().await.unwrap();
332 info!("Received message: {:?}", msg);
333 assert_eq!("HELLO", msg.topic());
334 assert_eq!("WORLD", msg.payload());
335 }
336
337 #[tokio::test]
338 async fn pubsub_many_compressed() {
339 let _ = tracing_subscriber::fmt::try_init();
340
341 let mut pub_socket = PubSocket::new(Tcp::default()).with_compressor(GzipCompressor::new(6));
342
343 let mut sub1 = SubSocket::new(Tcp::default());
344
345 let mut sub2 = SubSocket::new(Tcp::default());
346
347 pub_socket.bind("0.0.0.0:0").await.unwrap();
348 let addr = pub_socket.local_addr().unwrap();
349
350 sub1.connect(addr).await.unwrap();
351 sub2.connect(addr).await.unwrap();
352 sub1.subscribe("HELLO".to_string()).await.unwrap();
353 sub2.subscribe("HELLO".to_string()).await.unwrap();
354 tokio::time::sleep(Duration::from_millis(100)).await;
355
356 let original_msg = Bytes::from(
357 "WOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOORLD",
358 );
359
360 pub_socket.publish("HELLO".to_string(), original_msg.clone()).await.unwrap();
361
362 let msg = sub1.next().await.unwrap();
363 info!("Received message: {:?}", msg);
364 assert_eq!("HELLO", msg.topic());
365 assert_eq!(original_msg, msg.payload());
366
367 let msg = sub2.next().await.unwrap();
368 info!("Received message: {:?}", msg);
369 assert_eq!("HELLO", msg.topic());
370 assert_eq!(original_msg, msg.payload());
371 }
372
373 #[tokio::test]
374 async fn pubsub_durable_tcp() {
375 let _ = tracing_subscriber::fmt::try_init();
376
377 let mut pub_socket = PubSocket::new(Tcp::default());
378
379 let mut sub_socket = SubSocket::new(Tcp::default());
380
381 sub_socket.connect("0.0.0.0:6662").await.unwrap();
383 sub_socket.subscribe("HELLO".to_string()).await.unwrap();
384 tokio::time::sleep(Duration::from_millis(500)).await;
385
386 pub_socket.bind("0.0.0.0:6662").await.unwrap();
387 tokio::time::sleep(Duration::from_millis(2000)).await;
388
389 pub_socket.publish("HELLO".to_string(), Bytes::from("WORLD")).await.unwrap();
390
391 let msg = sub_socket.next().await.unwrap();
392 info!("Received message: {:?}", msg);
393 assert_eq!("HELLO", msg.topic());
394 assert_eq!("WORLD", msg.payload());
395 }
396
397 #[tokio::test]
398 async fn pubsub_durable_quic() {
399 let _ = tracing_subscriber::fmt::try_init();
400
401 let mut pub_socket = PubSocket::new(Quic::default());
402
403 let mut sub_socket = SubSocket::new(Quic::default());
404
405 sub_socket.connect("0.0.0.0:6662").await.unwrap();
407 sub_socket.subscribe("HELLO".to_string()).await.unwrap();
408 tokio::time::sleep(Duration::from_millis(1000)).await;
409
410 pub_socket.bind("0.0.0.0:6662").await.unwrap();
411 tokio::time::sleep(Duration::from_millis(2000)).await;
412
413 pub_socket.publish("HELLO".to_string(), Bytes::from("WORLD")).await.unwrap();
414
415 let msg = sub_socket.next().await.unwrap();
416 info!("Received message: {:?}", msg);
417 assert_eq!("HELLO", msg.topic());
418 assert_eq!("WORLD", msg.payload());
419 }
420
421 #[tokio::test]
422 async fn pubsub_max_clients() {
423 let _ = tracing_subscriber::fmt::try_init();
424
425 let mut pub_socket =
426 PubSocket::with_options(Tcp::default(), PubOptions::default().with_max_clients(1));
427
428 pub_socket.bind("0.0.0.0:0").await.unwrap();
429
430 let mut sub1 = SubSocket::with_options(Tcp::default(), SubOptions::default());
431
432 let mut sub2 = SubSocket::with_options(Tcp::default(), SubOptions::default());
433
434 let addr = pub_socket.local_addr().unwrap();
435
436 sub1.connect(addr).await.unwrap();
437 tokio::time::sleep(Duration::from_millis(100)).await;
438 assert_eq!(pub_socket.stats().active_clients(), 1);
439 sub2.connect(addr).await.unwrap();
440 tokio::time::sleep(Duration::from_millis(100)).await;
441 assert_eq!(pub_socket.stats().active_clients(), 1);
442 }
443}