1use std::{
2 collections::VecDeque, fmt::Debug, future::{poll_fn, Future}, sync::{
3 atomic::{AtomicU64, Ordering},
4 Arc, LazyLock,
5 }, task::Poll, time::{Duration, Instant}
6};
7
8use arc_metrics::{IntCounter, IntGauge};
9use tokio::sync::{
10 mpsc::{channel, error::{SendError, TryRecvError, TrySendError}, Permit, Receiver, Sender},
11 oneshot,
12};
13use tokio_util::sync::CancellationToken;
14use tracing::Instrument;
15
16pub struct SequencedBroadcast<T> {
17 new_client_tx: Sender<NewClient<T>>,
18 metrics: Arc<SequencedBroadcastMetrics>,
19 shutdown: CancellationToken,
20 worker_loops: Arc<AtomicU64>,
21 closed: oneshot::Receiver<()>,
22}
23
24pub struct SequencedSender<T> {
25 next_seq: u64,
26 send: Sender<(u64, T)>,
27}
28
29pub struct SequencedReceiver<T> {
30 next_seq: u64,
31 receiver: Receiver<(u64, T)>,
32}
33
34#[derive(Default, Debug)]
35pub struct SequencedBroadcastMetrics {
36 pub oldest_sequence: IntGauge,
37 pub next_sequence: IntGauge,
38 pub new_client_drop_count: IntCounter,
39 pub new_client_accept_count: IntCounter,
40 pub lagging_subs_gauge: IntGauge,
41 pub active_subs_gauge: IntGauge,
42 pub min_sub_sequence_gauge: IntGauge,
43 pub disconnect_count: IntCounter,
44 pub worker_loops: IntCounter,
45}
46
47struct Subscriber<T> {
48 id: u64,
49 next_sequence: u64,
50 tx: Sender<(u64, T)>,
51 allow_drop: bool,
52 lag_started_at: Option<Instant>,
53 pending: Option<T>,
54}
55
56#[derive(Debug, Clone)]
57pub struct SequencedBroadcastSettings {
58 pub subscriber_channel_len: usize,
59 pub lag_start_threshold: u64,
60 pub lag_end_threshold: u64,
61 pub max_time_lag: Duration,
62 pub min_history: u64,
63}
64
65struct Worker<T> {
66 rx: Receiver<(u64, T)>,
67 next_rx: Option<(u64, T)>,
68 rx_closed: bool,
69 rx_full: bool,
70
71 next_client_rx: Receiver<NewClient<T>>,
72 next_client: Option<NewClient<T>>,
73 next_client_closed: bool,
74
75 next_sub_id: u64,
76 subscribers: Vec<Subscriber<T>>,
77 queue: VecDeque<(u64, T)>,
78 next_queue_seq: u64,
79 metrics: Arc<SequencedBroadcastMetrics>,
80 settings: SequencedBroadcastSettings,
81 shutdown: CancellationToken,
82 worker_loops: Arc<AtomicU64>,
83 closed: oneshot::Sender<()>,
84}
85
86impl Default for SequencedBroadcastSettings {
87 fn default() -> Self {
88 SequencedBroadcastSettings {
89 subscriber_channel_len: 32,
90 lag_start_threshold: 1024 * 8,
91 lag_end_threshold: 1024 * 4,
92 max_time_lag: Duration::from_secs(2),
93 min_history: 2048,
94 }
95 }
96}
97
98impl<T> SequencedSender<T> {
99 pub fn new(next_seq: u64, send: Sender<(u64, T)>) -> Self {
100 SequencedSender { next_seq, send }
101 }
102
103 pub fn is_closed(&self) -> bool {
104 self.send.is_closed()
105 }
106
107 pub async fn closed(&self) {
108 self.send.closed().await
109 }
110
111 pub async fn safe_send(&mut self, seq: u64, item: T) -> Result<(), SequencedSenderError<T>> {
112 self._send(Some(seq), item).await
113 }
114
115 pub async fn send(&mut self, item: T) -> Result<(), SequencedSenderError<T>> {
116 self._send(None, item).await
117 }
118
119 pub fn try_send(&mut self, item: T) -> Result<(), TrySendError<T>> {
120 match self.send.try_send((self.next_seq, item)) {
121 Ok(()) => {
122 self.next_seq += 1;
123 Ok(())
124 }
125 Err(TrySendError::Full(err)) => Err(TrySendError::Full(err.1)),
126 Err(TrySendError::Closed(err)) => Err(TrySendError::Closed(err.1)),
127 }
128 }
129
130 pub async fn reserve(&mut self) -> Result<SequencedSenderPermit<T>, SendError<()>> {
131 let permit = self.send.reserve().await?;
132
133 Ok(SequencedSenderPermit {
134 next_seq: &mut self.next_seq,
135 permit,
136 })
137 }
138
139 async fn _send(&mut self, seq: Option<u64>, item: T) -> Result<(), SequencedSenderError<T>> {
140 if let Some(seq) = seq {
141 if seq != self.next_seq {
142 return Err(SequencedSenderError::InvalidSequence(self.next_seq, item));
143 }
144 }
145
146 if let Err(error) = self.send.send((self.next_seq, item)).await {
147 return Err(SequencedSenderError::ChannelClosed(error.0.1));
148 }
149
150 self.next_seq += 1;
151 Ok(())
152 }
153
154 pub fn seq(&self) -> u64 {
155 self.next_seq
156 }
157}
158
159pub struct SequencedSenderPermit<'a, T> {
160 next_seq: &'a mut u64,
161 permit: Permit<'a, (u64, T)>,
162}
163
164impl<T> SequencedSenderPermit<'_, T> {
165 pub fn send(self, item: T) {
166 let seq = *self.next_seq;
167 self.permit.send((seq, item));
168 *self.next_seq = seq + 1;
169 }
170}
171
172impl<T> SequencedReceiver<T> {
173 pub fn new(next_seq: u64, receiver: Receiver<(u64, T)>) -> Self {
174 SequencedReceiver {
175 next_seq,
176 receiver
177 }
178 }
179
180 pub fn is_closed(&self) -> bool {
181 self.receiver.is_closed()
182 }
183
184 pub async fn recv(&mut self) -> Option<(u64, T)> {
185 let (seq, action) = self.receiver.recv().await?;
186 if self.next_seq != seq {
187 panic!("expected sequence: {} but got: {}", self.next_seq, seq);
188 }
189 self.next_seq += 1;
190 Some((seq, action))
191 }
192
193 pub fn try_recv(&mut self) -> Result<(u64, T), TryRecvError> {
194 match self.receiver.try_recv() {
195 Ok((seq, action)) => {
196 if self.next_seq != seq {
197 panic!("expected sequence: {} but got: {}", self.next_seq, seq);
198 }
199 self.next_seq += 1;
200 Ok((seq, action))
201 }
202 Err(error) => Err(error)
203 }
204 }
205
206 pub fn unbundle(self) -> (u64, Receiver<(u64, T)>) {
207 (self.next_seq, self.receiver)
208 }
209
210 pub fn next_seq(&self) -> u64 {
211 self.next_seq
212 }
213}
214
215#[derive(PartialEq, Eq)]
216pub enum SequencedSenderError<T> {
217 InvalidSequence(u64, T),
218 ChannelClosed(T),
219}
220
221impl<T> Debug for SequencedSenderError<T> {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 match self {
224 Self::InvalidSequence(seq, _) => write!(f, "InvalidSequence(seq: {})", seq),
225 Self::ChannelClosed(_) => write!(f, "ChannelClosed"),
226 }
227 }
228}
229
230impl<T> SequencedSenderError<T> {
231 pub fn into_inner(self) -> T {
232 match self {
233 Self::InvalidSequence(_, v) => v,
234 Self::ChannelClosed(v) => v,
235 }
236 }
237}
238
239impl<T: Send + Clone + 'static> SequencedBroadcast<T> {
240 pub fn new(next_seq: u64, settings: SequencedBroadcastSettings) -> (Self, SequencedSender<T>) {
241 let (tx, rx) = channel(1024);
242 let tx = SequencedSender::new(next_seq, tx);
243 let rx = SequencedReceiver::new(next_seq, rx);
244
245 (
246 Self::new2(rx, settings),
247 tx
248 )
249 }
250
251 pub fn new2(receiver: SequencedReceiver<T>, settings: SequencedBroadcastSettings) -> Self {
252 let queue_cap = 2 * (
253 (settings.lag_start_threshold as usize)
254 .next_power_of_two()
255 .max(1024)
256 ).max((settings.min_history as usize).next_power_of_two());
257
258 assert!(settings.lag_end_threshold <= settings.lag_start_threshold);
259
260 let (client_tx, client_rx) = channel(32);
261
262 let metrics = Arc::new(SequencedBroadcastMetrics {
263 oldest_sequence: {
264 let i = IntGauge::default();
265 i.set(receiver.next_seq);
266 i
267 },
268 next_sequence: {
269 let i = IntGauge::default();
270 i.set(receiver.next_seq);
271 i
272 },
273 ..Default::default()
274 });
275
276 let shutdown = CancellationToken::new();
277 let current_span = tracing::Span::current();
278 let (closed_tx, closed_rx) = oneshot::channel();
279
280 let worker_loops = Arc::new(AtomicU64::new(0));
281
282 tokio::spawn(
283 Worker {
284 rx: receiver.receiver,
285 next_rx: None,
286 rx_full: false,
287 rx_closed: false,
288
289 next_client_rx: client_rx,
290 next_client: None,
291 next_client_closed: false,
292
293 next_sub_id: 1,
294 subscribers: Vec::with_capacity(32),
295 queue: VecDeque::with_capacity(queue_cap),
296 next_queue_seq: receiver.next_seq,
297 metrics: metrics.clone(),
298 settings,
299 shutdown: shutdown.clone(),
300 worker_loops: worker_loops.clone(),
301 closed: closed_tx,
302 }
303 .start()
304 .instrument(current_span),
305 );
306
307 Self {
308 new_client_tx: client_tx,
309 metrics,
310 shutdown,
311 worker_loops,
312 closed: closed_rx,
313 }
314 }
315
316 pub async fn add_client(
317 &self,
318 next_sequence: u64,
319 allow_drop: bool,
320 ) -> Result<SequencedReceiver<T>, NewClientError> {
321 let (tx, rx) = oneshot::channel();
322
323 self.new_client_tx
324 .send(NewClient {
325 response: tx,
326 allow_drop,
327 next_sequence,
328 })
329 .await
330 .expect("Failed to queue new subscriber, worker crashed");
331
332 rx.await.expect("worker closed")
333 }
334
335 pub fn metrics_ref(&self) -> &SequencedBroadcastMetrics {
336 &self.metrics
337 }
338
339 pub fn metrics(&self) -> Arc<SequencedBroadcastMetrics> {
340 self.metrics.clone()
341 }
342
343 pub fn worker_loops(&self) -> u64 {
344 self.worker_loops.load(Ordering::Relaxed)
345 }
346
347 pub fn shutdown(self) -> oneshot::Receiver<()> {
348 self.shutdown.cancel();
349 self.closed
350 }
351
352 pub async fn shutdown_wait(self) {
353 self.shutdown().await.unwrap();
354 }
355
356 pub fn closed(self) -> oneshot::Receiver<()> {
357 self.closed
358 }
359}
360
361struct NewClient<T> {
362 response: oneshot::Sender<Result<SequencedReceiver<T>, NewClientError>>,
363 next_sequence: u64,
364 allow_drop: bool,
365}
366
367#[derive(Debug)]
368pub enum NewClientError {
369 SequenceTooFarAhead { seq: u64, max: u64 },
370 SequenceTooFarBehind { seq: u64, min: u64 },
371}
372
373impl<T> Debug for NewClient<T> {
374 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375 write!(
376 f,
377 "NewClient {{ next_sequence: {}, allow_drop: {} }}",
378 self.next_sequence, self.allow_drop
379 )
380 }
381}
382
383static WORKER_ID: LazyLock<Arc<AtomicU64>> = LazyLock::new(|| Arc::new(AtomicU64::new(1)));
384
385impl<T: Send + Clone + 'static> Worker<T> {
386 async fn start(mut self) {
387 let id = WORKER_ID.fetch_add(1, Ordering::SeqCst);
388 tracing::info!(id, "{}|SequencedBroadcastWorker Started", id);
389 let start = Instant::now();
390
391 self._start(id).await;
392 let elapsed = start.elapsed();
393 let iter = self.worker_loops.load(Ordering::Relaxed);
394
395 tracing::info!(id, ?elapsed, iter, "{}|SequencedBroadcastWorker Stopped", id);
396
397 let _ = self.closed.send(());
398 }
399
400 async fn _start(&mut self, id: u64) {
401 loop {
402 self.worker_loops.fetch_add(1, Ordering::Relaxed);
403 tokio::task::yield_now().await;
404
405 if self.next_client.is_none() {
406 self.next_client = match self.next_client_rx.try_recv() {
407 Ok(item) => Some(item),
408 Err(TryRecvError::Empty) => None,
409 Err(TryRecvError::Disconnected) => {
410 self.next_client_closed = true;
411 None
412 }
413 };
414 }
415
416 if self.shutdown.is_cancelled() {
417 tracing::info!("{}|Stopping worker due to shutdown", id);
418 break;
419 }
420
421 if !self.next_client_closed {
423 let mut max_per_loop = 32;
424 let min_allowed_seq = self
425 .queue
426 .front()
427 .map(|i| i.0)
428 .unwrap_or(self.next_queue_seq);
429
430 while let Some(new) = self.next_client.take() {
431 self.next_client = self.next_client_rx.try_recv().ok();
432
433 if new.next_sequence < min_allowed_seq
435 || self.next_queue_seq < new.next_sequence
436 {
437 self.metrics.new_client_drop_count.inc();
438
439 if new.next_sequence < min_allowed_seq {
440 tracing::info!(
441 "{}|Subscriber rejected, seq({}) < min_allowed({})",
442 id,
443 new.next_sequence,
444 min_allowed_seq
445 );
446
447 let _ = new.response.send(Err(NewClientError::SequenceTooFarBehind {
448 seq: new.next_sequence,
449 min: min_allowed_seq
450 }));
451 } else {
452 tracing::info!(
453 "{}|Subscriber rejected, max_seq({}) < seq({})",
454 id,
455 self.next_queue_seq,
456 new.next_sequence
457 );
458
459 let _ = new.response.send(Err(NewClientError::SequenceTooFarAhead {
460 seq: new.next_sequence,
461 max: self.next_queue_seq
462 }));
463 }
464
465 continue;
466 }
467
468 self.metrics.new_client_accept_count.inc();
469
470 let (tx, rx) = channel(self.settings.subscriber_channel_len);
472 let rx = SequencedReceiver::<T> {
473 receiver: rx,
474 next_seq: new.next_sequence,
475 };
476
477 if new.response.send(Ok(rx)).is_ok() {
478 let sub_id = self.next_sub_id;
479 self.next_sub_id += 1;
480
481 tracing::info!(
482 "{}|Subscriber({}): Added, allow_drop: {}, next_sequence: {}, min_allowed_seq: {}",
483 id, sub_id, new.allow_drop, new.next_sequence, min_allowed_seq,
484 );
485
486 self.subscribers.push(Subscriber {
487 id: sub_id,
488 allow_drop: new.allow_drop,
489 next_sequence: new.next_sequence,
490 pending: None,
491 tx,
492 lag_started_at: None,
493 });
494 } else {
495 tracing::warn!("{}|New subscriber accepted but receiver dropped", id);
496 }
497
498 if max_per_loop == 0 {
500 break;
501 }
502
503 max_per_loop -= 1;
504 }
505 }
506
507 'fill_rx: {
509 if self.next_rx.is_none() {
510 self.next_rx = match self.rx.try_recv() {
511 Ok(msg) => Some(msg),
512 Err(TryRecvError::Disconnected) => {
513 self.rx_closed = true;
514 None
515 }
516 Err(TryRecvError::Empty) => None,
517 };
518 }
519
520 let mut remaining_msg_count = self.rx_space().min(1024);
521 if remaining_msg_count == 0 {
522 if !self.rx_full {
523 self.rx_full = true;
524 assert_eq!(self.queue.len(), self.queue.capacity());
525 tracing::info!("{}|Reached queue capacity {}", id, self.queue.len());
526 }
527
528 break 'fill_rx;
529 }
530
531 if self.rx_full {
532 tracing::info!("{}|Space returned to queue {}/{}", id, self.rx_space(), self.queue.len());
533 self.rx_full = false;
534 }
535
536 while let Some((seq, item)) = self.next_rx.take() {
537 self.next_rx = self.rx.try_recv().ok();
538
539 assert_eq!(seq, self.next_queue_seq, "sequence is invalid");
540 self.queue.push_back((seq, item));
541 self.next_queue_seq += 1;
542
543 remaining_msg_count -= 1;
544 if remaining_msg_count == 0 {
545 break;
546 }
547 }
548 }
549
550 self.metrics.next_sequence.set(self.next_queue_seq);
551
552 let oldest_queue_sequence = self
553 .queue
554 .front()
555 .map(|v| v.0)
556 .unwrap_or(self.next_queue_seq);
557
558 let max_seq = oldest_queue_sequence + self.queue.len() as u64;
559 let lag_start_seq = max_seq.max(self.settings.lag_start_threshold) - self.settings.lag_start_threshold;
560 let lag_end_seq = lag_start_seq.max(max_seq.max(self.settings.lag_end_threshold) - self.settings.lag_end_threshold);
561
562 let mut min_sub_sequence_calc = self.next_queue_seq;
563 let mut earliest_lag_start_at_calc: Option<Instant> = None;
564
565 let mut i = 0;
566 'next_sub: while i < self.subscribers.len() {
567 let sub = &mut self.subscribers[i];
568
569 if (sub.allow_drop && sub.next_sequence < oldest_queue_sequence) || sub.tx.is_closed() {
571 if sub.tx.is_closed() {
572 tracing::info!("{}|Subscriber({}): channel closed, dropping", sub.id, id);
573 } else {
574 tracing::warn!(
575 "{}|Subscriber({}): lag behind available data ({} < {}), dropping",
576 id,
577 sub.id,
578 sub.next_sequence,
579 oldest_queue_sequence
580 );
581 }
582
583 if sub.lag_started_at.is_some() {
584 self.metrics.lagging_subs_gauge.dec();
585 }
586
587 self.metrics.disconnect_count.inc();
588
589 self.subscribers.swap_remove(i);
590 continue 'next_sub;
591 }
592
593 let mut offset = {
595 assert!(sub.next_sequence >= oldest_queue_sequence);
596 let offset = (sub.next_sequence - oldest_queue_sequence) as usize;
597 assert!(sub.next_sequence <= self.next_queue_seq, "sub cannot be ahead of queue sequence");
598 assert!(offset <= self.queue.len(), "sub cannot be ahead of queue sequence");
599 offset
600 };
601
602 if sub.pending.is_none() {
604 if self.queue.len() == offset {
606 i += 1;
607 continue 'next_sub;
608 }
609
610 let (seq, item) = self.queue.get(offset).unwrap();
612 assert_eq!(*seq, sub.next_sequence);
613 sub.pending = Some(item.clone());
614 }
615
616 while let Some(next) = sub.pending.take() {
618 match sub.tx.try_send((sub.next_sequence, next)) {
619 Ok(_) => {
620 sub.next_sequence += 1;
621 offset += 1;
622
623 if self.queue.len() == offset {
624 break;
625 }
626
627 let (seq, item) = self.queue.get(offset).unwrap();
628 assert_eq!(*seq, sub.next_sequence);
629 sub.pending = Some(item.clone());
630 }
631 Err(TrySendError::Closed(_)) => break,
632 Err(TrySendError::Full((_seq, item))) => {
633 sub.pending = Some(item);
634 break;
635 }
636 }
637 }
638
639 if sub.allow_drop {
640 if lag_end_seq <= sub.next_sequence {
641 if let Some(lag_start) = sub.lag_started_at.take() {
642 tracing::info!(
643 "{}|Subscriber({}): caught up after {:?}",
644 id,
645 sub.id,
646 lag_start.elapsed()
647 );
648
649 self.metrics.lagging_subs_gauge.inc();
650 }
651 }
652 else if sub.next_sequence < lag_start_seq {
653 if let Some(lag_start) = &sub.lag_started_at {
654 let lag_duration = lag_start.elapsed();
655
656 if self.settings.max_time_lag < lag_duration {
657 tracing::info!(
658 "{}|Subscriber({}): lag too high for too long ({:?}), dropping",
659 id,
660 sub.id,
661 lag_duration,
662 );
663
664 self.metrics.lagging_subs_gauge.dec();
665 self.metrics.disconnect_count.inc();
666
667 self.subscribers.swap_remove(i);
668 continue 'next_sub;
669 }
670 } else {
671 sub.lag_started_at = Some(Instant::now());
672
673 tracing::info!(
674 "{}|Subscriber({}): lag started thresh({}) < lag({})",
675 id,
676 sub.id,
677 self.settings.lag_start_threshold,
678 max_seq - sub.next_sequence,
679 );
680
681 self.metrics.lagging_subs_gauge.inc()
682 }
683 }
684 }
685
686 if let Some(lag_started_at) = &sub.lag_started_at {
687 earliest_lag_start_at_calc = match earliest_lag_start_at_calc {
688 Some(v) if v.lt(lag_started_at) => Some(v),
689 _ => sub.lag_started_at
690 };
691 }
692
693 min_sub_sequence_calc = min_sub_sequence_calc.min(sub.next_sequence);
694 i += 1;
695 }
696
697 let min_sub_sequence = min_sub_sequence_calc;
698
699 self.metrics.active_subs_gauge.set(self.subscribers.len() as u64);
700 self.metrics.min_sub_sequence_gauge.set(min_sub_sequence);
701
702 {
704 let keep_seq = min_sub_sequence.min(max_seq.max(self.settings.min_history) - self.settings.min_history);
705
706 if oldest_queue_sequence < keep_seq {
707 let remove_count = keep_seq - oldest_queue_sequence;
708 if remove_count != 0 {
709 let _ = self.queue.drain(0..remove_count as usize);
710 }
711
712 self.metrics.oldest_sequence.set(oldest_queue_sequence + remove_count);
713 }
714 }
715
716 if self.rx_closed && min_sub_sequence == max_seq {
717 tracing::info!("{}|RX closed and all subscribers caught up, shutting down worker", id);
718 return;
719 }
720
721 if self.next_client_closed && self.subscribers.is_empty() {
722 tracing::info!("{}|no subscribers and next_client_rx closed, shutting down worker", id);
723 return;
724 }
725
726 let rx_blocked = self.next_rx.is_none() && !self.rx_closed;
727 let next_timeout = earliest_lag_start_at_calc.map(|early| {
728 let now = Instant::now();
729 let expire = early + self.settings.max_time_lag;
730 (expire.max(now) - now).max(Duration::from_millis(100))
731 });
732
733 {
735 if !rx_blocked && 0 < self.rx_space() {
737 tracing::trace!("{}|have more rx", id);
738 continue;
739 }
740
741 if self.next_client.is_some() {
743 tracing::trace!("{}|have next client", id);
744 continue;
745 }
746 }
747
748 let mut timeout_fut = next_timeout.map(|duration| tokio::time::sleep(duration));
749 let mut pending_tx = Vec::new();
750 let new_client_rx = &mut self.next_client_rx;
751 let new_msg_rx = &mut self.rx;
752 let next_rx = &mut self.next_rx;
753 let next_client = &mut self.next_client;
754
755 for sub in &mut self.subscribers {
756 if sub.pending.is_some() {
757 pending_tx.push((sub.tx.reserve(), &mut sub.pending, &mut sub.next_sequence));
758 }
759 }
760
761 poll_fn(|cx| {
762 if let Some(timeout) = &mut timeout_fut {
763 if unsafe { std::pin::Pin::new_unchecked(timeout) }.poll(cx).is_ready() {
764 tracing::trace!("{}|poll: max lag timer reached", id);
765 return Poll::Ready(());
766 }
767 }
768
769 if rx_blocked {
770 if let Poll::Ready(item) = unsafe { std::pin::Pin::new_unchecked(&mut *new_msg_rx) }.poll_recv(cx) {
771 assert!(next_rx.is_none());
772
773 *next_rx = item;
774 if next_rx.is_some() {
775 tracing::trace!("{}|poll: new RX available", id);
776 } else {
777 tracing::trace!("{}|poll: RX closed", id);
778 }
779
780 return Poll::Ready(());
781 }
782 }
783
784 if let Poll::Ready(item) = unsafe { std::pin::Pin::new_unchecked(&mut *new_client_rx) }.poll_recv(cx) {
785 tracing::trace!("{}|poll: new client", id);
786
787 assert!(next_client.is_none());
788 *next_client = item;
789 return Poll::Ready(());
790 }
791
792 let mut sent = false;
793 for (reserve, pending, next_sequence) in &mut pending_tx {
794 let reserve = unsafe { std::pin::Pin::new_unchecked(reserve) };
795
796 match reserve.poll(cx) {
797 Poll::Ready(Ok(slot)) => {
798 let item = pending.take().expect("pending missing");
799 let seq = **next_sequence;
800 slot.send((seq, item));
801 **next_sequence = seq + 1;
802
803 sent = true;
804 }
805 Poll::Ready(Err(_)) => {
806 sent = true;
807 }
808 Poll::Pending => {}
809 }
810 }
811
812 if sent {
813 tracing::trace!("{}|poll: subscriber message sent", id);
814 return Poll::Ready(());
815 }
816
817 Poll::Pending
818 }).await;
819 }
820 }
821
822 fn rx_space(&self) -> usize {
823 self.queue.capacity() - self.queue.len()
824 }
825}
826
827#[cfg(test)]
828mod test {
829 use super::*;
830
831 pub fn setup_logging() {
832 let _ = tracing_subscriber::fmt().with_test_writer().try_init();
833 }
835
836 #[tokio::test]
837 async fn subscribers_shutdown_test() {
838 setup_logging();
839
840 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
841
842 let client = subs.add_client(0, true).await.unwrap();
843 tx.send("Hello World").await.unwrap();
844
845 let close_wait = subs.shutdown();
846
847 tokio::time::timeout(Duration::from_millis(100), close_wait).await
848 .expect("timeout waiting for close")
849 .expect("close handler dropped before send");
850
851 drop(client);
852 drop(tx);
853 }
854
855 #[tokio::test]
856 async fn subscribers_close_no_subs_test() {
857 setup_logging();
858
859 let close_wait = {
860 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
861 tx.send("Hello World").await.unwrap();
862 subs.closed()
863 };
864
865 tokio::time::timeout(Duration::from_millis(100), close_wait).await
866 .expect("timeout waiting for close")
867 .expect("close handler dropped before send");
868 }
869
870 #[tokio::test]
871 async fn subscribers_updates_active_metric_test() {
872 setup_logging();
873
874 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
875 tx.send("Hello World").await.unwrap();
876
877 let mut client_1 = subs.add_client(0, true).await.unwrap();
878 let mut client_2 = subs.add_client(0, true).await.unwrap();
879 let msg = client_1.recv().await.unwrap();
880 assert_eq!((0, "Hello World"), msg);
881
882 assert_eq!(2, subs.metrics_ref().active_subs_gauge.load());
883 drop(client_1);
884
885 tx.send("Test2").await.unwrap();
886 assert_eq!((0, "Hello World"), client_2.recv().await.unwrap());
887 assert_eq!((1, "Test2"), client_2.recv().await.unwrap());
888
889 tokio::time::sleep(Duration::from_millis(10)).await;
890 assert_eq!(1, subs.metrics_ref().active_subs_gauge.load());
891 }
892
893 #[tokio::test]
894 async fn subscribers_close_sub_caught_up_test() {
895 setup_logging();
896
897 let (close_wait, mut client) = {
898 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
899 tx.send("Hello World").await.unwrap();
900 let client = subs.add_client(0, true).await.unwrap();
901 (subs.closed(), client)
902 };
903
904 assert_eq!((0, "Hello World"), client.recv().await.unwrap());
905 drop(client);
906
907 tokio::time::timeout(Duration::from_millis(100), close_wait).await
908 .expect("timeout waiting for close")
909 .expect("close handler dropped before send");
910 }
911
912 #[tokio::test]
913 async fn subscribers_close_sub_caught_up_tx_alive_test() {
914 setup_logging();
915
916 let (close_wait, mut client, tx) = {
917 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
918 tx.send("Hello World").await.unwrap();
919 let client = subs.add_client(0, true).await.unwrap();
920 (subs.closed(), client, tx)
921 };
922
923 assert_eq!((0, "Hello World"), client.recv().await.unwrap());
924 drop(client);
925
926 tokio::time::timeout(Duration::from_millis(100), close_wait).await
927 .expect("timeout waiting for close")
928 .expect("close handler dropped before send");
929
930 drop(tx);
931 }
932
933 #[tokio::test]
934 async fn subscribers_close_sub_not_caught_up_test() {
935 setup_logging();
936
937 let (close_wait, mut client) = {
938 let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
939 tx.send("Hello World").await.unwrap();
940 tx.send("Hello World 2").await.unwrap();
941 let client = subs.add_client(0, true).await.unwrap();
942 (subs.closed(), client)
943 };
944
945 assert_eq!((0, "Hello World"), client.recv().await.unwrap());
946 drop(client);
947
948 tokio::time::timeout(Duration::from_millis(100), close_wait).await
949 .expect("timeout waiting for close")
950 .expect("close handler dropped before send");
951 }
952
953 #[tokio::test]
954 async fn subscribers_catchup_test() {
955 setup_logging();
956
957 let (subs, mut tx) =
958 SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
959
960 tx.send("Hello WOrld").await.unwrap();
961 tx.send("What the heck").await.unwrap();
962
963 let mut sub_1 = subs.add_client(0, true).await.unwrap();
964 assert_eq!((0, "Hello WOrld"), sub_1.recv().await.unwrap());
965 assert_eq!((1, "What the heck"), sub_1.recv().await.unwrap());
966
967 let mut sub_2 = subs.add_client(0, true).await.unwrap();
968 assert_eq!((0, "Hello WOrld"), sub_2.recv().await.unwrap());
969 assert_eq!((1, "What the heck"), sub_2.recv().await.unwrap());
970
971 let mut sub_3 = subs.add_client(1, true).await.unwrap();
972 assert_eq!((1, "What the heck"), sub_3.recv().await.unwrap());
973
974 tx.send("Hehe").await.unwrap();
975 assert_eq!((2, "Hehe"), sub_1.recv().await.unwrap());
976 assert_eq!((2, "Hehe"), sub_2.recv().await.unwrap());
977 assert_eq!((2, "Hehe"), sub_3.recv().await.unwrap());
978
979 subs.shutdown_wait().await;
980 }
981
982 #[tokio::test]
983 async fn sequenced_broadcast_simple_test() {
984 setup_logging();
985
986 let (subs, mut tx) =
987 SequencedBroadcast::<u64>::new(10, SequencedBroadcastSettings::default());
988
989 let mut client = subs.add_client(10, true).await.unwrap();
990 tracing::info!("client added");
991
992 let read_task = tokio::spawn(async move {
993 let mut i = 0;
994 let mut seq = 10;
995
996 while let Some(msg) = client.recv().await {
997 assert_eq!(msg, (seq, i));
998 i += 1;
999 seq += 1;
1000 }
1001
1002 i
1003 });
1004
1005 let count = 1024 * 16;
1006
1007 for i in 0..count {
1008 tx.send(i).await.unwrap();
1009 }
1010 drop(tx);
1011
1012 let total = read_task.await.unwrap();
1013 assert_eq!(total, count);
1014
1015 subs.shutdown_wait().await;
1016 }
1017
1018 #[tokio::test]
1019 async fn subscribers_test() {
1020 setup_logging();
1021
1022 let (subs, mut tx) =
1023 SequencedBroadcast::<&'static str>::new(10, SequencedBroadcastSettings::default());
1024 tx.send("Hello WOrld").await.unwrap();
1025 tx.send("What the heck").await.unwrap();
1026
1027 let mut sub = subs.add_client(10, true).await.unwrap();
1028 assert_eq!((10, "Hello WOrld"), sub.recv().await.unwrap());
1029 assert_eq!((11, "What the heck"), sub.recv().await.unwrap());
1030
1031 assert!(subs.add_client(10, true).await.is_ok());
1032 assert!(subs.add_client(11, true).await.is_ok());
1033 assert!(subs.add_client(12, true).await.is_ok());
1034 assert!(subs.add_client(13, true).await.is_err());
1035 assert!(subs.add_client(9, true).await.is_err());
1036
1037 tx.send("Butts").await.unwrap();
1038 assert_eq!((12, "Butts"), sub.recv().await.unwrap());
1039
1040 tokio::time::sleep(Duration::from_millis(1)).await;
1041
1042 tracing::info!("Metrics: {:?}", subs.metrics_ref());
1043
1044 subs.shutdown_wait().await;
1045 }
1046
1047 #[tokio::test]
1048 async fn subscribers_dont_drop_test() {
1049 setup_logging();
1050
1051 let (subs, mut tx) = SequencedBroadcast::<i64>::new(
1052 1,
1053 SequencedBroadcastSettings {
1054 max_time_lag: Duration::from_millis(100),
1055 ..Default::default()
1056 },
1057 );
1058
1059 let mut sub = subs.add_client(1, false).await.unwrap();
1060
1061 tokio::time::timeout(Duration::from_secs(3), async {
1062 loop {
1063 if tokio::time::timeout(Duration::from_secs(1), tx.send(1))
1064 .await
1065 .is_err()
1066 {
1067 break;
1068 }
1069 }
1070 }).await.expect("client must have been dropped as can still send tx");
1071
1072 tracing::info!("tx filled");
1073
1074 assert!(tokio::time::timeout(Duration::from_secs(1), tx.send(1))
1075 .await
1076 .is_err());
1077 assert_eq!((1, 1), sub.recv().await.unwrap());
1078
1079 assert!(
1080 tokio::time::timeout(Duration::from_millis(10), tx.send(1000))
1081 .await
1082 .is_ok()
1083 );
1084
1085 let sub_mut = &mut sub;
1086
1087 tokio::time::timeout(Duration::from_millis(100), async move {
1088 loop {
1089 let (_, num) = sub_mut.recv().await.unwrap();
1090 if num == 1000 {
1091 break;
1092 }
1093 }
1094 })
1095 .await
1096 .unwrap();
1097
1098 assert!(
1099 tokio::time::timeout(Duration::from_millis(10), tx.send(2000))
1100 .await
1101 .is_ok()
1102 );
1103 assert_eq!(2000, sub.recv().await.unwrap().1);
1104
1105 subs.shutdown_wait().await;
1106 }
1107
1108 #[tokio::test]
1109 async fn subscribers_no_clients_test() {
1110 setup_logging();
1111
1112 let (subs, mut tx) =
1113 SequencedBroadcast::<&'static str>::new(1, SequencedBroadcastSettings::default());
1114 let (subs, mut tx) = tokio::time::timeout(Duration::from_secs(1), async move {
1115 for _ in 0..1_000_000 {
1116 tx.send("Hello World").await.unwrap();
1117 }
1118
1119 tracing::info!("Sent 1M messages");
1120
1121 while tx.seq() != subs.metrics_ref().next_sequence.load() {
1122 tokio::time::sleep(Duration::from_millis(100)).await;
1123 }
1124
1125 tracing::info!("All 1M messages have been processed");
1126
1127 (subs, tx)
1128 })
1129 .await
1130 .unwrap();
1131
1132 let seq = tx.seq();
1133 tracing::info!("Seq: {}", seq);
1134
1135 let mut sub = subs.add_client(seq - 1, true).await.unwrap();
1136 tx.send("Test").await.unwrap();
1137
1138 assert_eq!((seq - 1, "Hello World"), sub.recv().await.unwrap());
1139 assert_eq!((seq, "Test"), sub.recv().await.unwrap());
1140
1141 subs.shutdown_wait().await;
1142 }
1143
1144 #[tokio::test]
1145 async fn continious_send_send_test() {
1146 setup_logging();
1147
1148 let (subs, mut tx) = SequencedBroadcast::<u64>::new(1, SequencedBroadcastSettings {
1149 min_history: 1024,
1150 lag_end_threshold: 128,
1151 lag_start_threshold: 512,
1152 max_time_lag: Duration::from_secs(20),
1153 ..Default::default()
1154 });
1155
1156 let mut read_tasks = vec![];
1157 for _ in 0..32 {
1158 let mut client = subs.add_client(1, true).await.unwrap();
1159
1160 let read_task = tokio::spawn(async move {
1161 let mut next = 1;
1162 while let Some((seq, num)) = client.recv().await {
1163 assert_eq!(seq, num);
1164 assert_eq!(seq, next);
1165 next = seq + 1;
1166 }
1167 next
1168 });
1169
1170 read_tasks.push(read_task);
1171 }
1172
1173 let start = Instant::now();
1174 let mut end = 1;
1175 while start.elapsed() < Duration::from_secs(5) {
1176 tokio::time::timeout(Duration::from_secs(1), tx.send(end)).await
1177 .expect("timeout sending message")
1178 .expect("failed to send message");
1179
1180 end += 1;
1181 }
1182
1183 drop(tx);
1184
1185 for read_task in read_tasks {
1186 let count = tokio::time::timeout(Duration::from_secs(1), read_task).await
1187 .expect("timeout waiting for rx task to close")
1188 .expect("rx task crashed");
1189
1190 assert_eq!(count, end);
1191 }
1192 }
1193
1194 #[tokio::test]
1195 async fn subscribers_drops_slow_sub_test() {
1196 setup_logging();
1197
1198 let (subs, mut tx) = SequencedBroadcast::<i64>::new(
1199 1,
1200 SequencedBroadcastSettings {
1201 max_time_lag: Duration::from_secs(1),
1202 subscriber_channel_len: 4,
1203 lag_start_threshold: 64,
1204 lag_end_threshold: 32,
1205 ..Default::default()
1206 },
1207 );
1208
1209 let mut fast_client = subs.add_client(1, true).await.unwrap();
1210 let mut slow_client = subs.add_client(1, true).await.unwrap();
1211
1212 let send_task = tokio::spawn(async move {
1213 let mut i = 0;
1214 for _ in 0..1_000 {
1217 tokio::time::sleep(Duration::from_millis(5)).await;
1218 i += 1;
1219 tx.send(i).await.unwrap();
1220 }
1221
1222 tracing::info!("Done sending");
1223 drop(subs);
1224
1225 i
1226 });
1227
1228 let fast_recv_task = tokio::spawn(async move {
1229 let mut last = None;
1230 while let Some(recv) = fast_client.recv().await {
1231 last = Some(recv.1);
1232 }
1233 tracing::info!("Fast Done: {:?}", last);
1234 last.unwrap()
1235 });
1236
1237 let slow_recv_task = tokio::spawn(async move {
1238 let mut last = None;
1239 while let Some(recv) = slow_client.recv().await {
1240 last = Some(recv.1);
1241 tokio::time::sleep(Duration::from_millis(100)).await;
1242 }
1243 tracing::info!("Slow done: {:?}", last);
1244 last.unwrap()
1245 });
1246
1247 let sent_i = send_task.await.unwrap();
1248 let fast_recv_i = fast_recv_task.await.unwrap();
1249 let slow_recv_i = slow_recv_task.await.unwrap();
1250
1251 assert_eq!(sent_i, 1000);
1252 assert_eq!(fast_recv_i, 1000);
1253 assert_eq!(slow_recv_i, 19);
1254 }
1255}