Skip to main content

atomr_streams/
overflow.rs

1//! Overflow strategies for bounded buffers.
2//!
3//! Implemented as a helper that wraps a source into a bounded tokio mpsc
4//! channel and applies the chosen drop/fail/backpressure policy when the
5//! channel is full. Mirrors the upstream `OverflowStrategy` enum.
6
7use std::sync::Arc;
8
9use futures::stream::StreamExt;
10use parking_lot::Mutex;
11use tokio::sync::Notify;
12
13use crate::source::Source;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[non_exhaustive]
17pub enum OverflowStrategy {
18    /// Propagate backpressure by awaiting channel capacity.
19    Backpressure,
20    /// Drop the oldest buffered element to make room.
21    DropHead,
22    /// Drop the newest produced element (the one that would overflow).
23    DropNew,
24    /// Drop the newest buffered element.
25    DropTail,
26    /// Drop every buffered element when overflow happens.
27    DropBuffer,
28    /// Fail the stream on overflow.
29    Fail,
30}
31
32pub(crate) fn apply<T: Send + 'static>(
33    source: Source<T>,
34    size: usize,
35    strategy: OverflowStrategy,
36) -> Source<T> {
37    let cap = size.max(1);
38    let state: Arc<Mutex<BufferState<T>>> = Arc::new(Mutex::new(BufferState::default()));
39    let notify = Arc::new(Notify::new());
40    let state_p = Arc::clone(&state);
41    let notify_p = Arc::clone(&notify);
42    let mut inner = source.into_boxed();
43    tokio::spawn(async move {
44        while let Some(item) = inner.next().await {
45            let mut item_opt = Some(item);
46            let mut overflowed = false;
47            {
48                let mut guard = state_p.lock();
49                if guard.items.len() >= cap {
50                    match strategy {
51                        OverflowStrategy::DropHead => {
52                            guard.items.pop_front();
53                            guard.items.push_back(item_opt.take().unwrap());
54                        }
55                        OverflowStrategy::DropTail => {
56                            guard.items.pop_back();
57                            guard.items.push_back(item_opt.take().unwrap());
58                        }
59                        OverflowStrategy::DropNew => {
60                            item_opt = None;
61                        }
62                        OverflowStrategy::DropBuffer => {
63                            guard.items.clear();
64                            guard.items.push_back(item_opt.take().unwrap());
65                        }
66                        OverflowStrategy::Fail => {
67                            guard.failed = true;
68                            guard.complete = true;
69                            drop(guard);
70                            notify_p.notify_waiters();
71                            return;
72                        }
73                        OverflowStrategy::Backpressure => {
74                            overflowed = true;
75                        }
76                    }
77                } else {
78                    guard.items.push_back(item_opt.take().unwrap());
79                }
80            }
81            if overflowed {
82                while let Some(item) = item_opt.take() {
83                    notify_p.notified().await;
84                    let mut g = state_p.lock();
85                    if g.items.len() < cap {
86                        g.items.push_back(item);
87                        break;
88                    } else {
89                        item_opt = Some(item);
90                    }
91                }
92            }
93            notify_p.notify_one();
94        }
95        state_p.lock().complete = true;
96        notify_p.notify_waiters();
97    });
98
99    let out = futures::stream::unfold((state, notify), |(state, notify)| async move {
100        loop {
101            {
102                let mut guard = state.lock();
103                if guard.failed {
104                    return None;
105                }
106                if let Some(v) = guard.items.pop_front() {
107                    notify.notify_one();
108                    return Some((v, (state.clone(), notify.clone())));
109                }
110                if guard.complete {
111                    return None;
112                }
113            }
114            notify.notified().await;
115        }
116    })
117    .boxed();
118    Source { inner: out }
119}
120
121struct BufferState<T> {
122    items: std::collections::VecDeque<T>,
123    complete: bool,
124    failed: bool,
125}
126
127impl<T> Default for BufferState<T> {
128    fn default() -> Self {
129        Self { items: std::collections::VecDeque::new(), complete: false, failed: false }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::sink::Sink;
137
138    #[tokio::test]
139    async fn buffer_backpressure_forwards_all_elements() {
140        let src = Source::from_iter(1..=100_i32);
141        let buffered = src.buffer(8, OverflowStrategy::Backpressure);
142        let out = Sink::collect(buffered).await;
143        assert_eq!(out.len(), 100);
144        assert_eq!(out[0], 1);
145        assert_eq!(out[99], 100);
146    }
147
148    #[tokio::test]
149    async fn buffer_drop_new_limits_output() {
150        // Fast producer, slow consumer: with DropNew and size=1 we should
151        // receive fewer than all items once the buffer fills.
152        let src = Source::from_iter(0..1_000_i32);
153        let buffered = src.buffer(1, OverflowStrategy::DropNew);
154        let mut count = 0usize;
155        let out = buffered.into_boxed();
156        use futures::StreamExt;
157        tokio::pin!(out);
158        while out.next().await.is_some() {
159            count += 1;
160        }
161        assert!(count <= 1_000);
162    }
163}