sequenced_broadcast/
lib.rs

1use std::{
2    collections::VecDeque, fmt::Debug, future::{poll_fn, Future}, sync::{
3        atomic::{AtomicU64, Ordering},
4        Arc, LazyLock,
5    }, task::Poll, time::{Duration, Instant}
6};
7
8use tokio::sync::{
9    mpsc::{channel, error::{SendError, TryRecvError, TrySendError}, Permit, Receiver, Sender},
10    oneshot,
11};
12use tokio_util::sync::CancellationToken;
13use tracing::Instrument;
14
15pub struct SequencedBroadcast<T> {
16    new_client_tx: Sender<NewClient<T>>,
17    metrics: Arc<SequencedBroadcastMetrics>,
18    shutdown: CancellationToken,
19    worker_loops: Arc<AtomicU64>,
20    closed: oneshot::Receiver<()>,
21}
22
23pub struct SequencedSender<T> {
24    next_seq: u64,
25    send: Sender<(u64, T)>,
26}
27
28pub struct SequencedReceiver<T> {
29    next_seq: u64,
30    receiver: Receiver<(u64, T)>,
31}
32
33#[derive(Debug, Default)]
34pub struct SequencedBroadcastMetrics {
35    pub oldest_sequence: AtomicU64,
36    pub next_sequence: AtomicU64,
37    pub new_client_drop_count: AtomicU64,
38    pub new_client_accept_count: AtomicU64,
39    pub lagging_subs_gauge: AtomicU64,
40    pub active_subs_gauge: AtomicU64,
41    pub min_sub_sequence_gauge: AtomicU64,
42    pub disconnect_count: AtomicU64,
43    pub worker_loops: AtomicU64,
44}
45
46impl SequencedBroadcastMetrics {
47    pub fn update(&self, other: &Self) {
48        self.oldest_sequence.store(other.oldest_sequence.load(Ordering::Acquire), Ordering::Release);
49        self.next_sequence.store(other.next_sequence.load(Ordering::Acquire), Ordering::Release);
50        self.new_client_drop_count.store(other.new_client_drop_count.load(Ordering::Acquire), Ordering::Release);
51        self.new_client_accept_count.store(other.new_client_accept_count.load(Ordering::Acquire), Ordering::Release);
52        self.lagging_subs_gauge.store(other.lagging_subs_gauge.load(Ordering::Acquire), Ordering::Release);
53        self.active_subs_gauge.store(other.active_subs_gauge.load(Ordering::Acquire), Ordering::Release);
54        self.min_sub_sequence_gauge.store(other.min_sub_sequence_gauge.load(Ordering::Acquire), Ordering::Release);
55        self.disconnect_count.store(other.disconnect_count.load(Ordering::Acquire), Ordering::Release);
56        self.worker_loops.store(other.worker_loops.load(Ordering::Acquire), Ordering::Release);
57    }
58}
59
60struct Subscriber<T> {
61    id: u64,
62    next_sequence: u64,
63    tx: Sender<(u64, T)>,
64    allow_drop: bool,
65    lag_started_at: Option<Instant>,
66    pending: Option<T>,
67}
68
69#[derive(Debug, Clone)]
70pub struct SequencedBroadcastSettings {
71    pub subscriber_channel_len: usize,
72    pub lag_start_threshold: u64,
73    pub lag_end_threshold: u64,
74    pub max_time_lag: Duration,
75    pub min_history: u64,
76}
77
78struct Worker<T> {
79    rx: Receiver<(u64, T)>,
80    next_rx: Option<(u64, T)>,
81    rx_closed: bool,
82    rx_full: bool,
83
84    next_client_rx: Receiver<NewClient<T>>,
85    next_client: Option<NewClient<T>>,
86    next_client_closed: bool,
87
88    next_sub_id: u64,
89    subscribers: Vec<Subscriber<T>>,
90    queue: VecDeque<(u64, T)>,
91    next_queue_seq: u64,
92    metrics: Arc<SequencedBroadcastMetrics>,
93    settings: SequencedBroadcastSettings,
94    shutdown: CancellationToken,
95    worker_loops: Arc<AtomicU64>,
96    closed: oneshot::Sender<()>,
97}
98
99impl Default for SequencedBroadcastSettings {
100    fn default() -> Self {
101        SequencedBroadcastSettings {
102            subscriber_channel_len: 32,
103            lag_start_threshold: 1024 * 8,
104            lag_end_threshold: 1024 * 4,
105            max_time_lag: Duration::from_secs(2),
106            min_history: 2048,
107        }
108    }
109}
110
111impl<T> SequencedSender<T> {
112    pub fn new(next_seq: u64, send: Sender<(u64, T)>) -> Self {
113        SequencedSender { next_seq, send }
114    }
115
116    pub fn is_closed(&self) -> bool {
117        self.send.is_closed()
118    }
119
120    pub async fn closed(&self) {
121        self.send.closed().await
122    }
123
124    pub async fn safe_send(&mut self, seq: u64, item: T) -> Result<(), SequencedSenderError<T>> {
125        self._send(Some(seq), item).await
126    }
127
128    pub async fn send(&mut self, item: T) -> Result<(), SequencedSenderError<T>> {
129        self._send(None, item).await
130    }
131
132    pub fn try_send(&mut self, item: T) -> Result<(), TrySendError<T>> {
133        match self.send.try_send((self.next_seq, item)) {
134            Ok(()) => {
135                self.next_seq += 1;
136                Ok(())
137            }
138            Err(TrySendError::Full(err)) => Err(TrySendError::Full(err.1)),
139            Err(TrySendError::Closed(err)) => Err(TrySendError::Closed(err.1)),
140        }
141    }
142
143    pub async fn reserve(&mut self) -> Result<SequencedSenderPermit<T>, SendError<()>> {
144        let permit = self.send.reserve().await?;
145
146        Ok(SequencedSenderPermit {
147            next_seq: &mut self.next_seq,
148            permit,
149        })
150    }
151
152    async fn _send(&mut self, seq: Option<u64>, item: T) -> Result<(), SequencedSenderError<T>> {
153        if let Some(seq) = seq {
154            if seq != self.next_seq {
155                return Err(SequencedSenderError::InvalidSequence(self.next_seq, item));
156            }
157        }
158
159        if let Err(error) = self.send.send((self.next_seq, item)).await {
160            return Err(SequencedSenderError::ChannelClosed(error.0.1));
161        }
162
163        self.next_seq += 1;
164        Ok(())
165    }
166
167    pub fn seq(&self) -> u64 {
168        self.next_seq
169    }
170}
171
172pub struct SequencedSenderPermit<'a, T> {
173    next_seq: &'a mut u64,
174    permit: Permit<'a, (u64, T)>,
175}
176
177impl<'a, T> SequencedSenderPermit<'a, T> {
178    pub fn send(self, item: T) {
179        let seq = *self.next_seq;
180        self.permit.send((seq, item));
181        *self.next_seq = seq + 1;
182    }
183}
184
185impl<T> SequencedReceiver<T> {
186    pub fn new(next_seq: u64, receiver: Receiver<(u64, T)>) -> Self {
187        SequencedReceiver {
188            next_seq,
189            receiver
190        }
191    }
192
193    pub fn is_closed(&self) -> bool {
194        self.receiver.is_closed()
195    }
196
197    pub async fn recv(&mut self) -> Option<(u64, T)> {
198        let (seq, action) = self.receiver.recv().await?;
199        if self.next_seq != seq {
200            panic!("expected sequence: {} but got: {}", self.next_seq, seq);
201        }
202        self.next_seq += 1;
203        Some((seq, action))
204    }
205
206    pub fn try_recv(&mut self) -> Result<(u64, T), TryRecvError> {
207        match self.receiver.try_recv() {
208            Ok((seq, action)) => {
209                if self.next_seq != seq {
210                    panic!("expected sequence: {} but got: {}", self.next_seq, seq);
211                }
212                self.next_seq += 1;
213                Ok((seq, action))
214            }
215            Err(error) => Err(error)
216        }
217    }
218
219    pub fn unbundle(self) -> (u64, Receiver<(u64, T)>) {
220        (self.next_seq, self.receiver)
221    }
222
223    pub fn next_seq(&self) -> u64 {
224        self.next_seq
225    }
226}
227
228#[derive(Debug, PartialEq, Eq)]
229pub enum SequencedSenderError<T> {
230    InvalidSequence(u64, T),
231    ChannelClosed(T),
232}
233
234impl<T> SequencedSenderError<T> {
235    pub fn into_inner(self) -> T {
236        match self {
237            Self::InvalidSequence(_, v) => v,
238            Self::ChannelClosed(v) => v,
239        }
240    }
241}
242
243impl<T: Send + Clone + 'static> SequencedBroadcast<T> {
244    pub fn new(next_seq: u64, settings: SequencedBroadcastSettings) -> (Self, SequencedSender<T>) {
245        let (tx, rx) = channel(1024);
246        let tx = SequencedSender::new(next_seq, tx);
247        let rx = SequencedReceiver::new(next_seq, rx);
248
249        (
250            Self::new2(rx, settings),
251            tx
252        )
253    }
254
255    pub fn new2(receiver: SequencedReceiver<T>, settings: SequencedBroadcastSettings) -> Self {
256        let queue_cap = 2 * (
257            (settings.lag_start_threshold as usize)
258            .next_power_of_two()
259            .max(1024)
260        ).max((settings.min_history as usize).next_power_of_two());
261
262        assert!(settings.lag_end_threshold <= settings.lag_start_threshold);
263
264        let (client_tx, client_rx) = channel(32);
265
266        let metrics = Arc::new(SequencedBroadcastMetrics {
267            oldest_sequence: AtomicU64::new(receiver.next_seq),
268            next_sequence: AtomicU64::new(receiver.next_seq),
269            ..Default::default()
270        });
271
272        let shutdown = CancellationToken::new();
273        let current_span = tracing::Span::current();
274        let (closed_tx, closed_rx) = oneshot::channel();
275
276        let worker_loops = Arc::new(AtomicU64::new(0));
277
278        tokio::spawn(
279            Worker {
280                rx: receiver.receiver,
281                next_rx: None,
282                rx_full: false,
283                rx_closed: false,
284
285                next_client_rx: client_rx,
286                next_client: None,
287                next_client_closed: false,
288
289                next_sub_id: 1,
290                subscribers: Vec::with_capacity(32),
291                queue: VecDeque::with_capacity(queue_cap),
292                next_queue_seq: receiver.next_seq,
293                metrics: metrics.clone(),
294                settings,
295                shutdown: shutdown.clone(),
296                worker_loops: worker_loops.clone(),
297                closed: closed_tx,
298            }
299            .start()
300            .instrument(current_span),
301        );
302
303        Self {
304            new_client_tx: client_tx,
305            metrics,
306            shutdown,
307            worker_loops,
308            closed: closed_rx,
309        }
310    }
311
312    pub async fn add_client(
313        &self,
314        next_sequence: u64,
315        allow_drop: bool,
316    ) -> Result<SequencedReceiver<T>, NewClientError> {
317        let (tx, rx) = oneshot::channel();
318
319        self.new_client_tx
320            .send(NewClient {
321                response: tx,
322                allow_drop,
323                next_sequence,
324            })
325            .await
326            .expect("Failed to queue new subscriber, worker crashed");
327
328        rx.await.expect("worker closed")
329    }
330
331    pub fn metrics_ref(&self) -> &SequencedBroadcastMetrics {
332        &self.metrics
333    }
334
335    pub fn metrics(&self) -> Arc<SequencedBroadcastMetrics> {
336        self.metrics.clone()
337    }
338
339    pub fn worker_loops(&self) -> u64 {
340        self.worker_loops.load(Ordering::Relaxed)
341    }
342
343    pub fn shutdown(self) -> oneshot::Receiver<()> {
344        self.shutdown.cancel();
345        self.closed
346    }
347
348    pub async fn shutdown_wait(self) {
349        self.shutdown().await.unwrap();
350    }
351
352    pub fn closed(self) -> oneshot::Receiver<()> {
353        self.closed
354    }
355}
356
357struct NewClient<T> {
358    response: oneshot::Sender<Result<SequencedReceiver<T>, NewClientError>>,
359    next_sequence: u64,
360    allow_drop: bool,
361}
362
363#[derive(Debug)]
364pub enum NewClientError {
365    SequenceTooFarAhead { seq: u64, max: u64 },
366    SequenceTooFarBehind { seq: u64, min: u64 },
367}
368
369impl<T> Debug for NewClient<T> {
370    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        write!(
372            f,
373            "NewClient {{ next_sequence: {}, allow_drop: {} }}",
374            self.next_sequence, self.allow_drop
375        )
376    }
377}
378
379static WORKER_ID: LazyLock<Arc<AtomicU64>> = LazyLock::new(|| Arc::new(AtomicU64::new(1)));
380
381impl<T: Send + Clone + 'static> Worker<T> {
382    async fn start(mut self) {
383        let id = WORKER_ID.fetch_add(1, Ordering::SeqCst);
384        tracing::info!(id, "{}|SequencedBroadcastWorker Started", id);
385        let start = Instant::now();
386
387        self._start(id).await;
388        let elapsed = start.elapsed();
389        let iter = self.worker_loops.load(Ordering::Relaxed);
390
391        tracing::info!(id, ?elapsed, iter, "{}|SequencedBroadcastWorker Stopped", id);
392
393        let _ = self.closed.send(());
394    }
395
396    async fn _start(&mut self, id: u64) {
397        loop {
398            self.worker_loops.fetch_add(1, Ordering::Relaxed);
399            tokio::task::yield_now().await;
400
401            if self.next_client.is_none() {
402                self.next_client = match self.next_client_rx.try_recv() {
403                    Ok(item) => Some(item),
404                    Err(TryRecvError::Empty) => None,
405                    Err(TryRecvError::Disconnected) => {
406                        self.next_client_closed = true;
407                        None
408                    }
409                };
410            }
411
412            if self.shutdown.is_cancelled() {
413                tracing::info!("{}|Stopping worker due to shutdown", id);
414                break;
415            }
416
417            /* accept new clients */
418            if !self.next_client_closed {
419                let mut max_per_loop = 32;
420                let min_allowed_seq = self
421                    .queue
422                    .front()
423                    .map(|i| i.0)
424                    .unwrap_or(self.next_queue_seq);
425
426                while let Some(new) = self.next_client.take() {
427                    self.next_client = self.next_client_rx.try_recv().ok();
428
429                    /* Sequence in valid range */
430                    if new.next_sequence < min_allowed_seq
431                        || self.next_queue_seq < new.next_sequence
432                    {
433                        self.metrics
434                            .new_client_drop_count
435                            .fetch_add(1, Ordering::Relaxed);
436
437                        if new.next_sequence < min_allowed_seq {
438                            tracing::info!(
439                                "{}|Subscriber rejected, seq({}) < min_allowed({})",
440                                id,
441                                new.next_sequence,
442                                min_allowed_seq
443                            );
444
445                            let _ = new.response.send(Err(NewClientError::SequenceTooFarBehind {
446                                seq: new.next_sequence,
447                                min: min_allowed_seq
448                            }));
449                        } else {
450                            tracing::info!(
451                                "{}|Subscriber rejected, max_seq({}) < seq({})",
452                                id,
453                                self.next_queue_seq,
454                                new.next_sequence
455                            );
456
457                            let _ = new.response.send(Err(NewClientError::SequenceTooFarAhead {
458                                seq: new.next_sequence,
459                                max: self.next_queue_seq
460                            }));
461                        }
462
463                        continue;
464                    }
465
466                    self.metrics
467                        .new_client_accept_count
468                        .fetch_add(1, Ordering::Relaxed);
469
470                    /* Send Receiver to subscribers */
471                    let (tx, rx) = channel(self.settings.subscriber_channel_len);
472                    let rx = SequencedReceiver::<T> {
473                        receiver: rx,
474                        next_seq: new.next_sequence,
475                    };
476
477                    if new.response.send(Ok(rx)).is_ok() {
478                        let sub_id = self.next_sub_id;
479                        self.next_sub_id += 1;
480
481                        tracing::info!(
482                            "{}|Subscriber({}): Added, allow_drop: {}, next_sequence: {}, min_allowed_seq: {}",
483                            id, sub_id, new.allow_drop, new.next_sequence, min_allowed_seq,
484                        );
485
486                        self.subscribers.push(Subscriber {
487                            id: sub_id,
488                            allow_drop: new.allow_drop,
489                            next_sequence: new.next_sequence,
490                            pending: None,
491                            tx,
492                            lag_started_at: None,
493                        });
494                    } else {
495                        tracing::warn!("{}|New subscriber accepted but receiver dropped", id);
496                    }
497
498                    /* ensure we don't block getting new clients */
499                    if max_per_loop == 0 {
500                        break;
501                    }
502
503                    max_per_loop -= 1;
504                }
505            }
506
507            /* fill queue with available data from rx */
508            'fill_rx: {
509                if self.next_rx.is_none() {
510                    self.next_rx = match self.rx.try_recv() {
511                        Ok(msg) => Some(msg),
512                        Err(TryRecvError::Disconnected) => {
513                            self.rx_closed = true;
514                            None
515                        }
516                        Err(TryRecvError::Empty) => None,
517                    };
518                }
519
520                let mut remaining_msg_count = self.rx_space().min(1024);
521                if remaining_msg_count == 0 {
522                    if !self.rx_full {
523                        self.rx_full = true;
524                        assert_eq!(self.queue.len(), self.queue.capacity());
525                        tracing::info!("{}|Reached queue capacity {}", id, self.queue.len());
526                    }
527
528                    break 'fill_rx;
529                }
530
531                if self.rx_full {
532                    tracing::info!("{}|Space returned to queue {}/{}", id, self.rx_space(), self.queue.len());
533                    self.rx_full = false;
534                }
535
536                while let Some((seq, item)) = self.next_rx.take() {
537                    self.next_rx = self.rx.try_recv().ok();
538
539                    assert_eq!(seq, self.next_queue_seq, "sequence is invalid");
540                    self.queue.push_back((seq, item));
541                    self.next_queue_seq += 1;
542
543                    remaining_msg_count -= 1;
544                    if remaining_msg_count == 0 {
545                        break;
546                    }
547                }
548            }
549
550            self.metrics
551                .next_sequence
552                .store(self.next_queue_seq, Ordering::Relaxed);
553
554            let oldest_queue_sequence = self
555                .queue
556                .front()
557                .map(|v| v.0)
558                .unwrap_or(self.next_queue_seq);
559
560            let max_seq = oldest_queue_sequence + self.queue.len() as u64;
561            let lag_start_seq = max_seq.max(self.settings.lag_start_threshold) - self.settings.lag_start_threshold;
562            let lag_end_seq = lag_start_seq.max(max_seq.max(self.settings.lag_end_threshold) - self.settings.lag_end_threshold);
563
564            let mut min_sub_sequence_calc = self.next_queue_seq;
565            let mut earliest_lag_start_at_calc: Option<Instant> = None;
566
567            let mut i = 0;
568            'next_sub: while i < self.subscribers.len() {
569                let sub = &mut self.subscribers[i];
570
571                /* make sure sub is still valid */
572                if (sub.allow_drop && sub.next_sequence < oldest_queue_sequence) || sub.tx.is_closed() {
573                    if sub.tx.is_closed() {
574                        tracing::info!("{}|Subscriber({}): channel closed, dropping", sub.id, id);
575                    } else {
576                        tracing::warn!(
577                            "{}|Subscriber({}): lag behind available data ({} < {}), dropping",
578                            id,
579                            sub.id,
580                            sub.next_sequence,
581                            oldest_queue_sequence
582                        );
583                    }
584
585                    if sub.lag_started_at.is_some() {
586                        self.metrics
587                            .lagging_subs_gauge
588                            .fetch_sub(1, Ordering::Relaxed);
589                    }
590
591                    self.metrics
592                        .disconnect_count
593                        .fetch_add(1, Ordering::Relaxed);
594
595                    self.subscribers.swap_remove(i);
596                    continue 'next_sub;
597                }
598
599                /* write_to_sub */
600                let mut offset = {
601                    assert!(sub.next_sequence >= oldest_queue_sequence);
602                    let offset = (sub.next_sequence - oldest_queue_sequence) as usize;
603                    assert!(sub.next_sequence <= self.next_queue_seq, "sub cannot be ahead of queue sequence");
604                    assert!(offset <= self.queue.len(), "sub cannot be ahead of queue sequence");
605                    offset
606                };
607
608                /* prep next message to send */
609                if sub.pending.is_none() {
610                    /* fully caught up */
611                    if self.queue.len() == offset {
612                        i += 1;
613                        continue 'next_sub;
614                    }
615
616                    /* make next item pending */
617                    let (seq, item) = self.queue.get(offset).unwrap();
618                    assert_eq!(*seq, sub.next_sequence);
619                    sub.pending = Some(item.clone());
620                }
621
622                /* send as much as possible */
623                while let Some(next) = sub.pending.take() {
624                    match sub.tx.try_send((sub.next_sequence, next)) {
625                        Ok(_) => {
626                            sub.next_sequence += 1;
627                            offset += 1;
628
629                            if self.queue.len() == offset {
630                                break;
631                            }
632
633                            let (seq, item) = self.queue.get(offset).unwrap();
634                            assert_eq!(*seq, sub.next_sequence);
635                            sub.pending = Some(item.clone());
636                        }
637                        Err(TrySendError::Closed(_)) => break,
638                        Err(TrySendError::Full((_seq, item))) => {
639                            sub.pending = Some(item);
640                            break;
641                        }
642                    }
643                }
644
645                if sub.allow_drop {
646                    if lag_end_seq <= sub.next_sequence {
647                        if let Some(lag_start) = sub.lag_started_at.take() {
648                            tracing::info!(
649                                "{}|Subscriber({}): caught up after {:?}",
650                                id,
651                                sub.id,
652                                lag_start.elapsed()
653                            );
654
655                            self.metrics
656                                .lagging_subs_gauge
657                                .fetch_sub(1, Ordering::Relaxed);
658                        }
659                    }
660                    else if sub.next_sequence < lag_start_seq {
661                        if let Some(lag_start) = &sub.lag_started_at {
662                            let lag_duration = lag_start.elapsed();
663
664                            if self.settings.max_time_lag < lag_duration {
665                                tracing::info!(
666                                    "{}|Subscriber({}): lag too high for too long ({:?}), dropping",
667                                    id,
668                                    sub.id,
669                                    lag_duration,
670                                );
671
672                                self.metrics
673                                    .lagging_subs_gauge
674                                    .fetch_sub(1, Ordering::Relaxed);
675
676                                self.metrics
677                                    .disconnect_count
678                                    .fetch_add(1, Ordering::Relaxed);
679
680                                self.subscribers.swap_remove(i);
681                                continue 'next_sub;
682                            }
683                        } else {
684                            sub.lag_started_at = Some(Instant::now());
685
686                            tracing::info!(
687                                "{}|Subscriber({}): lag started thresh({}) < lag({})",
688                                id,
689                                sub.id,
690                                self.settings.lag_start_threshold,
691                                max_seq - sub.next_sequence,
692                            );
693
694                            self.metrics
695                                .lagging_subs_gauge
696                                .fetch_add(1, Ordering::Relaxed);
697                            }
698                    }
699                }
700
701                if let Some(lag_started_at) = &sub.lag_started_at {
702                    earliest_lag_start_at_calc = match earliest_lag_start_at_calc {
703                        Some(v) if v.lt(lag_started_at) => Some(v),
704                        _ => sub.lag_started_at
705                    };
706                }
707
708                min_sub_sequence_calc = min_sub_sequence_calc.min(sub.next_sequence);
709                i += 1;
710            }
711
712            let min_sub_sequence = min_sub_sequence_calc;
713
714            self.metrics
715                .active_subs_gauge
716                .store(self.subscribers.len() as u64, Ordering::Relaxed);
717
718            self.metrics
719                .min_sub_sequence_gauge
720                .store(min_sub_sequence, Ordering::Relaxed);
721
722            /* trim rx queue */
723            {
724                let keep_seq = min_sub_sequence.min(max_seq.max(self.settings.min_history) - self.settings.min_history);
725
726                if oldest_queue_sequence < keep_seq {
727                    let remove_count = keep_seq - oldest_queue_sequence;
728                    if remove_count != 0 {
729                        let _ = self.queue.drain(0..remove_count as usize);
730                    }
731
732                    self.metrics
733                        .oldest_sequence
734                        .store(oldest_queue_sequence + remove_count, Ordering::Relaxed);
735                }
736            }
737
738            if self.rx_closed && min_sub_sequence == max_seq {
739                tracing::info!("{}|RX closed and all subscribers caught up, shutting down worker", id);
740                return;
741            }
742
743            if self.next_client_closed && self.subscribers.is_empty() {
744                tracing::info!("{}|no subscribers and next_client_rx closed, shutting down worker", id);
745                return;
746            }
747
748            let rx_blocked = self.next_rx.is_none() && !self.rx_closed;
749            let next_timeout = earliest_lag_start_at_calc.map(|early| {
750                let now = Instant::now();
751                let expire = early + self.settings.max_time_lag;
752                (expire.max(now) - now).max(Duration::from_millis(100))
753            });
754
755            /* see if there's more work available without waiting */
756            {
757                /* update RX */
758                if !rx_blocked && 0 < self.rx_space() {
759                    tracing::trace!("{}|have more rx", id);
760                    continue;
761                }
762
763                /* new client available */
764                if self.next_client.is_some() {
765                    tracing::trace!("{}|have next client", id);
766                    continue;
767                }
768            }
769
770            let mut timeout_fut = next_timeout.map(|duration| tokio::time::sleep(duration));
771            let mut pending_tx = Vec::new();
772            let new_client_rx = &mut self.next_client_rx;
773            let new_msg_rx = &mut self.rx;
774            let next_rx = &mut self.next_rx;
775            let next_client = &mut self.next_client;
776
777            for sub in &mut self.subscribers {
778                if sub.pending.is_some() {
779                    pending_tx.push((sub.tx.reserve(), &mut sub.pending, &mut sub.next_sequence));
780                }
781            }
782
783            poll_fn(|cx| {
784                if let Some(timeout) = &mut timeout_fut {
785                    if unsafe { std::pin::Pin::new_unchecked(timeout) }.poll(cx).is_ready() {
786                        tracing::trace!("{}|poll: max lag timer reached", id);
787                        return Poll::Ready(());
788                    }
789                }
790
791                if rx_blocked {
792                    if let Poll::Ready(item) = unsafe { std::pin::Pin::new_unchecked(&mut *new_msg_rx) }.poll_recv(cx) {
793                        assert!(next_rx.is_none());
794
795                        *next_rx = item;
796                        if next_rx.is_some() {
797                            tracing::trace!("{}|poll: new RX available", id);
798                        } else {
799                            tracing::trace!("{}|poll: RX closed", id);
800                        }
801
802                        return Poll::Ready(());
803                    }
804                }
805
806                if let Poll::Ready(item) = unsafe { std::pin::Pin::new_unchecked(&mut *new_client_rx) }.poll_recv(cx) {
807                    tracing::trace!("{}|poll: new client", id);
808
809                    assert!(next_client.is_none());
810                    *next_client = item;
811                    return Poll::Ready(());
812                }
813
814                let mut sent = false;
815                for (reserve, pending, next_sequence) in &mut pending_tx {
816                    let reserve = unsafe { std::pin::Pin::new_unchecked(reserve) };
817
818                    match reserve.poll(cx) {
819                        Poll::Ready(Ok(slot)) => {
820                            let item = pending.take().expect("pending missing");
821                            let seq = **next_sequence;
822                            slot.send((seq, item));
823                            **next_sequence = seq + 1;
824                            
825                            sent = true;
826                        }
827                        Poll::Ready(Err(_)) => {
828                            sent = true;
829                        }
830                        Poll::Pending => {}
831                    }
832                }
833
834                if sent {
835                    tracing::trace!("{}|poll: subscriber message sent", id);
836                    return Poll::Ready(());
837                }
838
839                Poll::Pending
840            }).await;
841        }
842    }
843
844    fn rx_space(&self) -> usize {
845        self.queue.capacity() - self.queue.len()
846    }
847}
848
849#[cfg(test)]
850mod test {
851    use super::*;
852
853    pub fn setup_logging() {
854        let _ = tracing_subscriber::fmt().with_test_writer().try_init();
855        // let _ = tracing_subscriber::fmt().try_init();
856    }
857
858    #[tokio::test]
859    async fn subscribers_shutdown_test() {
860        setup_logging();
861
862        let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
863
864        let client = subs.add_client(0, true).await.unwrap();
865        tx.send("Hello World").await.unwrap();
866
867        let close_wait = subs.shutdown();
868
869        tokio::time::timeout(Duration::from_millis(100), close_wait).await
870            .expect("timeout waiting for close")
871            .expect("close handler dropped before send");
872        
873        drop(client);
874        drop(tx);
875    }
876
877    #[tokio::test]
878    async fn subscribers_close_no_subs_test() {
879        setup_logging();
880
881        let close_wait = {
882            let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
883            tx.send("Hello World").await.unwrap();
884            subs.closed()
885        };
886
887        tokio::time::timeout(Duration::from_millis(100), close_wait).await
888            .expect("timeout waiting for close")
889            .expect("close handler dropped before send");
890    }
891
892    #[tokio::test]
893    async fn subscribers_close_sub_caught_up_test() {
894        setup_logging();
895
896        let (close_wait, mut client) = {
897            let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
898            tx.send("Hello World").await.unwrap();
899            let client = subs.add_client(0, true).await.unwrap();
900            (subs.closed(), client)
901        };
902
903        assert_eq!((0, "Hello World"), client.recv().await.unwrap());
904        drop(client);
905
906        tokio::time::timeout(Duration::from_millis(100), close_wait).await
907            .expect("timeout waiting for close")
908            .expect("close handler dropped before send");
909    }
910
911    #[tokio::test]
912    async fn subscribers_close_sub_caught_up_tx_alive_test() {
913        setup_logging();
914
915        let (close_wait, mut client, tx) = {
916            let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
917            tx.send("Hello World").await.unwrap();
918            let client = subs.add_client(0, true).await.unwrap();
919            (subs.closed(), client, tx)
920        };
921
922        assert_eq!((0, "Hello World"), client.recv().await.unwrap());
923        drop(client);
924
925        tokio::time::timeout(Duration::from_millis(100), close_wait).await
926            .expect("timeout waiting for close")
927            .expect("close handler dropped before send");
928
929        drop(tx);
930    }
931
932    #[tokio::test]
933    async fn subscribers_close_sub_not_caught_up_test() {
934        setup_logging();
935
936        let (close_wait, mut client) = {
937            let (subs, mut tx) = SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
938            tx.send("Hello World").await.unwrap();
939            tx.send("Hello World 2").await.unwrap();
940            let client = subs.add_client(0, true).await.unwrap();
941            (subs.closed(), client)
942        };
943
944        assert_eq!((0, "Hello World"), client.recv().await.unwrap());
945        drop(client);
946
947        tokio::time::timeout(Duration::from_millis(100), close_wait).await
948            .expect("timeout waiting for close")
949            .expect("close handler dropped before send");
950    }
951
952    #[tokio::test]
953    async fn subscribers_catchup_test() {
954        setup_logging();
955
956        let (subs, mut tx) =
957            SequencedBroadcast::<&'static str>::new(0, SequencedBroadcastSettings::default());
958
959        tx.send("Hello WOrld").await.unwrap();
960        tx.send("What the heck").await.unwrap();
961
962        let mut sub_1 = subs.add_client(0, true).await.unwrap();
963        assert_eq!((0, "Hello WOrld"), sub_1.recv().await.unwrap());
964        assert_eq!((1, "What the heck"), sub_1.recv().await.unwrap());
965
966        let mut sub_2 = subs.add_client(0, true).await.unwrap();
967        assert_eq!((0, "Hello WOrld"), sub_2.recv().await.unwrap());
968        assert_eq!((1, "What the heck"), sub_2.recv().await.unwrap());
969
970        let mut sub_3 = subs.add_client(1, true).await.unwrap();
971        assert_eq!((1, "What the heck"), sub_3.recv().await.unwrap());
972
973        tx.send("Hehe").await.unwrap();
974        assert_eq!((2, "Hehe"), sub_1.recv().await.unwrap());
975        assert_eq!((2, "Hehe"), sub_2.recv().await.unwrap());
976        assert_eq!((2, "Hehe"), sub_3.recv().await.unwrap());
977
978        subs.shutdown_wait().await;
979    }
980
981    #[tokio::test]
982    async fn sequenced_broadcast_simple_test() {
983        setup_logging();
984
985        let (subs, mut tx) =
986            SequencedBroadcast::<u64>::new(10, SequencedBroadcastSettings::default());
987
988        let mut client = subs.add_client(10, true).await.unwrap();
989        tracing::info!("client added");
990
991        let read_task = tokio::spawn(async move {
992            let mut i = 0;
993            let mut seq = 10;
994
995            while let Some(msg) = client.recv().await {
996                assert_eq!(msg, (seq, i));
997                i += 1;
998                seq += 1;
999            }
1000
1001            i
1002        });
1003
1004        let count = 1024 * 16;
1005
1006        for i in 0..count {
1007            tx.send(i).await.unwrap();
1008        }
1009        drop(tx);
1010
1011        let total = read_task.await.unwrap();
1012        assert_eq!(total, count);
1013
1014        subs.shutdown_wait().await;
1015    }
1016
1017    #[tokio::test]
1018    async fn subscribers_test() {
1019        setup_logging();
1020
1021        let (subs, mut tx) =
1022            SequencedBroadcast::<&'static str>::new(10, SequencedBroadcastSettings::default());
1023        tx.send("Hello WOrld").await.unwrap();
1024        tx.send("What the heck").await.unwrap();
1025
1026        let mut sub = subs.add_client(10, true).await.unwrap();
1027        assert_eq!((10, "Hello WOrld"), sub.recv().await.unwrap());
1028        assert_eq!((11, "What the heck"), sub.recv().await.unwrap());
1029
1030        assert!(subs.add_client(10, true).await.is_ok());
1031        assert!(subs.add_client(11, true).await.is_ok());
1032        assert!(subs.add_client(12, true).await.is_ok());
1033        assert!(subs.add_client(13, true).await.is_err());
1034        assert!(subs.add_client(9, true).await.is_err());
1035
1036        tx.send("Butts").await.unwrap();
1037        assert_eq!((12, "Butts"), sub.recv().await.unwrap());
1038
1039        tokio::time::sleep(Duration::from_millis(1)).await;
1040
1041        tracing::info!("Metrics: {:?}", subs.metrics_ref());
1042
1043        subs.shutdown_wait().await;
1044    }
1045
1046    #[tokio::test]
1047    async fn subscribers_dont_drop_test() {
1048        setup_logging();
1049
1050        let (subs, mut tx) = SequencedBroadcast::<i64>::new(
1051            1,
1052            SequencedBroadcastSettings {
1053                max_time_lag: Duration::from_millis(100),
1054                ..Default::default()
1055            },
1056        );
1057
1058        let mut sub = subs.add_client(1, false).await.unwrap();
1059
1060        tokio::time::timeout(Duration::from_secs(3), async {
1061            loop {
1062                if tokio::time::timeout(Duration::from_secs(1), tx.send(1))
1063                    .await
1064                    .is_err()
1065                {
1066                    break;
1067                }
1068            }
1069        }).await.expect("client must have been dropped as can still send tx");
1070
1071        tracing::info!("tx filled");
1072
1073        assert!(tokio::time::timeout(Duration::from_secs(1), tx.send(1))
1074            .await
1075            .is_err());
1076        assert_eq!((1, 1), sub.recv().await.unwrap());
1077
1078        assert!(
1079            tokio::time::timeout(Duration::from_millis(10), tx.send(1000))
1080                .await
1081                .is_ok()
1082        );
1083
1084        let sub_mut = &mut sub;
1085
1086        tokio::time::timeout(Duration::from_millis(100), async move {
1087            loop {
1088                let (_, num) = sub_mut.recv().await.unwrap();
1089                if num == 1000 {
1090                    break;
1091                }
1092            }
1093        })
1094        .await
1095        .unwrap();
1096
1097        assert!(
1098            tokio::time::timeout(Duration::from_millis(10), tx.send(2000))
1099                .await
1100                .is_ok()
1101        );
1102        assert_eq!(2000, sub.recv().await.unwrap().1);
1103
1104        subs.shutdown_wait().await;
1105    }
1106
1107    #[tokio::test]
1108    async fn subscribers_no_clients_test() {
1109        setup_logging();
1110
1111        let (subs, mut tx) =
1112            SequencedBroadcast::<&'static str>::new(1, SequencedBroadcastSettings::default());
1113        let (subs, mut tx) = tokio::time::timeout(Duration::from_secs(1), async move {
1114            for _ in 0..1_000_000 {
1115                tx.send("Hello World").await.unwrap();
1116            }
1117
1118            tracing::info!("Sent 1M messages");
1119
1120            while tx.seq() != subs.metrics_ref().next_sequence.load(Ordering::Relaxed) {
1121                tokio::time::sleep(Duration::from_millis(100)).await;
1122            }
1123
1124            tracing::info!("All 1M messages have been processed");
1125
1126            (subs, tx)
1127        })
1128        .await
1129        .unwrap();
1130
1131        let seq = tx.seq();
1132        tracing::info!("Seq: {}", seq);
1133
1134        let mut sub = subs.add_client(seq - 1, true).await.unwrap();
1135        tx.send("Test").await.unwrap();
1136
1137        assert_eq!((seq - 1, "Hello World"), sub.recv().await.unwrap());
1138        assert_eq!((seq, "Test"), sub.recv().await.unwrap());
1139
1140        subs.shutdown_wait().await;
1141    }
1142
1143    #[tokio::test]
1144    async fn continious_send_send_test() {
1145        setup_logging();
1146
1147        let (subs, mut tx) = SequencedBroadcast::<u64>::new(1, SequencedBroadcastSettings {
1148            min_history: 1024,
1149            lag_end_threshold: 128,
1150            lag_start_threshold: 512,
1151            max_time_lag: Duration::from_secs(20),
1152            ..Default::default()
1153        });
1154
1155        let mut read_tasks = vec![];
1156        for _ in 0..32 {
1157            let mut client = subs.add_client(1, true).await.unwrap();
1158
1159            let read_task = tokio::spawn(async move {
1160                let mut next = 1;
1161                while let Some((seq, num)) = client.recv().await {
1162                    assert_eq!(seq, num);
1163                    assert_eq!(seq, next);
1164                    next = seq + 1;
1165                }
1166                next
1167            });
1168
1169            read_tasks.push(read_task);
1170        }
1171
1172        let start = Instant::now();
1173        let mut end = 1;
1174        while start.elapsed() < Duration::from_secs(5) {
1175            tokio::time::timeout(Duration::from_secs(1), tx.send(end)).await
1176                .expect("timeout sending message")
1177                .expect("failed to send message");
1178
1179            end += 1;
1180        }
1181
1182        drop(tx);
1183
1184        for read_task in read_tasks {
1185            let count = tokio::time::timeout(Duration::from_secs(1), read_task).await
1186                .expect("timeout waiting for rx task to close")
1187                .expect("rx task crashed");
1188
1189            assert_eq!(count, end);
1190        }
1191    }
1192
1193    #[tokio::test]
1194    async fn subscribers_drops_slow_sub_test() {
1195        setup_logging();
1196
1197        let (subs, mut tx) = SequencedBroadcast::<i64>::new(
1198            1,
1199            SequencedBroadcastSettings {
1200                max_time_lag: Duration::from_secs(1),
1201                subscriber_channel_len: 4,
1202                lag_start_threshold: 64,
1203                lag_end_threshold: 32,
1204                ..Default::default()
1205            },
1206        );
1207
1208        let mut fast_client = subs.add_client(1, true).await.unwrap();
1209        let mut slow_client = subs.add_client(1, true).await.unwrap();
1210
1211        let send_task = tokio::spawn(async move {
1212            let mut i = 0;
1213            /* 5 seconds of sending */
1214
1215            for _ in 0..1_000 {
1216                tokio::time::sleep(Duration::from_millis(5)).await;
1217                i += 1;
1218                tx.send(i).await.unwrap();
1219            }
1220
1221            tracing::info!("Done sending");
1222            drop(subs);
1223
1224            i
1225        });
1226
1227        let fast_recv_task = tokio::spawn(async move {
1228            let mut last = None;
1229            while let Some(recv) = fast_client.recv().await {
1230                last = Some(recv.1);
1231            }
1232            tracing::info!("Fast Done: {:?}", last);
1233            last.unwrap()
1234        });
1235
1236        let slow_recv_task = tokio::spawn(async move {
1237            let mut last = None;
1238            while let Some(recv) = slow_client.recv().await {
1239                last = Some(recv.1);
1240                tokio::time::sleep(Duration::from_millis(100)).await;
1241            }
1242            tracing::info!("Slow done: {:?}", last);
1243            last.unwrap()
1244        });
1245
1246        let sent_i = send_task.await.unwrap();
1247        let fast_recv_i = fast_recv_task.await.unwrap();
1248        let slow_recv_i = slow_recv_task.await.unwrap();
1249
1250        assert_eq!(sent_i, 1000);
1251        assert_eq!(fast_recv_i, 1000);
1252        assert_eq!(slow_recv_i, 19);
1253    }
1254}