1use std::{
2 collections::VecDeque, fmt::Debug, future::{poll_fn, Future}, sync::{
3 atomic::{AtomicU64, Ordering},
4 Arc, LazyLock,
5 }, task::Poll, time::{Duration, Instant}
6};
7
8use tokio::sync::{
9 mpsc::{channel, error::{SendError, TryRecvError, TrySendError}, Permit, Receiver, Sender},
10 oneshot,
11};
12use tokio_util::sync::CancellationToken;
13use tracing::Instrument;
14
15pub struct SequencedBroadcast<T> {
16 new_client_tx: Sender<NewClient<T>>,
17 metrics: Arc<SequencedBroadcastMetrics>,
18 shutdown: CancellationToken,
19 worker_loops: Arc<AtomicU64>,
20 closed: oneshot::Receiver<()>,
21}
22
23pub struct SequencedSender<T> {
24 next_seq: u64,
25 send: Sender<(u64, T)>,
26}
27
28pub struct SequencedReceiver<T> {
29 next_seq: u64,
30 receiver: Receiver<(u64, T)>,
31}
32
33#[derive(Debug, Default)]
34pub struct SequencedBroadcastMetrics {
35 pub oldest_sequence: AtomicU64,
36 pub next_sequence: AtomicU64,
37 pub new_client_drop_count: AtomicU64,
38 pub new_client_accept_count: AtomicU64,
39 pub lagging_subs_gauge: AtomicU64,
40 pub active_subs_gauge: AtomicU64,
41 pub min_sub_sequence_gauge: AtomicU64,
42 pub disconnect_count: AtomicU64,
43 pub worker_loops: AtomicU64,
44}
45
46impl SequencedBroadcastMetrics {
47 pub fn update(&self, other: &Self) {
48 self.oldest_sequence.store(other.oldest_sequence.load(Ordering::Acquire), Ordering::Release);
49 self.next_sequence.store(other.next_sequence.load(Ordering::Acquire), Ordering::Release);
50 self.new_client_drop_count.store(other.new_client_drop_count.load(Ordering::Acquire), Ordering::Release);
51 self.new_client_accept_count.store(other.new_client_accept_count.load(Ordering::Acquire), Ordering::Release);
52 self.lagging_subs_gauge.store(other.lagging_subs_gauge.load(Ordering::Acquire), Ordering::Release);
53 self.active_subs_gauge.store(other.active_subs_gauge.load(Ordering::Acquire), Ordering::Release);
54 self.min_sub_sequence_gauge.store(other.min_sub_sequence_gauge.load(Ordering::Acquire), Ordering::Release);
55 self.disconnect_count.store(other.disconnect_count.load(Ordering::Acquire), Ordering::Release);
56 self.worker_loops.store(other.worker_loops.load(Ordering::Acquire), Ordering::Release);
57 }
58}
59
60struct Subscriber<T> {
61 id: u64,
62 next_sequence: u64,
63 tx: Sender<(u64, T)>,
64 allow_drop: bool,
65 lag_started_at: Option<Instant>,
66 pending: Option<T>,
67}
68
69#[derive(Debug, Clone)]
70pub struct SequencedBroadcastSettings {
71 pub subscriber_channel_len: usize,
72 pub lag_start_threshold: u64,
73 pub lag_end_threshold: u64,
74 pub max_time_lag: Duration,
75 pub min_history: u64,
76}
77
78struct Worker<T> {
79 rx: Receiver<(u64, T)>,
80 next_rx: Option<(u64, T)>,
81 rx_closed: bool,
82 rx_full: bool,
83
84 next_client_rx: Receiver<NewClient<T>>,
85 next_client: Option<NewClient<T>>,
86 next_client_closed: bool,
87
88 next_sub_id: u64,
89 subscribers: Vec<Subscriber<T>>,
90 queue: VecDeque<(u64, T)>,
91 next_queue_seq: u64,
92 metrics: Arc<SequencedBroadcastMetrics>,
93 settings: SequencedBroadcastSettings,
94 shutdown: CancellationToken,
95 worker_loops: Arc<AtomicU64>,
96 closed: oneshot::Sender<()>,
97}
98
99impl Default for SequencedBroadcastSettings {
100 fn default() -> Self {
101 SequencedBroadcastSettings {
102 subscriber_channel_len: 32,
103 lag_start_threshold: 1024 * 8,
104 lag_end_threshold: 1024 * 4,
105 max_time_lag: Duration::from_secs(2),
106 min_history: 2048,
107 }
108 }
109}
110
111impl<T> SequencedSender<T> {
112 pub fn new(next_seq: u64, send: Sender<(u64, T)>) -> Self {
113 SequencedSender { next_seq, send }
114 }
115
116 pub fn is_closed(&self) -> bool {
117 self.send.is_closed()
118 }
119
120 pub async fn closed(&self) {
121 self.send.closed().await
122 }
123
124 pub async fn safe_send(&mut self, seq: u64, item: T) -> Result<(), SequencedSenderError<T>> {
125 self._send(Some(seq), item).await
126 }
127
128 pub async fn send(&mut self, item: T) -> Result<(), SequencedSenderError<T>> {
129 self._send(None, item).await
130 }
131
132 pub fn try_send(&mut self, item: T) -> Result<(), TrySendError<T>> {
133 match self.send.try_send((self.next_seq, item)) {
134 Ok(()) => {
135 self.next_seq += 1;
136 Ok(())
137 }
138 Err(TrySendError::Full(err)) => Err(TrySendError::Full(err.1)),
139 Err(TrySendError::Closed(err)) => Err(TrySendError::Closed(err.1)),
140 }
141 }
142
143 pub async fn reserve(&mut self) -> Result<SequencedSenderPermit<T>, SendError<()>> {
144 let permit = self.send.reserve().await?;
145
146 Ok(SequencedSenderPermit {
147 next_seq: &mut self.next_seq,
148 permit,
149 })
150 }
151
152 async fn _send(&mut self, seq: Option<u64>, item: T) -> Result<(), SequencedSenderError<T>> {
153 if let Some(seq) = seq {
154 if seq != self.next_seq {
155 return Err(SequencedSenderError::InvalidSequence(self.next_seq, item));
156 }
157 }
158
159 if let Err(error) = self.send.send((self.next_seq, item)).await {
160 return Err(SequencedSenderError::ChannelClosed(error.0.1));
161 }
162
163 self.next_seq += 1;
164 Ok(())
165 }
166
167 pub fn seq(&self) -> u64 {
168 self.next_seq
169 }
170}
171
172pub struct SequencedSenderPermit<'a, T> {
173 next_seq: &'a mut u64,
174 permit: Permit<'a, (u64, T)>,
175}
176
177impl<'a, T> SequencedSenderPermit<'a, T> {
178 pub fn send(self, item: T) {
179 let seq = *self.next_seq;
180 self.permit.send((seq, item));
181 *self.next_seq = seq + 1;
182 }
183}
184
185impl<T> SequencedReceiver<T> {
186 pub fn new(next_seq: u64, receiver: Receiver<(u64, T)>) -> Self {
187 SequencedReceiver {
188 next_seq,
189 receiver
190 }
191 }
192
193 pub fn is_closed(&self) -> bool {
194 self.receiver.is_closed()
195 }
196
197 pub async fn recv(&mut self) -> Option<(u64, T)> {
198 let (seq, action) = self.receiver.recv().await?;
199 if self.next_seq != seq {
200 panic!("expected sequence: {} but got: {}", self.next_seq, seq);
201 }
202 self.next_seq += 1;
203 Some((seq, action))
204 }
205
206 pub fn try_recv(&mut self) -> Result<(u64, T), TryRecvError> {
207 match self.receiver.try_recv() {
208 Ok((seq, action)) => {
209 if self.next_seq != seq {
210 panic!("expected sequence: {} but got: {}", self.next_seq, seq);
211 }
212 self.next_seq += 1;
213 Ok((seq, action))
214 }
215 Err(error) => Err(error)
216 }
217 }
218
219 pub fn unbundle(self) -> (u64, Receiver<(u64, T)>) {
220 (self.next_seq, self.receiver)
221 }
222
223 pub fn next_seq(&self) -> u64 {
224 self.next_seq
225 }
226}
227
228#[derive(Debug, PartialEq, Eq)]
229pub enum SequencedSenderError<T> {
230 InvalidSequence(u64, T),
231 ChannelClosed(T),
232}
233
234impl<T> SequencedSenderError<T> {
235 pub fn into_inner(self) -> T {
236 match self {
237 Self::InvalidSequence(_, v) => v,
238 Self::ChannelClosed(v) => v,
239 }
240 }
241}
242
243impl<T: Send + Clone + 'static> SequencedBroadcast<T> {
244 pub fn new(next_seq: u64, settings: SequencedBroadcastSettings) -> (Self, SequencedSender<T>) {
245 let (tx, rx) = channel(1024);
246 let tx = SequencedSender::new(next_seq, tx);
247 let rx = SequencedReceiver::new(next_seq, rx);
248
249 (
250 Self::new2(rx, settings),
251 tx
252 )
253 }
254
255 pub fn new2(receiver: SequencedReceiver<T>, settings: SequencedBroadcastSettings) -> Self {
256 let queue_cap = 2 * (
257 (settings.lag_start_threshold as usize)
258 .next_power_of_two()
259 .max(1024)
260 ).max((settings.min_history as usize).next_power_of_two());
261
262 assert!(settings.lag_end_threshold <= settings.lag_start_threshold);
263
264 let (client_tx, client_rx) = channel(32);
265
266 let metrics = Arc::new(SequencedBroadcastMetrics {
267 oldest_sequence: AtomicU64::new(receiver.next_seq),
268 next_sequence: AtomicU64::new(receiver.next_seq),
269 ..Default::default()
270 });
271
272 let shutdown = CancellationToken::new();
273 let current_span = tracing::Span::current();
274 let (closed_tx, closed_rx) = oneshot::channel();
275
276 let worker_loops = Arc::new(AtomicU64::new(0));
277
278 tokio::spawn(
279 Worker {
280 rx: receiver.receiver,
281 next_rx: None,
282 rx_full: false,
283 rx_closed: false,
284
285 next_client_rx: client_rx,
286 next_client: None,
287 next_client_closed: false,
288
289 next_sub_id: 1,
290 subscribers: Vec::with_capacity(32),
291 queue: VecDeque::with_capacity(queue_cap),
292 next_queue_seq: receiver.next_seq,
293 metrics: metrics.clone(),
294 settings,
295 shutdown: shutdown.clone(),
296 worker_loops: worker_loops.clone(),
297 closed: closed_tx,
298 }
299 .start()
300 .instrument(current_span),
301 );
302
303 Self {
304 new_client_tx: client_tx,
305 metrics,
306 shutdown,
307 worker_loops,
308 closed: closed_rx,
309 }
310 }
311
312 pub async fn add_client(
313 &self,
314 next_sequence: u64,
315 allow_drop: bool,
316 ) -> Result<SequencedReceiver<T>, NewClientError> {
317 let (tx, rx) = oneshot::channel();
318
319 self.new_client_tx
320 .send(NewClient {
321 response: tx,
322 allow_drop,
323 next_sequence,
324 })
325 .await
326 .expect("Failed to queue new subscriber, worker crashed");
327
328 rx.await.expect("worker closed")
329 }
330
331 pub fn metrics_ref(&self) -> &SequencedBroadcastMetrics {
332 &self.metrics
333 }
334
335 pub fn metrics(&self) -> Arc<SequencedBroadcastMetrics> {
336 self.metrics.clone()
337 }
338
339 pub fn worker_loops(&self) -> u64 {
340 self.worker_loops.load(Ordering::Relaxed)
341 }
342
343 pub fn shutdown(self) -> oneshot::Receiver<()> {
344 self.shutdown.cancel();
345 self.closed
346 }
347
348 pub async fn shutdown_wait(self) {
349 self.shutdown().await.unwrap();
350 }
351
352 pub fn closed(self) -> oneshot::Receiver<()> {
353 self.closed
354 }
355}
356
357struct NewClient<T> {
358 response: oneshot::Sender<Result<SequencedReceiver<T>, NewClientError>>,
359 next_sequence: u64,
360 allow_drop: bool,
361}
362
363#[derive(Debug)]
364pub enum NewClientError {
365 SequenceTooFarAhead { seq: u64, max: u64 },
366 SequenceTooFarBehind { seq: u64, min: u64 },
367}
368
369impl<T> Debug for NewClient<T> {
370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 write!(
372 f,
373 "NewClient {{ next_sequence: {}, allow_drop: {} }}",
374 self.next_sequence, self.allow_drop
375 )
376 }
377}
378
379static WORKER_ID: LazyLock<Arc<AtomicU64>> = LazyLock::new(|| Arc::new(AtomicU64::new(1)));
380
381impl<T: Send + Clone + 'static> Worker<T> {
382 async fn start(mut self) {
383 let id = WORKER_ID.fetch_add(1, Ordering::SeqCst);
384 tracing::info!(id, "{}|SequencedBroadcastWorker Started", id);
385 let start = Instant::now();
386
387 self._start(id).await;
388 let elapsed = start.elapsed();
389 let iter = self.worker_loops.load(Ordering::Relaxed);
390
391 tracing::info!(id, ?elapsed, iter, "{}|SequencedBroadcastWorker Stopped", id);
392
393 let _ = self.closed.send(());
394 }
395
396 async fn _start(&mut self, id: u64) {
397 loop {
398 self.worker_loops.fetch_add(1, Ordering::Relaxed);
399 tokio::task::yield_now().await;
400
401 if self.next_client.is_none() {
402 self.next_client = match self.next_client_rx.try_recv() {
403 Ok(item) => Some(item),
404 Err(TryRecvError::Empty) => None,
405 Err(TryRecvError::Disconnected) => {
406 self.next_client_closed = true;
407 None
408 }
409 };
410 }
411
412 if self.shutdown.is_cancelled() {
413 tracing::info!("{}|Stopping worker due to shutdown", id);
414 break;
415 }
416
417 if !self.next_client_closed {
419 let mut max_per_loop = 32;
420 let min_allowed_seq = self
421 .queue
422 .front()
423 .map(|i| i.0)
424 .unwrap_or(self.next_queue_seq);
425
426 while let Some(new) = self.next_client.take() {
427 self.next_client = self.next_client_rx.try_recv().ok();
428
429 if new.next_sequence < min_allowed_seq
431 || self.next_queue_seq < new.next_sequence
432 {
433 self.metrics
434 .new_client_drop_count
435 .fetch_add(1, Ordering::Relaxed);
436
437 if new.next_sequence < min_allowed_seq {
438 tracing::info!(
439 "{}|Subscriber rejected, seq({}) < min_allowed({})",
440 id,
441 new.next_sequence,
442 min_allowed_seq
443 );
444
445 let _ = new.response.send(Err(NewClientError::SequenceTooFarBehind {
446 seq: new.next_sequence,
447 min: min_allowed_seq
448 }));
449 } else {
450 tracing::info!(
451 "{}|Subscriber rejected, max_seq({}) < seq({})",
452 id,
453 self.next_queue_seq,
454 new.next_sequence
455 );
456
457 let _ = new.response.send(Err(NewClientError::SequenceTooFarAhead {
458 seq: new.next_sequence,
459 max: self.next_queue_seq
460 }));
461 }
462
463 continue;
464 }
465
466 self.metrics
467 .new_client_accept_count
468 .fetch_add(1, Ordering::Relaxed);
469
470 let (tx, rx) = channel(self.settings.subscriber_channel_len);
472 let rx = SequencedReceiver::<T> {
473 receiver: rx,
474 next_seq: new.next_sequence,
475 };
476
477 if new.response.send(Ok(rx)).is_ok() {
478 let sub_id = self.next_sub_id;
479 self.next_sub_id += 1;
480
481 tracing::info!(
482 "{}|Subscriber({}): Added, allow_drop: {}, next_sequence: {}, min_allowed_seq: {}",
483 id, sub_id, new.allow_drop, new.next_sequence, min_allowed_seq,
484 );
485
486 self.subscribers.push(Subscriber {
487 id: sub_id,
488 allow_drop: new.allow_drop,
489 next_sequence: new.next_sequence,
490 pending: None,
491 tx,
492 lag_started_at: None,
493 });
494 } else {
495 tracing::warn!("{}|New subscriber accepted but receiver dropped", id);
496 }
497
498 if max_per_loop == 0 {
500 break;
501 }
502
503 max_per_loop -= 1;
504 }
505 }
506
507 'fill_rx: {
509 if self.next_rx.is_none() {
510 self.next_rx = match self.rx.try_recv() {
511 Ok(msg) => Some(msg),
512 Err(TryRecvError::Disconnected) => {
513 self.rx_closed = true;
514 None
515 }
516 Err(TryRecvError::Empty) => None,
517 };
518 }
519
520 let mut remaining_msg_count = self.rx_space().min(1024);
521 if remaining_msg_count == 0 {
522 if !self.rx_full {
523 self.rx_full = true;
524 assert_eq!(self.queue.len(), self.queue.capacity());
525 tracing::info!("{}|Reached queue capacity {}", id, self.queue.len());
526 }
527
528 break 'fill_rx;
529 }
530
531 if self.rx_full {
532 tracing::info!("{}|Space returned to queue {}/{}", id, self.rx_space(), self.queue.len());
533 self.rx_full = false;
534 }
535
536 while let Some((seq, item)) = self.next_rx.take() {
537 self.next_rx = self.rx.try_recv().ok();
538
539 assert_eq!(seq, self.next_queue_seq, "sequence is invalid");
540 self.queue.push_back((seq, item));
541 self.next_queue_seq += 1;
542
543 remaining_msg_count -= 1;
544 if remaining_msg_count == 0 {
545 break;
546 }
547 }
548 }
549
550 self.metrics
551 .next_sequence
552 .store(self.next_queue_seq, Ordering::Relaxed);
553
554 let oldest_queue_sequence = self
555 .queue
556 .front()
557 .map(|v| v.0)
558 .unwrap_or(self.next_queue_seq);
559
560 let max_seq = oldest_queue_sequence + self.queue.len() as u64;
561 let lag_start_seq = max_seq.max(self.settings.lag_start_threshold) - self.settings.lag_start_threshold;
562 let lag_end_seq = lag_start_seq.max(max_seq.max(self.settings.lag_end_threshold) - self.settings.lag_end_threshold);
563
564 let mut min_sub_sequence_calc = self.next_queue_seq;
565 let mut earliest_lag_start_at_calc: Option<Instant> = None;
566
567 let mut i = 0;
568 'next_sub: while i < self.subscribers.len() {
569 let sub = &mut self.subscribers[i];
570
571 if (sub.allow_drop && sub.next_sequence < oldest_queue_sequence) || sub.tx.is_closed() {
573 if sub.tx.is_closed() {
574 tracing::info!("{}|Subscriber({}): channel closed, dropping", sub.id, id);
575 } else {
576 tracing::warn!(
577 "{}|Subscriber({}): lag behind available data ({} < {}), dropping",
578 id,
579 sub.id,
580 sub.next_sequence,
581 oldest_queue_sequence
582 );
583 }
584
585 if sub.lag_started_at.is_some() {
586 self.metrics
587 .lagging_subs_gauge
588 .fetch_sub(1, Ordering::Relaxed);
589 }
590
591 self.metrics
592 .disconnect_count
593 .fetch_add(1, Ordering::Relaxed);
594
595 self.subscribers.swap_remove(i);
596 continue 'next_sub;
597 }
598
599 let mut offset = {
601 assert!(sub.next_sequence >= oldest_queue_sequence);
602 let offset = (sub.next_sequence - oldest_queue_sequence) as usize;
603 assert!(sub.next_sequence <= self.next_queue_seq, "sub cannot be ahead of queue sequence");
604 assert!(offset <= self.queue.len(), "sub cannot be ahead of queue sequence");
605 offset
606 };
607
608 if sub.pending.is_none() {
610 if self.queue.len() == offset {
612 i += 1;
613 continue 'next_sub;
614 }
615
616 let (seq, item) = self.queue.get(offset).unwrap();
618 assert_eq!(*seq, sub.next_sequence);
619 sub.pending = Some(item.clone());
620 }
621
622 while let Some(next) = sub.pending.take() {
624 match sub.tx.try_send((sub.next_sequence, next)) {
625 Ok(_) => {
626 sub.next_sequence += 1;
627 offset += 1;
628
629 if self.queue.len() == offset {
630 break;
631 }
632
633 let (seq, item) = self.queue.get(offset).unwrap();
634 assert_eq!(*seq, sub.next_sequence);
635 sub.pending = Some(item.clone());
636 }
637 Err(TrySendError::Closed(_)) => break,
638 Err(TrySendError::Full((_seq, item))) => {
639 sub.pending = Some(item);
640 break;
641 }
642 }
643 }
644
645 if sub.allow_drop {
646 if lag_end_seq <= sub.next_sequence {
647 if let Some(lag_start) = sub.lag_started_at.take() {
648 tracing::info!(
649 "{}|Subscriber({}): caught up after {:?}",
650 id,
651 sub.id,
652 lag_start.elapsed()
653 );
654
655 self.metrics
656 .lagging_subs_gauge
657 .fetch_sub(1, Ordering::Relaxed);
658 }
659 }
660 else if sub.next_sequence < lag_start_seq {
661 if let Some(lag_start) = &sub.lag_started_at {
662 let lag_duration = lag_start.elapsed();
663
664 if self.settings.max_time_lag < lag_duration {
665 tracing::info!(
666 "{}|Subscriber({}): lag too high for too long ({:?}), dropping",
667 id,
668 sub.id,
669 lag_duration,
670 );
671
672 self.metrics
673 .lagging_subs_gauge
674 .fetch_sub(1, Ordering::Relaxed);
675
676 self.metrics
677 .disconnect_count
678 .fetch_add(1, Ordering::Relaxed);
679
680 self.subscribers.swap_remove(i);
681 continue 'next_sub;
682 }
683 } else {
684 sub.lag_started_at = Some(Instant::now());
685
686 tracing::info!(
687 "{}|Subscriber({}): lag started thresh({}) < lag({})",
688 id,
689 sub.id,
690 self.settings.lag_start_threshold,
691 max_seq - sub.next_sequence,
692 );
693
694 self.metrics
695 .lagging_subs_gauge
696 .fetch_add(1, Ordering::Relaxed);
697 }
698 }
699 }
700
701 if let Some(lag_started_at) = &sub.lag_started_at {
702 earliest_lag_start_at_calc = match earliest_lag_start_at_calc {
703 Some(v) if v.lt(lag_started_at) => Some(v),
704 _ => sub.lag_started_at
705 };
706 }
707
708 min_sub_sequence_calc = min_sub_sequence_calc.min(sub.next_sequence);
709 i += 1;
710 }
711
712 let min_sub_sequence = min_sub_sequence_calc;
713
714 self.metrics
715 .active_subs_gauge
716 .store(self.subscribers.len() as u64, Ordering::Relaxed);
717
718 self.metrics
719 .min_sub_sequence_gauge
720 .store(min_sub_sequence, Ordering::Relaxed);
721
722 {
724 let keep_seq = min_sub_sequence.min(max_seq.max(self.settings.min_history) - self.settings.min_history);
725
726 if oldest_queue_sequence < keep_seq {
727 let remove_count = keep_seq - oldest_queue_sequence;
728 if remove_count != 0 {
729 let _ = self.queue.drain(0..remove_count as usize);
730 }
731
732 self.metrics
733 .oldest_sequence
734 .store(oldest_queue_sequence + remove_count, Ordering::Relaxed);
735 }
736 }
737
738 if self.rx_closed && min_sub_sequence == max_seq {
739 tracing::info!("{}|RX closed and all subscribers caught up, shutting down worker", id);
740 return;
741 }
742
743 if self.next_client_closed && self.subscribers.is_empty() {
744 tracing::info!("{}|no subscribers and next_client_rx closed, shutting down worker", id);
745 return;
746 }
747
748 let rx_blocked = self.next_rx.is_none() && !self.rx_closed;
749 let next_timeout = earliest_lag_start_at_calc.map(|early| {
750 let now = Instant::now();
751 let expire = early + self.settings.max_time_lag;
752 (expire.max(now) - now).max(Duration::from_millis(100))
753 });
754
755 {
757 if !rx_blocked && 0 < self.rx_space() {
759 tracing::trace!("{}|have more rx", id);
760 continue;
761 }
762
763 if self.next_client.is_some() {
765 tracing::trace!("{}|have next client", id);
766 continue;
767 }
768 }
769
770 let mut timeout_fut = next_timeout.map(|duration| tokio::time::sleep(duration));
771 let mut pending_tx = Vec::new();
772 let new_client_rx = &mut self.next_client_rx;
773 let new_msg_rx = &mut self.rx;
774 let next_rx = &mut self.next_rx;
775 let next_client = &mut self.next_client;
776
777 for sub in &mut self.subscribers {
778 if sub.pending.is_some() {
779 pending_tx.push((sub.tx.reserve(), &mut sub.pending, &mut sub.next_sequence));
780 }
781 }
782
783 poll_fn(|cx| {
784 if let Some(timeout) = &mut timeout_fut {
785 if unsafe { std::pin::Pin::new_unchecked(timeout) }.poll(cx).is_ready() {
786 tracing::trace!("{}|poll: max lag timer reached", id);
787 return Poll::Ready(());
788 }
789 }
790
791 if rx_blocked {
792 if let Poll::Ready(item) = unsafe { std::pin::Pin::new_unchecked(&mut *new_msg_rx) }.poll_recv(cx) {
793 assert!(next_rx.is_none());
794
795 *next_rx = item;
796 if next_rx.is_some() {
797 tracing::trace!("{}|poll: new RX available", id);
798 } else {
799 tracing::trace!("{}|poll: RX closed", id);
800 }
801
802 return Poll::Ready(());
803 }
804 }
805
806 if let Poll::Ready(item) = unsafe { std::pin::Pin::new_unchecked(&mut *new_client_rx) }.poll_recv(cx) {
807 tracing::trace!("{}|poll: new client", id);
808
809 assert!(next_client.is_none());
810 *next_client = item;
811 return Poll::Ready(());
812 }
813
814 let mut sent = false;
815 for (reserve, pending, next_sequence) in &mut pending_tx {
816 let reserve = unsafe { std::pin::Pin::new_unchecked(reserve) };
817
818 match reserve.poll(cx) {
819 Poll::Ready(Ok(slot)) => {
820 let item = pending.take().expect("pending missing");
821 let seq = **next_sequence;
822 slot.send((seq, item));
823 **next_sequence = seq + 1;
824
825 sent = true;
826 }
827 Poll::Ready(Err(_)) => {
828 sent = true;
829 }
830 Poll::Pending => {}
831 }
832 }
833
834 if sent {
835 tracing::trace!("{}|poll: subscriber message sent", id);
836 return Poll::Ready(());
837 }
838
839 Poll::Pending
840 }).await;
841 }
842 }
843
844 fn rx_space(&self) -> usize {
845 self.queue.capacity() - self.queue.len()
846 }
847}
848
849#[cfg(test)]
850mod test {
851 use super::*;
852
853 pub fn setup_logging() {
854 let _ = tracing_subscriber::fmt().with_test_writer().try_init();
855 }
857
858 #[tokio::test]
859 async fn subscribers_shutdown_test() {
860 setup_logging();
861
862 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
863
864 let client = subs.add_client(0, true).await.unwrap();
865 tx.send("Hello World").await.unwrap();
866
867 let close_wait = subs.shutdown();
868
869 tokio::time::timeout(Duration::from_millis(100), close_wait).await
870 .expect("timeout waiting for close")
871 .expect("close handler dropped before send");
872
873 drop(client);
874 drop(tx);
875 }
876
877 #[tokio::test]
878 async fn subscribers_close_no_subs_test() {
879 setup_logging();
880
881 let close_wait = {
882 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
883 tx.send("Hello World").await.unwrap();
884 subs.closed()
885 };
886
887 tokio::time::timeout(Duration::from_millis(100), close_wait).await
888 .expect("timeout waiting for close")
889 .expect("close handler dropped before send");
890 }
891
892 #[tokio::test]
893 async fn subscribers_close_sub_caught_up_test() {
894 setup_logging();
895
896 let (close_wait, mut client) = {
897 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
898 tx.send("Hello World").await.unwrap();
899 let client = subs.add_client(0, true).await.unwrap();
900 (subs.closed(), client)
901 };
902
903 assert_eq!((0, "Hello World"), client.recv().await.unwrap());
904 drop(client);
905
906 tokio::time::timeout(Duration::from_millis(100), close_wait).await
907 .expect("timeout waiting for close")
908 .expect("close handler dropped before send");
909 }
910
911 #[tokio::test]
912 async fn subscribers_close_sub_caught_up_tx_alive_test() {
913 setup_logging();
914
915 let (close_wait, mut client, tx) = {
916 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
917 tx.send("Hello World").await.unwrap();
918 let client = subs.add_client(0, true).await.unwrap();
919 (subs.closed(), client, tx)
920 };
921
922 assert_eq!((0, "Hello World"), client.recv().await.unwrap());
923 drop(client);
924
925 tokio::time::timeout(Duration::from_millis(100), close_wait).await
926 .expect("timeout waiting for close")
927 .expect("close handler dropped before send");
928
929 drop(tx);
930 }
931
932 #[tokio::test]
933 async fn subscribers_close_sub_not_caught_up_test() {
934 setup_logging();
935
936 let (close_wait, mut client) = {
937 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
938 tx.send("Hello World").await.unwrap();
939 tx.send("Hello World 2").await.unwrap();
940 let client = subs.add_client(0, true).await.unwrap();
941 (subs.closed(), client)
942 };
943
944 assert_eq!((0, "Hello World"), client.recv().await.unwrap());
945 drop(client);
946
947 tokio::time::timeout(Duration::from_millis(100), close_wait).await
948 .expect("timeout waiting for close")
949 .expect("close handler dropped before send");
950 }
951
952 #[tokio::test]
953 async fn subscribers_catchup_test() {
954 setup_logging();
955
956 let (subs, mut tx) =
957 SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
958
959 tx.send("Hello WOrld").await.unwrap();
960 tx.send("What the heck").await.unwrap();
961
962 let mut sub_1 = subs.add_client(0, true).await.unwrap();
963 assert_eq!((0, "Hello WOrld"), sub_1.recv().await.unwrap());
964 assert_eq!((1, "What the heck"), sub_1.recv().await.unwrap());
965
966 let mut sub_2 = subs.add_client(0, true).await.unwrap();
967 assert_eq!((0, "Hello WOrld"), sub_2.recv().await.unwrap());
968 assert_eq!((1, "What the heck"), sub_2.recv().await.unwrap());
969
970 let mut sub_3 = subs.add_client(1, true).await.unwrap();
971 assert_eq!((1, "What the heck"), sub_3.recv().await.unwrap());
972
973 tx.send("Hehe").await.unwrap();
974 assert_eq!((2, "Hehe"), sub_1.recv().await.unwrap());
975 assert_eq!((2, "Hehe"), sub_2.recv().await.unwrap());
976 assert_eq!((2, "Hehe"), sub_3.recv().await.unwrap());
977
978 subs.shutdown_wait().await;
979 }
980
981 #[tokio::test]
982 async fn sequenced_broadcast_simple_test() {
983 setup_logging();
984
985 let (subs, mut tx) =
986 SequencedBroadcast::<u64>::new(10, SequencedBroadcastSettings::default());
987
988 let mut client = subs.add_client(10, true).await.unwrap();
989 tracing::info!("client added");
990
991 let read_task = tokio::spawn(async move {
992 let mut i = 0;
993 let mut seq = 10;
994
995 while let Some(msg) = client.recv().await {
996 assert_eq!(msg, (seq, i));
997 i += 1;
998 seq += 1;
999 }
1000
1001 i
1002 });
1003
1004 let count = 1024 * 16;
1005
1006 for i in 0..count {
1007 tx.send(i).await.unwrap();
1008 }
1009 drop(tx);
1010
1011 let total = read_task.await.unwrap();
1012 assert_eq!(total, count);
1013
1014 subs.shutdown_wait().await;
1015 }
1016
1017 #[tokio::test]
1018 async fn subscribers_test() {
1019 setup_logging();
1020
1021 let (subs, mut tx) =
1022 SequencedBroadcast::<&'static str>::new(10, SequencedBroadcastSettings::default());
1023 tx.send("Hello WOrld").await.unwrap();
1024 tx.send("What the heck").await.unwrap();
1025
1026 let mut sub = subs.add_client(10, true).await.unwrap();
1027 assert_eq!((10, "Hello WOrld"), sub.recv().await.unwrap());
1028 assert_eq!((11, "What the heck"), sub.recv().await.unwrap());
1029
1030 assert!(subs.add_client(10, true).await.is_ok());
1031 assert!(subs.add_client(11, true).await.is_ok());
1032 assert!(subs.add_client(12, true).await.is_ok());
1033 assert!(subs.add_client(13, true).await.is_err());
1034 assert!(subs.add_client(9, true).await.is_err());
1035
1036 tx.send("Butts").await.unwrap();
1037 assert_eq!((12, "Butts"), sub.recv().await.unwrap());
1038
1039 tokio::time::sleep(Duration::from_millis(1)).await;
1040
1041 tracing::info!("Metrics: {:?}", subs.metrics_ref());
1042
1043 subs.shutdown_wait().await;
1044 }
1045
1046 #[tokio::test]
1047 async fn subscribers_dont_drop_test() {
1048 setup_logging();
1049
1050 let (subs, mut tx) = SequencedBroadcast::<i64>::new(
1051 1,
1052 SequencedBroadcastSettings {
1053 max_time_lag: Duration::from_millis(100),
1054 ..Default::default()
1055 },
1056 );
1057
1058 let mut sub = subs.add_client(1, false).await.unwrap();
1059
1060 tokio::time::timeout(Duration::from_secs(3), async {
1061 loop {
1062 if tokio::time::timeout(Duration::from_secs(1), tx.send(1))
1063 .await
1064 .is_err()
1065 {
1066 break;
1067 }
1068 }
1069 }).await.expect("client must have been dropped as can still send tx");
1070
1071 tracing::info!("tx filled");
1072
1073 assert!(tokio::time::timeout(Duration::from_secs(1), tx.send(1))
1074 .await
1075 .is_err());
1076 assert_eq!((1, 1), sub.recv().await.unwrap());
1077
1078 assert!(
1079 tokio::time::timeout(Duration::from_millis(10), tx.send(1000))
1080 .await
1081 .is_ok()
1082 );
1083
1084 let sub_mut = &mut sub;
1085
1086 tokio::time::timeout(Duration::from_millis(100), async move {
1087 loop {
1088 let (_, num) = sub_mut.recv().await.unwrap();
1089 if num == 1000 {
1090 break;
1091 }
1092 }
1093 })
1094 .await
1095 .unwrap();
1096
1097 assert!(
1098 tokio::time::timeout(Duration::from_millis(10), tx.send(2000))
1099 .await
1100 .is_ok()
1101 );
1102 assert_eq!(2000, sub.recv().await.unwrap().1);
1103
1104 subs.shutdown_wait().await;
1105 }
1106
1107 #[tokio::test]
1108 async fn subscribers_no_clients_test() {
1109 setup_logging();
1110
1111 let (subs, mut tx) =
1112 SequencedBroadcast::<&'static str>::new(1, SequencedBroadcastSettings::default());
1113 let (subs, mut tx) = tokio::time::timeout(Duration::from_secs(1), async move {
1114 for _ in 0..1_000_000 {
1115 tx.send("Hello World").await.unwrap();
1116 }
1117
1118 tracing::info!("Sent 1M messages");
1119
1120 while tx.seq() != subs.metrics_ref().next_sequence.load(Ordering::Relaxed) {
1121 tokio::time::sleep(Duration::from_millis(100)).await;
1122 }
1123
1124 tracing::info!("All 1M messages have been processed");
1125
1126 (subs, tx)
1127 })
1128 .await
1129 .unwrap();
1130
1131 let seq = tx.seq();
1132 tracing::info!("Seq: {}", seq);
1133
1134 let mut sub = subs.add_client(seq - 1, true).await.unwrap();
1135 tx.send("Test").await.unwrap();
1136
1137 assert_eq!((seq - 1, "Hello World"), sub.recv().await.unwrap());
1138 assert_eq!((seq, "Test"), sub.recv().await.unwrap());
1139
1140 subs.shutdown_wait().await;
1141 }
1142
1143 #[tokio::test]
1144 async fn continious_send_send_test() {
1145 setup_logging();
1146
1147 let (subs, mut tx) = SequencedBroadcast::<u64>::new(1, SequencedBroadcastSettings {
1148 min_history: 1024,
1149 lag_end_threshold: 128,
1150 lag_start_threshold: 512,
1151 max_time_lag: Duration::from_secs(20),
1152 ..Default::default()
1153 });
1154
1155 let mut read_tasks = vec![];
1156 for _ in 0..32 {
1157 let mut client = subs.add_client(1, true).await.unwrap();
1158
1159 let read_task = tokio::spawn(async move {
1160 let mut next = 1;
1161 while let Some((seq, num)) = client.recv().await {
1162 assert_eq!(seq, num);
1163 assert_eq!(seq, next);
1164 next = seq + 1;
1165 }
1166 next
1167 });
1168
1169 read_tasks.push(read_task);
1170 }
1171
1172 let start = Instant::now();
1173 let mut end = 1;
1174 while start.elapsed() < Duration::from_secs(5) {
1175 tokio::time::timeout(Duration::from_secs(1), tx.send(end)).await
1176 .expect("timeout sending message")
1177 .expect("failed to send message");
1178
1179 end += 1;
1180 }
1181
1182 drop(tx);
1183
1184 for read_task in read_tasks {
1185 let count = tokio::time::timeout(Duration::from_secs(1), read_task).await
1186 .expect("timeout waiting for rx task to close")
1187 .expect("rx task crashed");
1188
1189 assert_eq!(count, end);
1190 }
1191 }
1192
1193 #[tokio::test]
1194 async fn subscribers_drops_slow_sub_test() {
1195 setup_logging();
1196
1197 let (subs, mut tx) = SequencedBroadcast::<i64>::new(
1198 1,
1199 SequencedBroadcastSettings {
1200 max_time_lag: Duration::from_secs(1),
1201 subscriber_channel_len: 4,
1202 lag_start_threshold: 64,
1203 lag_end_threshold: 32,
1204 ..Default::default()
1205 },
1206 );
1207
1208 let mut fast_client = subs.add_client(1, true).await.unwrap();
1209 let mut slow_client = subs.add_client(1, true).await.unwrap();
1210
1211 let send_task = tokio::spawn(async move {
1212 let mut i = 0;
1213 for _ in 0..1_000 {
1216 tokio::time::sleep(Duration::from_millis(5)).await;
1217 i += 1;
1218 tx.send(i).await.unwrap();
1219 }
1220
1221 tracing::info!("Done sending");
1222 drop(subs);
1223
1224 i
1225 });
1226
1227 let fast_recv_task = tokio::spawn(async move {
1228 let mut last = None;
1229 while let Some(recv) = fast_client.recv().await {
1230 last = Some(recv.1);
1231 }
1232 tracing::info!("Fast Done: {:?}", last);
1233 last.unwrap()
1234 });
1235
1236 let slow_recv_task = tokio::spawn(async move {
1237 let mut last = None;
1238 while let Some(recv) = slow_client.recv().await {
1239 last = Some(recv.1);
1240 tokio::time::sleep(Duration::from_millis(100)).await;
1241 }
1242 tracing::info!("Slow done: {:?}", last);
1243 last.unwrap()
1244 });
1245
1246 let sent_i = send_task.await.unwrap();
1247 let fast_recv_i = fast_recv_task.await.unwrap();
1248 let slow_recv_i = slow_recv_task.await.unwrap();
1249
1250 assert_eq!(sent_i, 1000);
1251 assert_eq!(fast_recv_i, 1000);
1252 assert_eq!(slow_recv_i, 19);
1253 }
1254}