sequenced_broadcast/
lib.rs

1use std::{
2    collections::VecDeque, fmt::Debug, future::{poll_fn, Future}, sync::{
3        atomic::{AtomicU64, Ordering},
4        Arc, LazyLock,
5    }, task::Poll, time::{Duration, Instant}
6};
7
8use arc_metrics::{IntCounter, IntGauge};
9use tokio::sync::{
10    mpsc::{channel, error::{SendError, TryRecvError, TrySendError}, Permit, Receiver, Sender},
11    oneshot,
12};
13use tokio_util::sync::CancellationToken;
14use tracing::Instrument;
15
16pub struct SequencedBroadcast<T> {
17    new_client_tx: Sender<NewClient<T>>,
18    metrics: Arc<SequencedBroadcastMetrics>,
19    shutdown: CancellationToken,
20    worker_loops: Arc<AtomicU64>,
21    closed: oneshot::Receiver<()>,
22}
23
24pub struct SequencedSender<T> {
25    next_seq: u64,
26    send: Sender<(u64, T)>,
27}
28
29pub struct SequencedReceiver<T> {
30    next_seq: u64,
31    receiver: Receiver<(u64, T)>,
32}
33
34#[derive(Default, Debug)]
35pub struct SequencedBroadcastMetrics {
36    pub oldest_sequence: IntGauge,
37    pub next_sequence: IntGauge,
38    pub new_client_drop_count: IntCounter,
39    pub new_client_accept_count: IntCounter,
40    pub lagging_subs_gauge: IntGauge,
41    pub active_subs_gauge: IntGauge,
42    pub min_sub_sequence_gauge: IntGauge,
43    pub disconnect_count: IntCounter,
44    pub worker_loops: IntCounter,
45}
46
47struct Subscriber<T> {
48    id: u64,
49    next_sequence: u64,
50    tx: Sender<(u64, T)>,
51    allow_drop: bool,
52    lag_started_at: Option<Instant>,
53    pending: Option<T>,
54}
55
56#[derive(Debug, Clone)]
57pub struct SequencedBroadcastSettings {
58    pub subscriber_channel_len: usize,
59    pub lag_start_threshold: u64,
60    pub lag_end_threshold: u64,
61    pub max_time_lag: Duration,
62    pub min_history: u64,
63}
64
65struct Worker<T> {
66    rx: Receiver<(u64, T)>,
67    next_rx: Option<(u64, T)>,
68    rx_closed: bool,
69    rx_full: bool,
70
71    next_client_rx: Receiver<NewClient<T>>,
72    next_client: Option<NewClient<T>>,
73    next_client_closed: bool,
74
75    next_sub_id: u64,
76    subscribers: Vec<Subscriber<T>>,
77    queue: VecDeque<(u64, T)>,
78    next_queue_seq: u64,
79    metrics: Arc<SequencedBroadcastMetrics>,
80    settings: SequencedBroadcastSettings,
81    shutdown: CancellationToken,
82    worker_loops: Arc<AtomicU64>,
83    closed: oneshot::Sender<()>,
84}
85
86impl Default for SequencedBroadcastSettings {
87    fn default() -> Self {
88        SequencedBroadcastSettings {
89            subscriber_channel_len: 32,
90            lag_start_threshold: 1024 * 8,
91            lag_end_threshold: 1024 * 4,
92            max_time_lag: Duration::from_secs(2),
93            min_history: 2048,
94        }
95    }
96}
97
98impl<T> SequencedSender<T> {
99    pub fn new(next_seq: u64, send: Sender<(u64, T)>) -> Self {
100        SequencedSender { next_seq, send }
101    }
102
103    pub fn is_closed(&self) -> bool {
104        self.send.is_closed()
105    }
106
107    pub async fn closed(&self) {
108        self.send.closed().await
109    }
110
111    pub async fn safe_send(&mut self, seq: u64, item: T) -> Result<(), SequencedSenderError<T>> {
112        self._send(Some(seq), item).await
113    }
114
115    pub async fn send(&mut self, item: T) -> Result<(), SequencedSenderError<T>> {
116        self._send(None, item).await
117    }
118
119    pub fn try_send(&mut self, item: T) -> Result<(), TrySendError<T>> {
120        match self.send.try_send((self.next_seq, item)) {
121            Ok(()) => {
122                self.next_seq += 1;
123                Ok(())
124            }
125            Err(TrySendError::Full(err)) => Err(TrySendError::Full(err.1)),
126            Err(TrySendError::Closed(err)) => Err(TrySendError::Closed(err.1)),
127        }
128    }
129
130    pub async fn reserve(&mut self) -> Result<SequencedSenderPermit<T>, SendError<()>> {
131        let permit = self.send.reserve().await?;
132
133        Ok(SequencedSenderPermit {
134            next_seq: &mut self.next_seq,
135            permit,
136        })
137    }
138
139    async fn _send(&mut self, seq: Option<u64>, item: T) -> Result<(), SequencedSenderError<T>> {
140        if let Some(seq) = seq {
141            if seq != self.next_seq {
142                return Err(SequencedSenderError::InvalidSequence(self.next_seq, item));
143            }
144        }
145
146        if let Err(error) = self.send.send((self.next_seq, item)).await {
147            return Err(SequencedSenderError::ChannelClosed(error.0.1));
148        }
149
150        self.next_seq += 1;
151        Ok(())
152    }
153
154    pub fn seq(&self) -> u64 {
155        self.next_seq
156    }
157}
158
159pub struct SequencedSenderPermit<'a, T> {
160    next_seq: &'a mut u64,
161    permit: Permit<'a, (u64, T)>,
162}
163
164impl<T> SequencedSenderPermit<'_, T> {
165    pub fn send(self, item: T) {
166        let seq = *self.next_seq;
167        self.permit.send((seq, item));
168        *self.next_seq = seq + 1;
169    }
170}
171
172impl<T> SequencedReceiver<T> {
173    pub fn new(next_seq: u64, receiver: Receiver<(u64, T)>) -> Self {
174        SequencedReceiver {
175            next_seq,
176            receiver
177        }
178    }
179
180    pub fn is_closed(&self) -> bool {
181        self.receiver.is_closed()
182    }
183
184    pub async fn recv(&mut self) -> Option<(u64, T)> {
185        let (seq, action) = self.receiver.recv().await?;
186        if self.next_seq != seq {
187            panic!("expected sequence: {} but got: {}", self.next_seq, seq);
188        }
189        self.next_seq += 1;
190        Some((seq, action))
191    }
192
193    pub fn try_recv(&mut self) -> Result<(u64, T), TryRecvError> {
194        match self.receiver.try_recv() {
195            Ok((seq, action)) => {
196                if self.next_seq != seq {
197                    panic!("expected sequence: {} but got: {}", self.next_seq, seq);
198                }
199                self.next_seq += 1;
200                Ok((seq, action))
201            }
202            Err(error) => Err(error)
203        }
204    }
205
206    pub fn unbundle(self) -> (u64, Receiver<(u64, T)>) {
207        (self.next_seq, self.receiver)
208    }
209
210    pub fn next_seq(&self) -> u64 {
211        self.next_seq
212    }
213}
214
215#[derive(Debug, PartialEq, Eq)]
216pub enum SequencedSenderError<T> {
217    InvalidSequence(u64, T),
218    ChannelClosed(T),
219}
220
221impl<T> SequencedSenderError<T> {
222    pub fn into_inner(self) -> T {
223        match self {
224            Self::InvalidSequence(_, v) => v,
225            Self::ChannelClosed(v) => v,
226        }
227    }
228}
229
230impl<T: Send + Clone + 'static> SequencedBroadcast<T> {
231    pub fn new(next_seq: u64, settings: SequencedBroadcastSettings) -> (Self, SequencedSender<T>) {
232        let (tx, rx) = channel(1024);
233        let tx = SequencedSender::new(next_seq, tx);
234        let rx = SequencedReceiver::new(next_seq, rx);
235
236        (
237            Self::new2(rx, settings),
238            tx
239        )
240    }
241
242    pub fn new2(receiver: SequencedReceiver<T>, settings: SequencedBroadcastSettings) -> Self {
243        let queue_cap = 2 * (
244            (settings.lag_start_threshold as usize)
245            .next_power_of_two()
246            .max(1024)
247        ).max((settings.min_history as usize).next_power_of_two());
248
249        assert!(settings.lag_end_threshold <= settings.lag_start_threshold);
250
251        let (client_tx, client_rx) = channel(32);
252
253        let metrics = Arc::new(SequencedBroadcastMetrics {
254            oldest_sequence: {
255                let i = IntGauge::default();
256                i.set(receiver.next_seq);
257                i
258            },
259            next_sequence: {
260                let i = IntGauge::default();
261                i.set(receiver.next_seq);
262                i
263            },
264            ..Default::default()
265        });
266
267        let shutdown = CancellationToken::new();
268        let current_span = tracing::Span::current();
269        let (closed_tx, closed_rx) = oneshot::channel();
270
271        let worker_loops = Arc::new(AtomicU64::new(0));
272
273        tokio::spawn(
274            Worker {
275                rx: receiver.receiver,
276                next_rx: None,
277                rx_full: false,
278                rx_closed: false,
279
280                next_client_rx: client_rx,
281                next_client: None,
282                next_client_closed: false,
283
284                next_sub_id: 1,
285                subscribers: Vec::with_capacity(32),
286                queue: VecDeque::with_capacity(queue_cap),
287                next_queue_seq: receiver.next_seq,
288                metrics: metrics.clone(),
289                settings,
290                shutdown: shutdown.clone(),
291                worker_loops: worker_loops.clone(),
292                closed: closed_tx,
293            }
294            .start()
295            .instrument(current_span),
296        );
297
298        Self {
299            new_client_tx: client_tx,
300            metrics,
301            shutdown,
302            worker_loops,
303            closed: closed_rx,
304        }
305    }
306
307    pub async fn add_client(
308        &self,
309        next_sequence: u64,
310        allow_drop: bool,
311    ) -> Result<SequencedReceiver<T>, NewClientError> {
312        let (tx, rx) = oneshot::channel();
313
314        self.new_client_tx
315            .send(NewClient {
316                response: tx,
317                allow_drop,
318                next_sequence,
319            })
320            .await
321            .expect("Failed to queue new subscriber, worker crashed");
322
323        rx.await.expect("worker closed")
324    }
325
326    pub fn metrics_ref(&self) -> &SequencedBroadcastMetrics {
327        &self.metrics
328    }
329
330    pub fn metrics(&self) -> Arc<SequencedBroadcastMetrics> {
331        self.metrics.clone()
332    }
333
334    pub fn worker_loops(&self) -> u64 {
335        self.worker_loops.load(Ordering::Relaxed)
336    }
337
338    pub fn shutdown(self) -> oneshot::Receiver<()> {
339        self.shutdown.cancel();
340        self.closed
341    }
342
343    pub async fn shutdown_wait(self) {
344        self.shutdown().await.unwrap();
345    }
346
347    pub fn closed(self) -> oneshot::Receiver<()> {
348        self.closed
349    }
350}
351
352struct NewClient<T> {
353    response: oneshot::Sender<Result<SequencedReceiver<T>, NewClientError>>,
354    next_sequence: u64,
355    allow_drop: bool,
356}
357
358#[derive(Debug)]
359pub enum NewClientError {
360    SequenceTooFarAhead { seq: u64, max: u64 },
361    SequenceTooFarBehind { seq: u64, min: u64 },
362}
363
364impl<T> Debug for NewClient<T> {
365    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366        write!(
367            f,
368            "NewClient {{ next_sequence: {}, allow_drop: {} }}",
369            self.next_sequence, self.allow_drop
370        )
371    }
372}
373
374static WORKER_ID: LazyLock<Arc<AtomicU64>> = LazyLock::new(|| Arc::new(AtomicU64::new(1)));
375
376impl<T: Send + Clone + 'static> Worker<T> {
377    async fn start(mut self) {
378        let id = WORKER_ID.fetch_add(1, Ordering::SeqCst);
379        tracing::info!(id, "{}|SequencedBroadcastWorker Started", id);
380        let start = Instant::now();
381
382        self._start(id).await;
383        let elapsed = start.elapsed();
384        let iter = self.worker_loops.load(Ordering::Relaxed);
385
386        tracing::info!(id, ?elapsed, iter, "{}|SequencedBroadcastWorker Stopped", id);
387
388        let _ = self.closed.send(());
389    }
390
391    async fn _start(&mut self, id: u64) {
392        loop {
393            self.worker_loops.fetch_add(1, Ordering::Relaxed);
394            tokio::task::yield_now().await;
395
396            if self.next_client.is_none() {
397                self.next_client = match self.next_client_rx.try_recv() {
398                    Ok(item) => Some(item),
399                    Err(TryRecvError::Empty) => None,
400                    Err(TryRecvError::Disconnected) => {
401                        self.next_client_closed = true;
402                        None
403                    }
404                };
405            }
406
407            if self.shutdown.is_cancelled() {
408                tracing::info!("{}|Stopping worker due to shutdown", id);
409                break;
410            }
411
412            /* accept new clients */
413            if !self.next_client_closed {
414                let mut max_per_loop = 32;
415                let min_allowed_seq = self
416                    .queue
417                    .front()
418                    .map(|i| i.0)
419                    .unwrap_or(self.next_queue_seq);
420
421                while let Some(new) = self.next_client.take() {
422                    self.next_client = self.next_client_rx.try_recv().ok();
423
424                    /* Sequence in valid range */
425                    if new.next_sequence < min_allowed_seq
426                        || self.next_queue_seq < new.next_sequence
427                    {
428                        self.metrics.new_client_drop_count.inc();
429
430                        if new.next_sequence < min_allowed_seq {
431                            tracing::info!(
432                                "{}|Subscriber rejected, seq({}) < min_allowed({})",
433                                id,
434                                new.next_sequence,
435                                min_allowed_seq
436                            );
437
438                            let _ = new.response.send(Err(NewClientError::SequenceTooFarBehind {
439                                seq: new.next_sequence,
440                                min: min_allowed_seq
441                            }));
442                        } else {
443                            tracing::info!(
444                                "{}|Subscriber rejected, max_seq({}) < seq({})",
445                                id,
446                                self.next_queue_seq,
447                                new.next_sequence
448                            );
449
450                            let _ = new.response.send(Err(NewClientError::SequenceTooFarAhead {
451                                seq: new.next_sequence,
452                                max: self.next_queue_seq
453                            }));
454                        }
455
456                        continue;
457                    }
458
459                    self.metrics.new_client_accept_count.inc();
460
461                    /* Send Receiver to subscribers */
462                    let (tx, rx) = channel(self.settings.subscriber_channel_len);
463                    let rx = SequencedReceiver::<T> {
464                        receiver: rx,
465                        next_seq: new.next_sequence,
466                    };
467
468                    if new.response.send(Ok(rx)).is_ok() {
469                        let sub_id = self.next_sub_id;
470                        self.next_sub_id += 1;
471
472                        tracing::info!(
473                            "{}|Subscriber({}): Added, allow_drop: {}, next_sequence: {}, min_allowed_seq: {}",
474                            id, sub_id, new.allow_drop, new.next_sequence, min_allowed_seq,
475                        );
476
477                        self.subscribers.push(Subscriber {
478                            id: sub_id,
479                            allow_drop: new.allow_drop,
480                            next_sequence: new.next_sequence,
481                            pending: None,
482                            tx,
483                            lag_started_at: None,
484                        });
485                    } else {
486                        tracing::warn!("{}|New subscriber accepted but receiver dropped", id);
487                    }
488
489                    /* ensure we don't block getting new clients */
490                    if max_per_loop == 0 {
491                        break;
492                    }
493
494                    max_per_loop -= 1;
495                }
496            }
497
498            /* fill queue with available data from rx */
499            'fill_rx: {
500                if self.next_rx.is_none() {
501                    self.next_rx = match self.rx.try_recv() {
502                        Ok(msg) => Some(msg),
503                        Err(TryRecvError::Disconnected) => {
504                            self.rx_closed = true;
505                            None
506                        }
507                        Err(TryRecvError::Empty) => None,
508                    };
509                }
510
511                let mut remaining_msg_count = self.rx_space().min(1024);
512                if remaining_msg_count == 0 {
513                    if !self.rx_full {
514                        self.rx_full = true;
515                        assert_eq!(self.queue.len(), self.queue.capacity());
516                        tracing::info!("{}|Reached queue capacity {}", id, self.queue.len());
517                    }
518
519                    break 'fill_rx;
520                }
521
522                if self.rx_full {
523                    tracing::info!("{}|Space returned to queue {}/{}", id, self.rx_space(), self.queue.len());
524                    self.rx_full = false;
525                }
526
527                while let Some((seq, item)) = self.next_rx.take() {
528                    self.next_rx = self.rx.try_recv().ok();
529
530                    assert_eq!(seq, self.next_queue_seq, "sequence is invalid");
531                    self.queue.push_back((seq, item));
532                    self.next_queue_seq += 1;
533
534                    remaining_msg_count -= 1;
535                    if remaining_msg_count == 0 {
536                        break;
537                    }
538                }
539            }
540
541            self.metrics.next_sequence.set(self.next_queue_seq);
542
543            let oldest_queue_sequence = self
544                .queue
545                .front()
546                .map(|v| v.0)
547                .unwrap_or(self.next_queue_seq);
548
549            let max_seq = oldest_queue_sequence + self.queue.len() as u64;
550            let lag_start_seq = max_seq.max(self.settings.lag_start_threshold) - self.settings.lag_start_threshold;
551            let lag_end_seq = lag_start_seq.max(max_seq.max(self.settings.lag_end_threshold) - self.settings.lag_end_threshold);
552
553            let mut min_sub_sequence_calc = self.next_queue_seq;
554            let mut earliest_lag_start_at_calc: Option<Instant> = None;
555
556            let mut i = 0;
557            'next_sub: while i < self.subscribers.len() {
558                let sub = &mut self.subscribers[i];
559
560                /* make sure sub is still valid */
561                if (sub.allow_drop && sub.next_sequence < oldest_queue_sequence) || sub.tx.is_closed() {
562                    if sub.tx.is_closed() {
563                        tracing::info!("{}|Subscriber({}): channel closed, dropping", sub.id, id);
564                    } else {
565                        tracing::warn!(
566                            "{}|Subscriber({}): lag behind available data ({} < {}), dropping",
567                            id,
568                            sub.id,
569                            sub.next_sequence,
570                            oldest_queue_sequence
571                        );
572                    }
573
574                    if sub.lag_started_at.is_some() {
575                        self.metrics.lagging_subs_gauge.dec();
576                    }
577
578                    self.metrics.disconnect_count.inc();
579
580                    self.subscribers.swap_remove(i);
581                    continue 'next_sub;
582                }
583
584                /* write_to_sub */
585                let mut offset = {
586                    assert!(sub.next_sequence >= oldest_queue_sequence);
587                    let offset = (sub.next_sequence - oldest_queue_sequence) as usize;
588                    assert!(sub.next_sequence <= self.next_queue_seq, "sub cannot be ahead of queue sequence");
589                    assert!(offset <= self.queue.len(), "sub cannot be ahead of queue sequence");
590                    offset
591                };
592
593                /* prep next message to send */
594                if sub.pending.is_none() {
595                    /* fully caught up */
596                    if self.queue.len() == offset {
597                        i += 1;
598                        continue 'next_sub;
599                    }
600
601                    /* make next item pending */
602                    let (seq, item) = self.queue.get(offset).unwrap();
603                    assert_eq!(*seq, sub.next_sequence);
604                    sub.pending = Some(item.clone());
605                }
606
607                /* send as much as possible */
608                while let Some(next) = sub.pending.take() {
609                    match sub.tx.try_send((sub.next_sequence, next)) {
610                        Ok(_) => {
611                            sub.next_sequence += 1;
612                            offset += 1;
613
614                            if self.queue.len() == offset {
615                                break;
616                            }
617
618                            let (seq, item) = self.queue.get(offset).unwrap();
619                            assert_eq!(*seq, sub.next_sequence);
620                            sub.pending = Some(item.clone());
621                        }
622                        Err(TrySendError::Closed(_)) => break,
623                        Err(TrySendError::Full((_seq, item))) => {
624                            sub.pending = Some(item);
625                            break;
626                        }
627                    }
628                }
629
630                if sub.allow_drop {
631                    if lag_end_seq <= sub.next_sequence {
632                        if let Some(lag_start) = sub.lag_started_at.take() {
633                            tracing::info!(
634                                "{}|Subscriber({}): caught up after {:?}",
635                                id,
636                                sub.id,
637                                lag_start.elapsed()
638                            );
639
640                            self.metrics.lagging_subs_gauge.inc();
641                        }
642                    }
643                    else if sub.next_sequence < lag_start_seq {
644                        if let Some(lag_start) = &sub.lag_started_at {
645                            let lag_duration = lag_start.elapsed();
646
647                            if self.settings.max_time_lag < lag_duration {
648                                tracing::info!(
649                                    "{}|Subscriber({}): lag too high for too long ({:?}), dropping",
650                                    id,
651                                    sub.id,
652                                    lag_duration,
653                                );
654
655                                self.metrics.lagging_subs_gauge.dec();
656                                self.metrics.disconnect_count.inc();
657
658                                self.subscribers.swap_remove(i);
659                                continue 'next_sub;
660                            }
661                        } else {
662                            sub.lag_started_at = Some(Instant::now());
663
664                            tracing::info!(
665                                "{}|Subscriber({}): lag started thresh({}) < lag({})",
666                                id,
667                                sub.id,
668                                self.settings.lag_start_threshold,
669                                max_seq - sub.next_sequence,
670                            );
671
672                            self.metrics.lagging_subs_gauge.inc()
673                        }
674                    }
675                }
676
677                if let Some(lag_started_at) = &sub.lag_started_at {
678                    earliest_lag_start_at_calc = match earliest_lag_start_at_calc {
679                        Some(v) if v.lt(lag_started_at) => Some(v),
680                        _ => sub.lag_started_at
681                    };
682                }
683
684                min_sub_sequence_calc = min_sub_sequence_calc.min(sub.next_sequence);
685                i += 1;
686            }
687
688            let min_sub_sequence = min_sub_sequence_calc;
689
690            self.metrics.active_subs_gauge.set(self.subscribers.len() as u64);
691            self.metrics.min_sub_sequence_gauge.set(min_sub_sequence);
692
693            /* trim rx queue */
694            {
695                let keep_seq = min_sub_sequence.min(max_seq.max(self.settings.min_history) - self.settings.min_history);
696
697                if oldest_queue_sequence < keep_seq {
698                    let remove_count = keep_seq - oldest_queue_sequence;
699                    if remove_count != 0 {
700                        let _ = self.queue.drain(0..remove_count as usize);
701                    }
702
703                    self.metrics.oldest_sequence.set(oldest_queue_sequence + remove_count);
704                }
705            }
706
707            if self.rx_closed && min_sub_sequence == max_seq {
708                tracing::info!("{}|RX closed and all subscribers caught up, shutting down worker", id);
709                return;
710            }
711
712            if self.next_client_closed && self.subscribers.is_empty() {
713                tracing::info!("{}|no subscribers and next_client_rx closed, shutting down worker", id);
714                return;
715            }
716
717            let rx_blocked = self.next_rx.is_none() && !self.rx_closed;
718            let next_timeout = earliest_lag_start_at_calc.map(|early| {
719                let now = Instant::now();
720                let expire = early + self.settings.max_time_lag;
721                (expire.max(now) - now).max(Duration::from_millis(100))
722            });
723
724            /* see if there's more work available without waiting */
725            {
726                /* update RX */
727                if !rx_blocked && 0 < self.rx_space() {
728                    tracing::trace!("{}|have more rx", id);
729                    continue;
730                }
731
732                /* new client available */
733                if self.next_client.is_some() {
734                    tracing::trace!("{}|have next client", id);
735                    continue;
736                }
737            }
738
739            let mut timeout_fut = next_timeout.map(|duration| tokio::time::sleep(duration));
740            let mut pending_tx = Vec::new();
741            let new_client_rx = &mut self.next_client_rx;
742            let new_msg_rx = &mut self.rx;
743            let next_rx = &mut self.next_rx;
744            let next_client = &mut self.next_client;
745
746            for sub in &mut self.subscribers {
747                if sub.pending.is_some() {
748                    pending_tx.push((sub.tx.reserve(), &mut sub.pending, &mut sub.next_sequence));
749                }
750            }
751
752            poll_fn(|cx| {
753                if let Some(timeout) = &mut timeout_fut {
754                    if unsafe { std::pin::Pin::new_unchecked(timeout) }.poll(cx).is_ready() {
755                        tracing::trace!("{}|poll: max lag timer reached", id);
756                        return Poll::Ready(());
757                    }
758                }
759
760                if rx_blocked {
761                    if let Poll::Ready(item) = unsafe { std::pin::Pin::new_unchecked(&mut *new_msg_rx) }.poll_recv(cx) {
762                        assert!(next_rx.is_none());
763
764                        *next_rx = item;
765                        if next_rx.is_some() {
766                            tracing::trace!("{}|poll: new RX available", id);
767                        } else {
768                            tracing::trace!("{}|poll: RX closed", id);
769                        }
770
771                        return Poll::Ready(());
772                    }
773                }
774
775                if let Poll::Ready(item) = unsafe { std::pin::Pin::new_unchecked(&mut *new_client_rx) }.poll_recv(cx) {
776                    tracing::trace!("{}|poll: new client", id);
777
778                    assert!(next_client.is_none());
779                    *next_client = item;
780                    return Poll::Ready(());
781                }
782
783                let mut sent = false;
784                for (reserve, pending, next_sequence) in &mut pending_tx {
785                    let reserve = unsafe { std::pin::Pin::new_unchecked(reserve) };
786
787                    match reserve.poll(cx) {
788                        Poll::Ready(Ok(slot)) => {
789                            let item = pending.take().expect("pending missing");
790                            let seq = **next_sequence;
791                            slot.send((seq, item));
792                            **next_sequence = seq + 1;
793                            
794                            sent = true;
795                        }
796                        Poll::Ready(Err(_)) => {
797                            sent = true;
798                        }
799                        Poll::Pending => {}
800                    }
801                }
802
803                if sent {
804                    tracing::trace!("{}|poll: subscriber message sent", id);
805                    return Poll::Ready(());
806                }
807
808                Poll::Pending
809            }).await;
810        }
811    }
812
813    fn rx_space(&self) -> usize {
814        self.queue.capacity() - self.queue.len()
815    }
816}
817
818#[cfg(test)]
819mod test {
820    use super::*;
821
822    pub fn setup_logging() {
823        let _ = tracing_subscriber::fmt().with_test_writer().try_init();
824        // let _ = tracing_subscriber::fmt().try_init();
825    }
826
827    #[tokio::test]
828    async fn subscribers_shutdown_test() {
829        setup_logging();
830
831        let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
832
833        let client = subs.add_client(0, true).await.unwrap();
834        tx.send("Hello World").await.unwrap();
835
836        let close_wait = subs.shutdown();
837
838        tokio::time::timeout(Duration::from_millis(100), close_wait).await
839            .expect("timeout waiting for close")
840            .expect("close handler dropped before send");
841        
842        drop(client);
843        drop(tx);
844    }
845
846    #[tokio::test]
847    async fn subscribers_close_no_subs_test() {
848        setup_logging();
849
850        let close_wait = {
851            let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
852            tx.send("Hello World").await.unwrap();
853            subs.closed()
854        };
855
856        tokio::time::timeout(Duration::from_millis(100), close_wait).await
857            .expect("timeout waiting for close")
858            .expect("close handler dropped before send");
859    }
860
861    #[tokio::test]
862    async fn subscribers_updates_active_metric_test() {
863        setup_logging();
864
865        let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
866        tx.send("Hello World").await.unwrap();
867
868        let mut client_1 = subs.add_client(0, true).await.unwrap();
869        let mut client_2 = subs.add_client(0, true).await.unwrap();
870        let msg = client_1.recv().await.unwrap();
871        assert_eq!((0, "Hello World"), msg);
872
873        assert_eq!(2, subs.metrics_ref().active_subs_gauge.load());
874        drop(client_1);
875
876        tx.send("Test2").await.unwrap();
877        assert_eq!((0, "Hello World"), client_2.recv().await.unwrap());
878        assert_eq!((1, "Test2"), client_2.recv().await.unwrap());
879
880        tokio::time::sleep(Duration::from_millis(10)).await;
881        assert_eq!(1, subs.metrics_ref().active_subs_gauge.load());
882    }
883
884    #[tokio::test]
885    async fn subscribers_close_sub_caught_up_test() {
886        setup_logging();
887
888        let (close_wait, mut client) = {
889            let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
890            tx.send("Hello World").await.unwrap();
891            let client = subs.add_client(0, true).await.unwrap();
892            (subs.closed(), client)
893        };
894
895        assert_eq!((0, "Hello World"), client.recv().await.unwrap());
896        drop(client);
897
898        tokio::time::timeout(Duration::from_millis(100), close_wait).await
899            .expect("timeout waiting for close")
900            .expect("close handler dropped before send");
901    }
902
903    #[tokio::test]
904    async fn subscribers_close_sub_caught_up_tx_alive_test() {
905        setup_logging();
906
907        let (close_wait, mut client, tx) = {
908            let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
909            tx.send("Hello World").await.unwrap();
910            let client = subs.add_client(0, true).await.unwrap();
911            (subs.closed(), client, tx)
912        };
913
914        assert_eq!((0, "Hello World"), client.recv().await.unwrap());
915        drop(client);
916
917        tokio::time::timeout(Duration::from_millis(100), close_wait).await
918            .expect("timeout waiting for close")
919            .expect("close handler dropped before send");
920
921        drop(tx);
922    }
923
924    #[tokio::test]
925    async fn subscribers_close_sub_not_caught_up_test() {
926        setup_logging();
927
928        let (close_wait, mut client) = {
929            let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
930            tx.send("Hello World").await.unwrap();
931            tx.send("Hello World 2").await.unwrap();
932            let client = subs.add_client(0, true).await.unwrap();
933            (subs.closed(), client)
934        };
935
936        assert_eq!((0, "Hello World"), client.recv().await.unwrap());
937        drop(client);
938
939        tokio::time::timeout(Duration::from_millis(100), close_wait).await
940            .expect("timeout waiting for close")
941            .expect("close handler dropped before send");
942    }
943
944    #[tokio::test]
945    async fn subscribers_catchup_test() {
946        setup_logging();
947
948        let (subs, mut tx) =
949            SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
950
951        tx.send("Hello WOrld").await.unwrap();
952        tx.send("What the heck").await.unwrap();
953
954        let mut sub_1 = subs.add_client(0, true).await.unwrap();
955        assert_eq!((0, "Hello WOrld"), sub_1.recv().await.unwrap());
956        assert_eq!((1, "What the heck"), sub_1.recv().await.unwrap());
957
958        let mut sub_2 = subs.add_client(0, true).await.unwrap();
959        assert_eq!((0, "Hello WOrld"), sub_2.recv().await.unwrap());
960        assert_eq!((1, "What the heck"), sub_2.recv().await.unwrap());
961
962        let mut sub_3 = subs.add_client(1, true).await.unwrap();
963        assert_eq!((1, "What the heck"), sub_3.recv().await.unwrap());
964
965        tx.send("Hehe").await.unwrap();
966        assert_eq!((2, "Hehe"), sub_1.recv().await.unwrap());
967        assert_eq!((2, "Hehe"), sub_2.recv().await.unwrap());
968        assert_eq!((2, "Hehe"), sub_3.recv().await.unwrap());
969
970        subs.shutdown_wait().await;
971    }
972
973    #[tokio::test]
974    async fn sequenced_broadcast_simple_test() {
975        setup_logging();
976
977        let (subs, mut tx) =
978            SequencedBroadcast::<u64>::new(10, SequencedBroadcastSettings::default());
979
980        let mut client = subs.add_client(10, true).await.unwrap();
981        tracing::info!("client added");
982
983        let read_task = tokio::spawn(async move {
984            let mut i = 0;
985            let mut seq = 10;
986
987            while let Some(msg) = client.recv().await {
988                assert_eq!(msg, (seq, i));
989                i += 1;
990                seq += 1;
991            }
992
993            i
994        });
995
996        let count = 1024 * 16;
997
998        for i in 0..count {
999            tx.send(i).await.unwrap();
1000        }
1001        drop(tx);
1002
1003        let total = read_task.await.unwrap();
1004        assert_eq!(total, count);
1005
1006        subs.shutdown_wait().await;
1007    }
1008
1009    #[tokio::test]
1010    async fn subscribers_test() {
1011        setup_logging();
1012
1013        let (subs, mut tx) =
1014            SequencedBroadcast::<&'static str>::new(10, SequencedBroadcastSettings::default());
1015        tx.send("Hello WOrld").await.unwrap();
1016        tx.send("What the heck").await.unwrap();
1017
1018        let mut sub = subs.add_client(10, true).await.unwrap();
1019        assert_eq!((10, "Hello WOrld"), sub.recv().await.unwrap());
1020        assert_eq!((11, "What the heck"), sub.recv().await.unwrap());
1021
1022        assert!(subs.add_client(10, true).await.is_ok());
1023        assert!(subs.add_client(11, true).await.is_ok());
1024        assert!(subs.add_client(12, true).await.is_ok());
1025        assert!(subs.add_client(13, true).await.is_err());
1026        assert!(subs.add_client(9, true).await.is_err());
1027
1028        tx.send("Butts").await.unwrap();
1029        assert_eq!((12, "Butts"), sub.recv().await.unwrap());
1030
1031        tokio::time::sleep(Duration::from_millis(1)).await;
1032
1033        tracing::info!("Metrics: {:?}", subs.metrics_ref());
1034
1035        subs.shutdown_wait().await;
1036    }
1037
1038    #[tokio::test]
1039    async fn subscribers_dont_drop_test() {
1040        setup_logging();
1041
1042        let (subs, mut tx) = SequencedBroadcast::<i64>::new(
1043            1,
1044            SequencedBroadcastSettings {
1045                max_time_lag: Duration::from_millis(100),
1046                ..Default::default()
1047            },
1048        );
1049
1050        let mut sub = subs.add_client(1, false).await.unwrap();
1051
1052        tokio::time::timeout(Duration::from_secs(3), async {
1053            loop {
1054                if tokio::time::timeout(Duration::from_secs(1), tx.send(1))
1055                    .await
1056                    .is_err()
1057                {
1058                    break;
1059                }
1060            }
1061        }).await.expect("client must have been dropped as can still send tx");
1062
1063        tracing::info!("tx filled");
1064
1065        assert!(tokio::time::timeout(Duration::from_secs(1), tx.send(1))
1066            .await
1067            .is_err());
1068        assert_eq!((1, 1), sub.recv().await.unwrap());
1069
1070        assert!(
1071            tokio::time::timeout(Duration::from_millis(10), tx.send(1000))
1072                .await
1073                .is_ok()
1074        );
1075
1076        let sub_mut = &mut sub;
1077
1078        tokio::time::timeout(Duration::from_millis(100), async move {
1079            loop {
1080                let (_, num) = sub_mut.recv().await.unwrap();
1081                if num == 1000 {
1082                    break;
1083                }
1084            }
1085        })
1086        .await
1087        .unwrap();
1088
1089        assert!(
1090            tokio::time::timeout(Duration::from_millis(10), tx.send(2000))
1091                .await
1092                .is_ok()
1093        );
1094        assert_eq!(2000, sub.recv().await.unwrap().1);
1095
1096        subs.shutdown_wait().await;
1097    }
1098
1099    #[tokio::test]
1100    async fn subscribers_no_clients_test() {
1101        setup_logging();
1102
1103        let (subs, mut tx) =
1104            SequencedBroadcast::<&'static str>::new(1, SequencedBroadcastSettings::default());
1105        let (subs, mut tx) = tokio::time::timeout(Duration::from_secs(1), async move {
1106            for _ in 0..1_000_000 {
1107                tx.send("Hello World").await.unwrap();
1108            }
1109
1110            tracing::info!("Sent 1M messages");
1111
1112            while tx.seq() != subs.metrics_ref().next_sequence.load() {
1113                tokio::time::sleep(Duration::from_millis(100)).await;
1114            }
1115
1116            tracing::info!("All 1M messages have been processed");
1117
1118            (subs, tx)
1119        })
1120        .await
1121        .unwrap();
1122
1123        let seq = tx.seq();
1124        tracing::info!("Seq: {}", seq);
1125
1126        let mut sub = subs.add_client(seq - 1, true).await.unwrap();
1127        tx.send("Test").await.unwrap();
1128
1129        assert_eq!((seq - 1, "Hello World"), sub.recv().await.unwrap());
1130        assert_eq!((seq, "Test"), sub.recv().await.unwrap());
1131
1132        subs.shutdown_wait().await;
1133    }
1134
1135    #[tokio::test]
1136    async fn continious_send_send_test() {
1137        setup_logging();
1138
1139        let (subs, mut tx) = SequencedBroadcast::<u64>::new(1, SequencedBroadcastSettings {
1140            min_history: 1024,
1141            lag_end_threshold: 128,
1142            lag_start_threshold: 512,
1143            max_time_lag: Duration::from_secs(20),
1144            ..Default::default()
1145        });
1146
1147        let mut read_tasks = vec![];
1148        for _ in 0..32 {
1149            let mut client = subs.add_client(1, true).await.unwrap();
1150
1151            let read_task = tokio::spawn(async move {
1152                let mut next = 1;
1153                while let Some((seq, num)) = client.recv().await {
1154                    assert_eq!(seq, num);
1155                    assert_eq!(seq, next);
1156                    next = seq + 1;
1157                }
1158                next
1159            });
1160
1161            read_tasks.push(read_task);
1162        }
1163
1164        let start = Instant::now();
1165        let mut end = 1;
1166        while start.elapsed() < Duration::from_secs(5) {
1167            tokio::time::timeout(Duration::from_secs(1), tx.send(end)).await
1168                .expect("timeout sending message")
1169                .expect("failed to send message");
1170
1171            end += 1;
1172        }
1173
1174        drop(tx);
1175
1176        for read_task in read_tasks {
1177            let count = tokio::time::timeout(Duration::from_secs(1), read_task).await
1178                .expect("timeout waiting for rx task to close")
1179                .expect("rx task crashed");
1180
1181            assert_eq!(count, end);
1182        }
1183    }
1184
1185    #[tokio::test]
1186    async fn subscribers_drops_slow_sub_test() {
1187        setup_logging();
1188
1189        let (subs, mut tx) = SequencedBroadcast::<i64>::new(
1190            1,
1191            SequencedBroadcastSettings {
1192                max_time_lag: Duration::from_secs(1),
1193                subscriber_channel_len: 4,
1194                lag_start_threshold: 64,
1195                lag_end_threshold: 32,
1196                ..Default::default()
1197            },
1198        );
1199
1200        let mut fast_client = subs.add_client(1, true).await.unwrap();
1201        let mut slow_client = subs.add_client(1, true).await.unwrap();
1202
1203        let send_task = tokio::spawn(async move {
1204            let mut i = 0;
1205            /* 5 seconds of sending */
1206
1207            for _ in 0..1_000 {
1208                tokio::time::sleep(Duration::from_millis(5)).await;
1209                i += 1;
1210                tx.send(i).await.unwrap();
1211            }
1212
1213            tracing::info!("Done sending");
1214            drop(subs);
1215
1216            i
1217        });
1218
1219        let fast_recv_task = tokio::spawn(async move {
1220            let mut last = None;
1221            while let Some(recv) = fast_client.recv().await {
1222                last = Some(recv.1);
1223            }
1224            tracing::info!("Fast Done: {:?}", last);
1225            last.unwrap()
1226        });
1227
1228        let slow_recv_task = tokio::spawn(async move {
1229            let mut last = None;
1230            while let Some(recv) = slow_client.recv().await {
1231                last = Some(recv.1);
1232                tokio::time::sleep(Duration::from_millis(100)).await;
1233            }
1234            tracing::info!("Slow done: {:?}", last);
1235            last.unwrap()
1236        });
1237
1238        let sent_i = send_task.await.unwrap();
1239        let fast_recv_i = fast_recv_task.await.unwrap();
1240        let slow_recv_i = slow_recv_task.await.unwrap();
1241
1242        assert_eq!(sent_i, 1000);
1243        assert_eq!(fast_recv_i, 1000);
1244        assert_eq!(slow_recv_i, 19);
1245    }
1246}