noosphere_common/
latency.rs

1use instant::{Duration, Instant};
2
3use futures_util::Stream;
4use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
5
6use crate::ConditionalSend;
7
8/// A helper for observing when [Stream] throughput appears to have stalled
9pub struct StreamLatencyGuard<S>
10where
11    S: Stream + Unpin,
12    S::Item: ConditionalSend + 'static,
13{
14    inner: S,
15    threshold: Duration,
16    last_ready_time: Instant,
17    tx: UnboundedSender<()>,
18}
19
20impl<S> StreamLatencyGuard<S>
21where
22    S: Stream + Unpin,
23    S::Item: ConditionalSend + 'static,
24{
25    /// Wraps a [Stream] and provides an [UnboundedReceiver<()>] that will receive
26    /// a message any time the wrapped [Stream] is pending for longer than the provided
27    /// threshold [Duration].
28    pub fn wrap(stream: S, threshold: Duration) -> (Self, UnboundedReceiver<()>) {
29        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<()>();
30        (
31            StreamLatencyGuard {
32                inner: stream,
33                threshold,
34                last_ready_time: Instant::now(),
35                tx,
36            },
37            rx,
38        )
39    }
40}
41
42impl<S> Stream for StreamLatencyGuard<S>
43where
44    S: Stream + Unpin,
45    S::Item: ConditionalSend + 'static,
46{
47    type Item = S::Item;
48
49    fn poll_next(
50        mut self: std::pin::Pin<&mut Self>,
51        cx: &mut std::task::Context<'_>,
52    ) -> std::task::Poll<Option<Self::Item>> {
53        let result = std::pin::pin!(&mut self.inner).poll_next(cx);
54
55        if result.is_pending() {
56            if Instant::now() - self.last_ready_time > self.threshold {
57                let _ = self.tx.send(());
58            }
59        } else if result.is_ready() {
60            self.last_ready_time = Instant::now();
61        }
62
63        result
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use anyhow::Result;
70    use instant::Duration;
71    use tokio::select;
72    use tokio_stream::StreamExt;
73
74    use crate::{helpers::wait, StreamLatencyGuard};
75
76    #[cfg(target_arch = "wasm32")]
77    use wasm_bindgen_test::wasm_bindgen_test;
78
79    #[cfg(target_arch = "wasm32")]
80    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
81
82    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
83    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
84    async fn it_does_not_impede_the_behavior_of_a_wrapped_stream() -> Result<()> {
85        let stream = tokio_stream::iter(Vec::from([0u32; 1024]));
86
87        let (guarded_stream, _latency_signal) =
88            StreamLatencyGuard::wrap(stream, Duration::from_secs(1));
89
90        tokio::pin!(guarded_stream);
91
92        guarded_stream.collect::<Vec<u32>>().await;
93
94        Ok(())
95    }
96
97    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
98    #[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
99    async fn it_signals_when_a_stream_encounters_latency() -> Result<()> {
100        let stream = Box::pin(futures_util::stream::unfold(0, |index| async move {
101            match index {
102                512 => {
103                    for _ in 0..3 {
104                        // Uh oh, latency! Note that `tokio::time::sleep` is observed to cooperate
105                        // with the runtime, so we wait multiple times to ensure that the stream is
106                        // actually polled multiple times
107                        wait(1).await;
108                    }
109                    Some((index, index + 1))
110                }
111                _ if index < 1024 => Some((index, index + 1)),
112                _ => None,
113            }
114        }));
115
116        let (guarded_stream, mut latency_guard) =
117            StreamLatencyGuard::wrap(stream, Duration::from_millis(100));
118
119        tokio::pin!(guarded_stream);
120
121        select! {
122            _ = guarded_stream.collect::<Vec<u32>>() => {
123                unreachable!("Latency guard should be hit first");
124            },
125            _ = latency_guard.recv() => ()
126        }
127
128        Ok(())
129    }
130}