1use std::{
2 collections::VecDeque, fmt::Debug, future::{poll_fn, Future}, sync::{
3 atomic::{AtomicU64, Ordering},
4 Arc, LazyLock,
5 }, task::Poll, time::{Duration, Instant}
6};
7
8use arc_metrics::{IntCounter, IntGauge};
9use tokio::sync::{
10 mpsc::{channel, error::{SendError, TryRecvError, TrySendError}, Permit, Receiver, Sender},
11 oneshot,
12};
13use tokio_util::sync::CancellationToken;
14use tracing::Instrument;
15
16pub struct SequencedBroadcast<T> {
17 new_client_tx: Sender<NewClient<T>>,
18 metrics: Arc<SequencedBroadcastMetrics>,
19 shutdown: CancellationToken,
20 worker_loops: Arc<AtomicU64>,
21 closed: oneshot::Receiver<()>,
22}
23
24pub struct SequencedSender<T> {
25 next_seq: u64,
26 send: Sender<(u64, T)>,
27}
28
29pub struct SequencedReceiver<T> {
30 next_seq: u64,
31 receiver: Receiver<(u64, T)>,
32}
33
34#[derive(Default, Debug)]
35pub struct SequencedBroadcastMetrics {
36 pub oldest_sequence: IntGauge,
37 pub next_sequence: IntGauge,
38 pub new_client_drop_count: IntCounter,
39 pub new_client_accept_count: IntCounter,
40 pub lagging_subs_gauge: IntGauge,
41 pub active_subs_gauge: IntGauge,
42 pub min_sub_sequence_gauge: IntGauge,
43 pub disconnect_count: IntCounter,
44 pub worker_loops: IntCounter,
45}
46
47struct Subscriber<T> {
48 id: u64,
49 next_sequence: u64,
50 tx: Sender<(u64, T)>,
51 allow_drop: bool,
52 lag_started_at: Option<Instant>,
53 pending: Option<T>,
54}
55
56#[derive(Debug, Clone)]
57pub struct SequencedBroadcastSettings {
58 pub subscriber_channel_len: usize,
59 pub lag_start_threshold: u64,
60 pub lag_end_threshold: u64,
61 pub max_time_lag: Duration,
62 pub min_history: u64,
63}
64
65struct Worker<T> {
66 rx: Receiver<(u64, T)>,
67 next_rx: Option<(u64, T)>,
68 rx_closed: bool,
69 rx_full: bool,
70
71 next_client_rx: Receiver<NewClient<T>>,
72 next_client: Option<NewClient<T>>,
73 next_client_closed: bool,
74
75 next_sub_id: u64,
76 subscribers: Vec<Subscriber<T>>,
77 queue: VecDeque<(u64, T)>,
78 next_queue_seq: u64,
79 metrics: Arc<SequencedBroadcastMetrics>,
80 settings: SequencedBroadcastSettings,
81 shutdown: CancellationToken,
82 worker_loops: Arc<AtomicU64>,
83 closed: oneshot::Sender<()>,
84}
85
86impl Default for SequencedBroadcastSettings {
87 fn default() -> Self {
88 SequencedBroadcastSettings {
89 subscriber_channel_len: 32,
90 lag_start_threshold: 1024 * 8,
91 lag_end_threshold: 1024 * 4,
92 max_time_lag: Duration::from_secs(2),
93 min_history: 2048,
94 }
95 }
96}
97
98impl<T> SequencedSender<T> {
99 pub fn new(next_seq: u64, send: Sender<(u64, T)>) -> Self {
100 SequencedSender { next_seq, send }
101 }
102
103 pub fn is_closed(&self) -> bool {
104 self.send.is_closed()
105 }
106
107 pub async fn closed(&self) {
108 self.send.closed().await
109 }
110
111 pub async fn safe_send(&mut self, seq: u64, item: T) -> Result<(), SequencedSenderError<T>> {
112 self._send(Some(seq), item).await
113 }
114
115 pub async fn send(&mut self, item: T) -> Result<(), SequencedSenderError<T>> {
116 self._send(None, item).await
117 }
118
119 pub fn try_send(&mut self, item: T) -> Result<(), TrySendError<T>> {
120 match self.send.try_send((self.next_seq, item)) {
121 Ok(()) => {
122 self.next_seq += 1;
123 Ok(())
124 }
125 Err(TrySendError::Full(err)) => Err(TrySendError::Full(err.1)),
126 Err(TrySendError::Closed(err)) => Err(TrySendError::Closed(err.1)),
127 }
128 }
129
130 pub async fn reserve(&mut self) -> Result<SequencedSenderPermit<T>, SendError<()>> {
131 let permit = self.send.reserve().await?;
132
133 Ok(SequencedSenderPermit {
134 next_seq: &mut self.next_seq,
135 permit,
136 })
137 }
138
139 async fn _send(&mut self, seq: Option<u64>, item: T) -> Result<(), SequencedSenderError<T>> {
140 if let Some(seq) = seq {
141 if seq != self.next_seq {
142 return Err(SequencedSenderError::InvalidSequence(self.next_seq, item));
143 }
144 }
145
146 if let Err(error) = self.send.send((self.next_seq, item)).await {
147 return Err(SequencedSenderError::ChannelClosed(error.0.1));
148 }
149
150 self.next_seq += 1;
151 Ok(())
152 }
153
154 pub fn seq(&self) -> u64 {
155 self.next_seq
156 }
157}
158
159pub struct SequencedSenderPermit<'a, T> {
160 next_seq: &'a mut u64,
161 permit: Permit<'a, (u64, T)>,
162}
163
164impl<T> SequencedSenderPermit<'_, T> {
165 pub fn send(self, item: T) {
166 let seq = *self.next_seq;
167 self.permit.send((seq, item));
168 *self.next_seq = seq + 1;
169 }
170}
171
172impl<T> SequencedReceiver<T> {
173 pub fn new(next_seq: u64, receiver: Receiver<(u64, T)>) -> Self {
174 SequencedReceiver {
175 next_seq,
176 receiver
177 }
178 }
179
180 pub fn is_closed(&self) -> bool {
181 self.receiver.is_closed()
182 }
183
184 pub async fn recv(&mut self) -> Option<(u64, T)> {
185 let (seq, action) = self.receiver.recv().await?;
186 if self.next_seq != seq {
187 panic!("expected sequence: {} but got: {}", self.next_seq, seq);
188 }
189 self.next_seq += 1;
190 Some((seq, action))
191 }
192
193 pub fn try_recv(&mut self) -> Result<(u64, T), TryRecvError> {
194 match self.receiver.try_recv() {
195 Ok((seq, action)) => {
196 if self.next_seq != seq {
197 panic!("expected sequence: {} but got: {}", self.next_seq, seq);
198 }
199 self.next_seq += 1;
200 Ok((seq, action))
201 }
202 Err(error) => Err(error)
203 }
204 }
205
206 pub fn unbundle(self) -> (u64, Receiver<(u64, T)>) {
207 (self.next_seq, self.receiver)
208 }
209
210 pub fn next_seq(&self) -> u64 {
211 self.next_seq
212 }
213}
214
215#[derive(Debug, PartialEq, Eq)]
216pub enum SequencedSenderError<T> {
217 InvalidSequence(u64, T),
218 ChannelClosed(T),
219}
220
221impl<T> SequencedSenderError<T> {
222 pub fn into_inner(self) -> T {
223 match self {
224 Self::InvalidSequence(_, v) => v,
225 Self::ChannelClosed(v) => v,
226 }
227 }
228}
229
230impl<T: Send + Clone + 'static> SequencedBroadcast<T> {
231 pub fn new(next_seq: u64, settings: SequencedBroadcastSettings) -> (Self, SequencedSender<T>) {
232 let (tx, rx) = channel(1024);
233 let tx = SequencedSender::new(next_seq, tx);
234 let rx = SequencedReceiver::new(next_seq, rx);
235
236 (
237 Self::new2(rx, settings),
238 tx
239 )
240 }
241
242 pub fn new2(receiver: SequencedReceiver<T>, settings: SequencedBroadcastSettings) -> Self {
243 let queue_cap = 2 * (
244 (settings.lag_start_threshold as usize)
245 .next_power_of_two()
246 .max(1024)
247 ).max((settings.min_history as usize).next_power_of_two());
248
249 assert!(settings.lag_end_threshold <= settings.lag_start_threshold);
250
251 let (client_tx, client_rx) = channel(32);
252
253 let metrics = Arc::new(SequencedBroadcastMetrics {
254 oldest_sequence: {
255 let i = IntGauge::default();
256 i.set(receiver.next_seq);
257 i
258 },
259 next_sequence: {
260 let i = IntGauge::default();
261 i.set(receiver.next_seq);
262 i
263 },
264 ..Default::default()
265 });
266
267 let shutdown = CancellationToken::new();
268 let current_span = tracing::Span::current();
269 let (closed_tx, closed_rx) = oneshot::channel();
270
271 let worker_loops = Arc::new(AtomicU64::new(0));
272
273 tokio::spawn(
274 Worker {
275 rx: receiver.receiver,
276 next_rx: None,
277 rx_full: false,
278 rx_closed: false,
279
280 next_client_rx: client_rx,
281 next_client: None,
282 next_client_closed: false,
283
284 next_sub_id: 1,
285 subscribers: Vec::with_capacity(32),
286 queue: VecDeque::with_capacity(queue_cap),
287 next_queue_seq: receiver.next_seq,
288 metrics: metrics.clone(),
289 settings,
290 shutdown: shutdown.clone(),
291 worker_loops: worker_loops.clone(),
292 closed: closed_tx,
293 }
294 .start()
295 .instrument(current_span),
296 );
297
298 Self {
299 new_client_tx: client_tx,
300 metrics,
301 shutdown,
302 worker_loops,
303 closed: closed_rx,
304 }
305 }
306
307 pub async fn add_client(
308 &self,
309 next_sequence: u64,
310 allow_drop: bool,
311 ) -> Result<SequencedReceiver<T>, NewClientError> {
312 let (tx, rx) = oneshot::channel();
313
314 self.new_client_tx
315 .send(NewClient {
316 response: tx,
317 allow_drop,
318 next_sequence,
319 })
320 .await
321 .expect("Failed to queue new subscriber, worker crashed");
322
323 rx.await.expect("worker closed")
324 }
325
326 pub fn metrics_ref(&self) -> &SequencedBroadcastMetrics {
327 &self.metrics
328 }
329
330 pub fn metrics(&self) -> Arc<SequencedBroadcastMetrics> {
331 self.metrics.clone()
332 }
333
334 pub fn worker_loops(&self) -> u64 {
335 self.worker_loops.load(Ordering::Relaxed)
336 }
337
338 pub fn shutdown(self) -> oneshot::Receiver<()> {
339 self.shutdown.cancel();
340 self.closed
341 }
342
343 pub async fn shutdown_wait(self) {
344 self.shutdown().await.unwrap();
345 }
346
347 pub fn closed(self) -> oneshot::Receiver<()> {
348 self.closed
349 }
350}
351
352struct NewClient<T> {
353 response: oneshot::Sender<Result<SequencedReceiver<T>, NewClientError>>,
354 next_sequence: u64,
355 allow_drop: bool,
356}
357
358#[derive(Debug)]
359pub enum NewClientError {
360 SequenceTooFarAhead { seq: u64, max: u64 },
361 SequenceTooFarBehind { seq: u64, min: u64 },
362}
363
364impl<T> Debug for NewClient<T> {
365 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366 write!(
367 f,
368 "NewClient {{ next_sequence: {}, allow_drop: {} }}",
369 self.next_sequence, self.allow_drop
370 )
371 }
372}
373
374static WORKER_ID: LazyLock<Arc<AtomicU64>> = LazyLock::new(|| Arc::new(AtomicU64::new(1)));
375
376impl<T: Send + Clone + 'static> Worker<T> {
377 async fn start(mut self) {
378 let id = WORKER_ID.fetch_add(1, Ordering::SeqCst);
379 tracing::info!(id, "{}|SequencedBroadcastWorker Started", id);
380 let start = Instant::now();
381
382 self._start(id).await;
383 let elapsed = start.elapsed();
384 let iter = self.worker_loops.load(Ordering::Relaxed);
385
386 tracing::info!(id, ?elapsed, iter, "{}|SequencedBroadcastWorker Stopped", id);
387
388 let _ = self.closed.send(());
389 }
390
391 async fn _start(&mut self, id: u64) {
392 loop {
393 self.worker_loops.fetch_add(1, Ordering::Relaxed);
394 tokio::task::yield_now().await;
395
396 if self.next_client.is_none() {
397 self.next_client = match self.next_client_rx.try_recv() {
398 Ok(item) => Some(item),
399 Err(TryRecvError::Empty) => None,
400 Err(TryRecvError::Disconnected) => {
401 self.next_client_closed = true;
402 None
403 }
404 };
405 }
406
407 if self.shutdown.is_cancelled() {
408 tracing::info!("{}|Stopping worker due to shutdown", id);
409 break;
410 }
411
412 if !self.next_client_closed {
414 let mut max_per_loop = 32;
415 let min_allowed_seq = self
416 .queue
417 .front()
418 .map(|i| i.0)
419 .unwrap_or(self.next_queue_seq);
420
421 while let Some(new) = self.next_client.take() {
422 self.next_client = self.next_client_rx.try_recv().ok();
423
424 if new.next_sequence < min_allowed_seq
426 || self.next_queue_seq < new.next_sequence
427 {
428 self.metrics.new_client_drop_count.inc();
429
430 if new.next_sequence < min_allowed_seq {
431 tracing::info!(
432 "{}|Subscriber rejected, seq({}) < min_allowed({})",
433 id,
434 new.next_sequence,
435 min_allowed_seq
436 );
437
438 let _ = new.response.send(Err(NewClientError::SequenceTooFarBehind {
439 seq: new.next_sequence,
440 min: min_allowed_seq
441 }));
442 } else {
443 tracing::info!(
444 "{}|Subscriber rejected, max_seq({}) < seq({})",
445 id,
446 self.next_queue_seq,
447 new.next_sequence
448 );
449
450 let _ = new.response.send(Err(NewClientError::SequenceTooFarAhead {
451 seq: new.next_sequence,
452 max: self.next_queue_seq
453 }));
454 }
455
456 continue;
457 }
458
459 self.metrics.new_client_accept_count.inc();
460
461 let (tx, rx) = channel(self.settings.subscriber_channel_len);
463 let rx = SequencedReceiver::<T> {
464 receiver: rx,
465 next_seq: new.next_sequence,
466 };
467
468 if new.response.send(Ok(rx)).is_ok() {
469 let sub_id = self.next_sub_id;
470 self.next_sub_id += 1;
471
472 tracing::info!(
473 "{}|Subscriber({}): Added, allow_drop: {}, next_sequence: {}, min_allowed_seq: {}",
474 id, sub_id, new.allow_drop, new.next_sequence, min_allowed_seq,
475 );
476
477 self.subscribers.push(Subscriber {
478 id: sub_id,
479 allow_drop: new.allow_drop,
480 next_sequence: new.next_sequence,
481 pending: None,
482 tx,
483 lag_started_at: None,
484 });
485 } else {
486 tracing::warn!("{}|New subscriber accepted but receiver dropped", id);
487 }
488
489 if max_per_loop == 0 {
491 break;
492 }
493
494 max_per_loop -= 1;
495 }
496 }
497
498 'fill_rx: {
500 if self.next_rx.is_none() {
501 self.next_rx = match self.rx.try_recv() {
502 Ok(msg) => Some(msg),
503 Err(TryRecvError::Disconnected) => {
504 self.rx_closed = true;
505 None
506 }
507 Err(TryRecvError::Empty) => None,
508 };
509 }
510
511 let mut remaining_msg_count = self.rx_space().min(1024);
512 if remaining_msg_count == 0 {
513 if !self.rx_full {
514 self.rx_full = true;
515 assert_eq!(self.queue.len(), self.queue.capacity());
516 tracing::info!("{}|Reached queue capacity {}", id, self.queue.len());
517 }
518
519 break 'fill_rx;
520 }
521
522 if self.rx_full {
523 tracing::info!("{}|Space returned to queue {}/{}", id, self.rx_space(), self.queue.len());
524 self.rx_full = false;
525 }
526
527 while let Some((seq, item)) = self.next_rx.take() {
528 self.next_rx = self.rx.try_recv().ok();
529
530 assert_eq!(seq, self.next_queue_seq, "sequence is invalid");
531 self.queue.push_back((seq, item));
532 self.next_queue_seq += 1;
533
534 remaining_msg_count -= 1;
535 if remaining_msg_count == 0 {
536 break;
537 }
538 }
539 }
540
541 self.metrics.next_sequence.set(self.next_queue_seq);
542
543 let oldest_queue_sequence = self
544 .queue
545 .front()
546 .map(|v| v.0)
547 .unwrap_or(self.next_queue_seq);
548
549 let max_seq = oldest_queue_sequence + self.queue.len() as u64;
550 let lag_start_seq = max_seq.max(self.settings.lag_start_threshold) - self.settings.lag_start_threshold;
551 let lag_end_seq = lag_start_seq.max(max_seq.max(self.settings.lag_end_threshold) - self.settings.lag_end_threshold);
552
553 let mut min_sub_sequence_calc = self.next_queue_seq;
554 let mut earliest_lag_start_at_calc: Option<Instant> = None;
555
556 let mut i = 0;
557 'next_sub: while i < self.subscribers.len() {
558 let sub = &mut self.subscribers[i];
559
560 if (sub.allow_drop && sub.next_sequence < oldest_queue_sequence) || sub.tx.is_closed() {
562 if sub.tx.is_closed() {
563 tracing::info!("{}|Subscriber({}): channel closed, dropping", sub.id, id);
564 } else {
565 tracing::warn!(
566 "{}|Subscriber({}): lag behind available data ({} < {}), dropping",
567 id,
568 sub.id,
569 sub.next_sequence,
570 oldest_queue_sequence
571 );
572 }
573
574 if sub.lag_started_at.is_some() {
575 self.metrics.lagging_subs_gauge.dec();
576 }
577
578 self.metrics.disconnect_count.inc();
579
580 self.subscribers.swap_remove(i);
581 continue 'next_sub;
582 }
583
584 let mut offset = {
586 assert!(sub.next_sequence >= oldest_queue_sequence);
587 let offset = (sub.next_sequence - oldest_queue_sequence) as usize;
588 assert!(sub.next_sequence <= self.next_queue_seq, "sub cannot be ahead of queue sequence");
589 assert!(offset <= self.queue.len(), "sub cannot be ahead of queue sequence");
590 offset
591 };
592
593 if sub.pending.is_none() {
595 if self.queue.len() == offset {
597 i += 1;
598 continue 'next_sub;
599 }
600
601 let (seq, item) = self.queue.get(offset).unwrap();
603 assert_eq!(*seq, sub.next_sequence);
604 sub.pending = Some(item.clone());
605 }
606
607 while let Some(next) = sub.pending.take() {
609 match sub.tx.try_send((sub.next_sequence, next)) {
610 Ok(_) => {
611 sub.next_sequence += 1;
612 offset += 1;
613
614 if self.queue.len() == offset {
615 break;
616 }
617
618 let (seq, item) = self.queue.get(offset).unwrap();
619 assert_eq!(*seq, sub.next_sequence);
620 sub.pending = Some(item.clone());
621 }
622 Err(TrySendError::Closed(_)) => break,
623 Err(TrySendError::Full((_seq, item))) => {
624 sub.pending = Some(item);
625 break;
626 }
627 }
628 }
629
630 if sub.allow_drop {
631 if lag_end_seq <= sub.next_sequence {
632 if let Some(lag_start) = sub.lag_started_at.take() {
633 tracing::info!(
634 "{}|Subscriber({}): caught up after {:?}",
635 id,
636 sub.id,
637 lag_start.elapsed()
638 );
639
640 self.metrics.lagging_subs_gauge.inc();
641 }
642 }
643 else if sub.next_sequence < lag_start_seq {
644 if let Some(lag_start) = &sub.lag_started_at {
645 let lag_duration = lag_start.elapsed();
646
647 if self.settings.max_time_lag < lag_duration {
648 tracing::info!(
649 "{}|Subscriber({}): lag too high for too long ({:?}), dropping",
650 id,
651 sub.id,
652 lag_duration,
653 );
654
655 self.metrics.lagging_subs_gauge.dec();
656 self.metrics.disconnect_count.inc();
657
658 self.subscribers.swap_remove(i);
659 continue 'next_sub;
660 }
661 } else {
662 sub.lag_started_at = Some(Instant::now());
663
664 tracing::info!(
665 "{}|Subscriber({}): lag started thresh({}) < lag({})",
666 id,
667 sub.id,
668 self.settings.lag_start_threshold,
669 max_seq - sub.next_sequence,
670 );
671
672 self.metrics.lagging_subs_gauge.inc()
673 }
674 }
675 }
676
677 if let Some(lag_started_at) = &sub.lag_started_at {
678 earliest_lag_start_at_calc = match earliest_lag_start_at_calc {
679 Some(v) if v.lt(lag_started_at) => Some(v),
680 _ => sub.lag_started_at
681 };
682 }
683
684 min_sub_sequence_calc = min_sub_sequence_calc.min(sub.next_sequence);
685 i += 1;
686 }
687
688 let min_sub_sequence = min_sub_sequence_calc;
689
690 self.metrics.active_subs_gauge.set(self.subscribers.len() as u64);
691 self.metrics.min_sub_sequence_gauge.set(min_sub_sequence);
692
693 {
695 let keep_seq = min_sub_sequence.min(max_seq.max(self.settings.min_history) - self.settings.min_history);
696
697 if oldest_queue_sequence < keep_seq {
698 let remove_count = keep_seq - oldest_queue_sequence;
699 if remove_count != 0 {
700 let _ = self.queue.drain(0..remove_count as usize);
701 }
702
703 self.metrics.oldest_sequence.set(oldest_queue_sequence + remove_count);
704 }
705 }
706
707 if self.rx_closed && min_sub_sequence == max_seq {
708 tracing::info!("{}|RX closed and all subscribers caught up, shutting down worker", id);
709 return;
710 }
711
712 if self.next_client_closed && self.subscribers.is_empty() {
713 tracing::info!("{}|no subscribers and next_client_rx closed, shutting down worker", id);
714 return;
715 }
716
717 let rx_blocked = self.next_rx.is_none() && !self.rx_closed;
718 let next_timeout = earliest_lag_start_at_calc.map(|early| {
719 let now = Instant::now();
720 let expire = early + self.settings.max_time_lag;
721 (expire.max(now) - now).max(Duration::from_millis(100))
722 });
723
724 {
726 if !rx_blocked && 0 < self.rx_space() {
728 tracing::trace!("{}|have more rx", id);
729 continue;
730 }
731
732 if self.next_client.is_some() {
734 tracing::trace!("{}|have next client", id);
735 continue;
736 }
737 }
738
739 let mut timeout_fut = next_timeout.map(|duration| tokio::time::sleep(duration));
740 let mut pending_tx = Vec::new();
741 let new_client_rx = &mut self.next_client_rx;
742 let new_msg_rx = &mut self.rx;
743 let next_rx = &mut self.next_rx;
744 let next_client = &mut self.next_client;
745
746 for sub in &mut self.subscribers {
747 if sub.pending.is_some() {
748 pending_tx.push((sub.tx.reserve(), &mut sub.pending, &mut sub.next_sequence));
749 }
750 }
751
752 poll_fn(|cx| {
753 if let Some(timeout) = &mut timeout_fut {
754 if unsafe { std::pin::Pin::new_unchecked(timeout) }.poll(cx).is_ready() {
755 tracing::trace!("{}|poll: max lag timer reached", id);
756 return Poll::Ready(());
757 }
758 }
759
760 if rx_blocked {
761 if let Poll::Ready(item) = unsafe { std::pin::Pin::new_unchecked(&mut *new_msg_rx) }.poll_recv(cx) {
762 assert!(next_rx.is_none());
763
764 *next_rx = item;
765 if next_rx.is_some() {
766 tracing::trace!("{}|poll: new RX available", id);
767 } else {
768 tracing::trace!("{}|poll: RX closed", id);
769 }
770
771 return Poll::Ready(());
772 }
773 }
774
775 if let Poll::Ready(item) = unsafe { std::pin::Pin::new_unchecked(&mut *new_client_rx) }.poll_recv(cx) {
776 tracing::trace!("{}|poll: new client", id);
777
778 assert!(next_client.is_none());
779 *next_client = item;
780 return Poll::Ready(());
781 }
782
783 let mut sent = false;
784 for (reserve, pending, next_sequence) in &mut pending_tx {
785 let reserve = unsafe { std::pin::Pin::new_unchecked(reserve) };
786
787 match reserve.poll(cx) {
788 Poll::Ready(Ok(slot)) => {
789 let item = pending.take().expect("pending missing");
790 let seq = **next_sequence;
791 slot.send((seq, item));
792 **next_sequence = seq + 1;
793
794 sent = true;
795 }
796 Poll::Ready(Err(_)) => {
797 sent = true;
798 }
799 Poll::Pending => {}
800 }
801 }
802
803 if sent {
804 tracing::trace!("{}|poll: subscriber message sent", id);
805 return Poll::Ready(());
806 }
807
808 Poll::Pending
809 }).await;
810 }
811 }
812
813 fn rx_space(&self) -> usize {
814 self.queue.capacity() - self.queue.len()
815 }
816}
817
818#[cfg(test)]
819mod test {
820 use super::*;
821
822 pub fn setup_logging() {
823 let _ = tracing_subscriber::fmt().with_test_writer().try_init();
824 }
826
827 #[tokio::test]
828 async fn subscribers_shutdown_test() {
829 setup_logging();
830
831 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
832
833 let client = subs.add_client(0, true).await.unwrap();
834 tx.send("Hello World").await.unwrap();
835
836 let close_wait = subs.shutdown();
837
838 tokio::time::timeout(Duration::from_millis(100), close_wait).await
839 .expect("timeout waiting for close")
840 .expect("close handler dropped before send");
841
842 drop(client);
843 drop(tx);
844 }
845
846 #[tokio::test]
847 async fn subscribers_close_no_subs_test() {
848 setup_logging();
849
850 let close_wait = {
851 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
852 tx.send("Hello World").await.unwrap();
853 subs.closed()
854 };
855
856 tokio::time::timeout(Duration::from_millis(100), close_wait).await
857 .expect("timeout waiting for close")
858 .expect("close handler dropped before send");
859 }
860
861 #[tokio::test]
862 async fn subscribers_updates_active_metric_test() {
863 setup_logging();
864
865 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
866 tx.send("Hello World").await.unwrap();
867
868 let mut client_1 = subs.add_client(0, true).await.unwrap();
869 let mut client_2 = subs.add_client(0, true).await.unwrap();
870 let msg = client_1.recv().await.unwrap();
871 assert_eq!((0, "Hello World"), msg);
872
873 assert_eq!(2, subs.metrics_ref().active_subs_gauge.load());
874 drop(client_1);
875
876 tx.send("Test2").await.unwrap();
877 assert_eq!((0, "Hello World"), client_2.recv().await.unwrap());
878 assert_eq!((1, "Test2"), client_2.recv().await.unwrap());
879
880 tokio::time::sleep(Duration::from_millis(10)).await;
881 assert_eq!(1, subs.metrics_ref().active_subs_gauge.load());
882 }
883
884 #[tokio::test]
885 async fn subscribers_close_sub_caught_up_test() {
886 setup_logging();
887
888 let (close_wait, mut client) = {
889 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
890 tx.send("Hello World").await.unwrap();
891 let client = subs.add_client(0, true).await.unwrap();
892 (subs.closed(), client)
893 };
894
895 assert_eq!((0, "Hello World"), client.recv().await.unwrap());
896 drop(client);
897
898 tokio::time::timeout(Duration::from_millis(100), close_wait).await
899 .expect("timeout waiting for close")
900 .expect("close handler dropped before send");
901 }
902
903 #[tokio::test]
904 async fn subscribers_close_sub_caught_up_tx_alive_test() {
905 setup_logging();
906
907 let (close_wait, mut client, tx) = {
908 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
909 tx.send("Hello World").await.unwrap();
910 let client = subs.add_client(0, true).await.unwrap();
911 (subs.closed(), client, tx)
912 };
913
914 assert_eq!((0, "Hello World"), client.recv().await.unwrap());
915 drop(client);
916
917 tokio::time::timeout(Duration::from_millis(100), close_wait).await
918 .expect("timeout waiting for close")
919 .expect("close handler dropped before send");
920
921 drop(tx);
922 }
923
924 #[tokio::test]
925 async fn subscribers_close_sub_not_caught_up_test() {
926 setup_logging();
927
928 let (close_wait, mut client) = {
929 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
930 tx.send("Hello World").await.unwrap();
931 tx.send("Hello World 2").await.unwrap();
932 let client = subs.add_client(0, true).await.unwrap();
933 (subs.closed(), client)
934 };
935
936 assert_eq!((0, "Hello World"), client.recv().await.unwrap());
937 drop(client);
938
939 tokio::time::timeout(Duration::from_millis(100), close_wait).await
940 .expect("timeout waiting for close")
941 .expect("close handler dropped before send");
942 }
943
944 #[tokio::test]
945 async fn subscribers_catchup_test() {
946 setup_logging();
947
948 let (subs, mut tx) =
949 SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
950
951 tx.send("Hello WOrld").await.unwrap();
952 tx.send("What the heck").await.unwrap();
953
954 let mut sub_1 = subs.add_client(0, true).await.unwrap();
955 assert_eq!((0, "Hello WOrld"), sub_1.recv().await.unwrap());
956 assert_eq!((1, "What the heck"), sub_1.recv().await.unwrap());
957
958 let mut sub_2 = subs.add_client(0, true).await.unwrap();
959 assert_eq!((0, "Hello WOrld"), sub_2.recv().await.unwrap());
960 assert_eq!((1, "What the heck"), sub_2.recv().await.unwrap());
961
962 let mut sub_3 = subs.add_client(1, true).await.unwrap();
963 assert_eq!((1, "What the heck"), sub_3.recv().await.unwrap());
964
965 tx.send("Hehe").await.unwrap();
966 assert_eq!((2, "Hehe"), sub_1.recv().await.unwrap());
967 assert_eq!((2, "Hehe"), sub_2.recv().await.unwrap());
968 assert_eq!((2, "Hehe"), sub_3.recv().await.unwrap());
969
970 subs.shutdown_wait().await;
971 }
972
973 #[tokio::test]
974 async fn sequenced_broadcast_simple_test() {
975 setup_logging();
976
977 let (subs, mut tx) =
978 SequencedBroadcast::<u64>::new(10, SequencedBroadcastSettings::default());
979
980 let mut client = subs.add_client(10, true).await.unwrap();
981 tracing::info!("client added");
982
983 let read_task = tokio::spawn(async move {
984 let mut i = 0;
985 let mut seq = 10;
986
987 while let Some(msg) = client.recv().await {
988 assert_eq!(msg, (seq, i));
989 i += 1;
990 seq += 1;
991 }
992
993 i
994 });
995
996 let count = 1024 * 16;
997
998 for i in 0..count {
999 tx.send(i).await.unwrap();
1000 }
1001 drop(tx);
1002
1003 let total = read_task.await.unwrap();
1004 assert_eq!(total, count);
1005
1006 subs.shutdown_wait().await;
1007 }
1008
1009 #[tokio::test]
1010 async fn subscribers_test() {
1011 setup_logging();
1012
1013 let (subs, mut tx) =
1014 SequencedBroadcast::<&'static str>::new(10, SequencedBroadcastSettings::default());
1015 tx.send("Hello WOrld").await.unwrap();
1016 tx.send("What the heck").await.unwrap();
1017
1018 let mut sub = subs.add_client(10, true).await.unwrap();
1019 assert_eq!((10, "Hello WOrld"), sub.recv().await.unwrap());
1020 assert_eq!((11, "What the heck"), sub.recv().await.unwrap());
1021
1022 assert!(subs.add_client(10, true).await.is_ok());
1023 assert!(subs.add_client(11, true).await.is_ok());
1024 assert!(subs.add_client(12, true).await.is_ok());
1025 assert!(subs.add_client(13, true).await.is_err());
1026 assert!(subs.add_client(9, true).await.is_err());
1027
1028 tx.send("Butts").await.unwrap();
1029 assert_eq!((12, "Butts"), sub.recv().await.unwrap());
1030
1031 tokio::time::sleep(Duration::from_millis(1)).await;
1032
1033 tracing::info!("Metrics: {:?}", subs.metrics_ref());
1034
1035 subs.shutdown_wait().await;
1036 }
1037
1038 #[tokio::test]
1039 async fn subscribers_dont_drop_test() {
1040 setup_logging();
1041
1042 let (subs, mut tx) = SequencedBroadcast::<i64>::new(
1043 1,
1044 SequencedBroadcastSettings {
1045 max_time_lag: Duration::from_millis(100),
1046 ..Default::default()
1047 },
1048 );
1049
1050 let mut sub = subs.add_client(1, false).await.unwrap();
1051
1052 tokio::time::timeout(Duration::from_secs(3), async {
1053 loop {
1054 if tokio::time::timeout(Duration::from_secs(1), tx.send(1))
1055 .await
1056 .is_err()
1057 {
1058 break;
1059 }
1060 }
1061 }).await.expect("client must have been dropped as can still send tx");
1062
1063 tracing::info!("tx filled");
1064
1065 assert!(tokio::time::timeout(Duration::from_secs(1), tx.send(1))
1066 .await
1067 .is_err());
1068 assert_eq!((1, 1), sub.recv().await.unwrap());
1069
1070 assert!(
1071 tokio::time::timeout(Duration::from_millis(10), tx.send(1000))
1072 .await
1073 .is_ok()
1074 );
1075
1076 let sub_mut = &mut sub;
1077
1078 tokio::time::timeout(Duration::from_millis(100), async move {
1079 loop {
1080 let (_, num) = sub_mut.recv().await.unwrap();
1081 if num == 1000 {
1082 break;
1083 }
1084 }
1085 })
1086 .await
1087 .unwrap();
1088
1089 assert!(
1090 tokio::time::timeout(Duration::from_millis(10), tx.send(2000))
1091 .await
1092 .is_ok()
1093 );
1094 assert_eq!(2000, sub.recv().await.unwrap().1);
1095
1096 subs.shutdown_wait().await;
1097 }
1098
1099 #[tokio::test]
1100 async fn subscribers_no_clients_test() {
1101 setup_logging();
1102
1103 let (subs, mut tx) =
1104 SequencedBroadcast::<&'static str>::new(1, SequencedBroadcastSettings::default());
1105 let (subs, mut tx) = tokio::time::timeout(Duration::from_secs(1), async move {
1106 for _ in 0..1_000_000 {
1107 tx.send("Hello World").await.unwrap();
1108 }
1109
1110 tracing::info!("Sent 1M messages");
1111
1112 while tx.seq() != subs.metrics_ref().next_sequence.load() {
1113 tokio::time::sleep(Duration::from_millis(100)).await;
1114 }
1115
1116 tracing::info!("All 1M messages have been processed");
1117
1118 (subs, tx)
1119 })
1120 .await
1121 .unwrap();
1122
1123 let seq = tx.seq();
1124 tracing::info!("Seq: {}", seq);
1125
1126 let mut sub = subs.add_client(seq - 1, true).await.unwrap();
1127 tx.send("Test").await.unwrap();
1128
1129 assert_eq!((seq - 1, "Hello World"), sub.recv().await.unwrap());
1130 assert_eq!((seq, "Test"), sub.recv().await.unwrap());
1131
1132 subs.shutdown_wait().await;
1133 }
1134
1135 #[tokio::test]
1136 async fn continious_send_send_test() {
1137 setup_logging();
1138
1139 let (subs, mut tx) = SequencedBroadcast::<u64>::new(1, SequencedBroadcastSettings {
1140 min_history: 1024,
1141 lag_end_threshold: 128,
1142 lag_start_threshold: 512,
1143 max_time_lag: Duration::from_secs(20),
1144 ..Default::default()
1145 });
1146
1147 let mut read_tasks = vec![];
1148 for _ in 0..32 {
1149 let mut client = subs.add_client(1, true).await.unwrap();
1150
1151 let read_task = tokio::spawn(async move {
1152 let mut next = 1;
1153 while let Some((seq, num)) = client.recv().await {
1154 assert_eq!(seq, num);
1155 assert_eq!(seq, next);
1156 next = seq + 1;
1157 }
1158 next
1159 });
1160
1161 read_tasks.push(read_task);
1162 }
1163
1164 let start = Instant::now();
1165 let mut end = 1;
1166 while start.elapsed() < Duration::from_secs(5) {
1167 tokio::time::timeout(Duration::from_secs(1), tx.send(end)).await
1168 .expect("timeout sending message")
1169 .expect("failed to send message");
1170
1171 end += 1;
1172 }
1173
1174 drop(tx);
1175
1176 for read_task in read_tasks {
1177 let count = tokio::time::timeout(Duration::from_secs(1), read_task).await
1178 .expect("timeout waiting for rx task to close")
1179 .expect("rx task crashed");
1180
1181 assert_eq!(count, end);
1182 }
1183 }
1184
1185 #[tokio::test]
1186 async fn subscribers_drops_slow_sub_test() {
1187 setup_logging();
1188
1189 let (subs, mut tx) = SequencedBroadcast::<i64>::new(
1190 1,
1191 SequencedBroadcastSettings {
1192 max_time_lag: Duration::from_secs(1),
1193 subscriber_channel_len: 4,
1194 lag_start_threshold: 64,
1195 lag_end_threshold: 32,
1196 ..Default::default()
1197 },
1198 );
1199
1200 let mut fast_client = subs.add_client(1, true).await.unwrap();
1201 let mut slow_client = subs.add_client(1, true).await.unwrap();
1202
1203 let send_task = tokio::spawn(async move {
1204 let mut i = 0;
1205 for _ in 0..1_000 {
1208 tokio::time::sleep(Duration::from_millis(5)).await;
1209 i += 1;
1210 tx.send(i).await.unwrap();
1211 }
1212
1213 tracing::info!("Done sending");
1214 drop(subs);
1215
1216 i
1217 });
1218
1219 let fast_recv_task = tokio::spawn(async move {
1220 let mut last = None;
1221 while let Some(recv) = fast_client.recv().await {
1222 last = Some(recv.1);
1223 }
1224 tracing::info!("Fast Done: {:?}", last);
1225 last.unwrap()
1226 });
1227
1228 let slow_recv_task = tokio::spawn(async move {
1229 let mut last = None;
1230 while let Some(recv) = slow_client.recv().await {
1231 last = Some(recv.1);
1232 tokio::time::sleep(Duration::from_millis(100)).await;
1233 }
1234 tracing::info!("Slow done: {:?}", last);
1235 last.unwrap()
1236 });
1237
1238 let sent_i = send_task.await.unwrap();
1239 let fast_recv_i = fast_recv_task.await.unwrap();
1240 let slow_recv_i = slow_recv_task.await.unwrap();
1241
1242 assert_eq!(sent_i, 1000);
1243 assert_eq!(fast_recv_i, 1000);
1244 assert_eq!(slow_recv_i, 19);
1245 }
1246}