atomr_streams/
overflow.rs1use std::sync::Arc;
8
9use futures::stream::StreamExt;
10use parking_lot::Mutex;
11use tokio::sync::Notify;
12
13use crate::source::Source;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[non_exhaustive]
17pub enum OverflowStrategy {
18 Backpressure,
20 DropHead,
22 DropNew,
24 DropTail,
26 DropBuffer,
28 Fail,
30}
31
32pub(crate) fn apply<T: Send + 'static>(
33 source: Source<T>,
34 size: usize,
35 strategy: OverflowStrategy,
36) -> Source<T> {
37 let cap = size.max(1);
38 let state: Arc<Mutex<BufferState<T>>> = Arc::new(Mutex::new(BufferState::default()));
39 let notify = Arc::new(Notify::new());
40 let state_p = Arc::clone(&state);
41 let notify_p = Arc::clone(¬ify);
42 let mut inner = source.into_boxed();
43 tokio::spawn(async move {
44 while let Some(item) = inner.next().await {
45 let mut item_opt = Some(item);
46 let mut overflowed = false;
47 {
48 let mut guard = state_p.lock();
49 if guard.items.len() >= cap {
50 match strategy {
51 OverflowStrategy::DropHead => {
52 guard.items.pop_front();
53 guard.items.push_back(item_opt.take().unwrap());
54 }
55 OverflowStrategy::DropTail => {
56 guard.items.pop_back();
57 guard.items.push_back(item_opt.take().unwrap());
58 }
59 OverflowStrategy::DropNew => {
60 item_opt = None;
61 }
62 OverflowStrategy::DropBuffer => {
63 guard.items.clear();
64 guard.items.push_back(item_opt.take().unwrap());
65 }
66 OverflowStrategy::Fail => {
67 guard.failed = true;
68 guard.complete = true;
69 drop(guard);
70 notify_p.notify_waiters();
71 return;
72 }
73 OverflowStrategy::Backpressure => {
74 overflowed = true;
75 }
76 }
77 } else {
78 guard.items.push_back(item_opt.take().unwrap());
79 }
80 }
81 if overflowed {
82 while let Some(item) = item_opt.take() {
83 notify_p.notified().await;
84 let mut g = state_p.lock();
85 if g.items.len() < cap {
86 g.items.push_back(item);
87 break;
88 } else {
89 item_opt = Some(item);
90 }
91 }
92 }
93 notify_p.notify_one();
94 }
95 state_p.lock().complete = true;
96 notify_p.notify_waiters();
97 });
98
99 let out = futures::stream::unfold((state, notify), |(state, notify)| async move {
100 loop {
101 {
102 let mut guard = state.lock();
103 if guard.failed {
104 return None;
105 }
106 if let Some(v) = guard.items.pop_front() {
107 notify.notify_one();
108 return Some((v, (state.clone(), notify.clone())));
109 }
110 if guard.complete {
111 return None;
112 }
113 }
114 notify.notified().await;
115 }
116 })
117 .boxed();
118 Source { inner: out }
119}
120
121struct BufferState<T> {
122 items: std::collections::VecDeque<T>,
123 complete: bool,
124 failed: bool,
125}
126
127impl<T> Default for BufferState<T> {
128 fn default() -> Self {
129 Self { items: std::collections::VecDeque::new(), complete: false, failed: false }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::sink::Sink;
137
138 #[tokio::test]
139 async fn buffer_backpressure_forwards_all_elements() {
140 let src = Source::from_iter(1..=100_i32);
141 let buffered = src.buffer(8, OverflowStrategy::Backpressure);
142 let out = Sink::collect(buffered).await;
143 assert_eq!(out.len(), 100);
144 assert_eq!(out[0], 1);
145 assert_eq!(out[99], 100);
146 }
147
148 #[tokio::test]
149 async fn buffer_drop_new_limits_output() {
150 let src = Source::from_iter(0..1_000_i32);
153 let buffered = src.buffer(1, OverflowStrategy::DropNew);
154 let mut count = 0usize;
155 let out = buffered.into_boxed();
156 use futures::StreamExt;
157 tokio::pin!(out);
158 while out.next().await.is_some() {
159 count += 1;
160 }
161 assert!(count <= 1_000);
162 }
163}