1use async_trait::async_trait;
2use binary_heap_plus::BinaryHeap;
3use futures::stream::Stream;
4use futures::stream::StreamExt;
5use std::cmp::Ordering;
6use tracing::instrument;
7
8#[async_trait]
9pub trait KeepFirstN<T, F>
10where
11 F: Fn(&T, &T) -> Ordering,
12{
13 async fn keep_first_n(
16 self,
17 n: usize,
18 sorted_by: F,
19 ) -> futures::stream::Iter<std::vec::IntoIter<T>>;
20}
21
22#[cfg(any(feature = "tokio", feature = "async-std"))]
23#[async_trait]
24impl<SInput, T, F> KeepFirstN<T, F> for SInput
25where
26 SInput: Stream<Item = T> + Send + Unpin + std::marker::Sync + 'static,
27 T: Clone + Send + std::marker::Sync + std::fmt::Debug + 'static,
28 F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + std::marker::Copy + 'static,
29{
30 #[instrument(skip(self, sorted_by))]
31 async fn keep_first_n(
32 mut self,
33 n: usize,
34 sorted_by: F,
35 ) -> futures::stream::Iter<std::vec::IntoIter<T>> {
36 let first_n = BinaryHeap::with_capacity_by(n, move |a, b| sorted_by(a, b).reverse());
38 impl_keep_first_n(self, first_n, n, sorted_by).await
39 }
40}
41
42#[cfg(any(feature = "tokio", feature = "async-std"))]
51async fn impl_keep_first_n<SInput, T, F, FReversed>(
52 sinput: SInput,
53 _first_n: BinaryHeap<T, binary_heap_plus::FnComparator<FReversed>>,
54 n: usize,
55 sorted_by: F,
56) -> futures::stream::Iter<std::vec::IntoIter<T>>
57where
58 SInput: Stream<Item = T> + Send + Unpin + std::marker::Sync + 'static,
59 T: Clone + Send + std::marker::Sync + std::fmt::Debug + 'static,
60 F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + std::marker::Copy + 'static,
61 FReversed: Fn(&T, &T) -> std::cmp::Ordering + Clone + Send + 'static,
62{
63 let mut indexed_stream = sinput.enumerate();
65
66 let indexed_comparator = move |a: &(usize, T), b: &(usize, T)| {
68 match sorted_by(&a.1, &b.1) {
69 Ordering::Less => Ordering::Less,
70 Ordering::Greater => Ordering::Greater,
71 Ordering::Equal => a.0.cmp(&b.0),
73 }
74 };
75 let mut first_n =
76 BinaryHeap::with_capacity_by(n, move |a, b| indexed_comparator(a, b).reverse());
77
78 while first_n.len() < n {
80 if let Some(indexed_item) = indexed_stream.next().await {
81 first_n.push(indexed_item);
82 } else {
83 break;
84 }
85 }
86
87 if first_n.len() < n {
89 return futures::stream::iter(
90 first_n
91 .into_sorted_vec()
92 .into_iter()
93 .map(|(_idx, item)| item) .collect::<Vec<_>>()
95 .into_iter(),
96 );
97 }
98
99 let first_n_mutex = std::sync::Arc::new(parking_lot::Mutex::new(first_n));
107 let smallest_kept = std::sync::Arc::new(parking_lot::RwLock::new(
108 first_n_mutex.lock().peek().unwrap().to_owned(),
109 ));
110 {
111 let first_n_arc = first_n_mutex.clone();
112 let smallest_kept_arc = smallest_kept.clone();
113 let mut ongoing_tasks = indexed_stream
114 .map(move |indexed_item| {
115 let first_n_arc = first_n_arc.clone();
116 let smallest_kept_arc = smallest_kept_arc.clone();
117 crate::async_runtime::spawn(async move {
118 {
121 let smallest = smallest_kept_arc.read();
122 if sorted_by(&smallest.1, &indexed_item.1) == Ordering::Greater {
123 return;
124 }
125 }
126
127 let mut update_first_n = first_n_arc.lock();
129 let should_keep = {
130 let smallest = update_first_n.peek().unwrap();
131 match sorted_by(&smallest.1, &indexed_item.1) {
132 Ordering::Less => true,
133 Ordering::Greater => false,
134 Ordering::Equal => indexed_item.0 < smallest.0,
136 }
137 };
138 if should_keep {
139 update_first_n.pop();
140 update_first_n.push(indexed_item);
141 let mut update_smallest_kept = smallest_kept_arc.write();
142 *update_smallest_kept = update_first_n.peek().unwrap().to_owned();
143 }
144 })
145 })
146 .buffer_unordered(num_cpus::get() * 4);
147 while let Some(_task) = ongoing_tasks.next().await {}
148 }
149 futures::stream::iter(
150 std::sync::Arc::try_unwrap(first_n_mutex)
151 .expect("Dangling references to mutex")
152 .into_inner()
153 .into_sorted_vec()
154 .into_iter()
155 .map(|(_idx, item)| item) .collect::<Vec<_>>()
157 .into_iter(),
158 )
159}
160
161#[async_trait]
162#[cfg(not(any(feature = "tokio", feature = "async-std")))]
163impl<SInput, T, F> KeepFirstN<T, F> for SInput
164where
165 SInput: Stream<Item = T> + Send + Unpin,
166 T: Clone + Send + std::marker::Sync,
167 F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + 'static,
168{
169 #[instrument(skip(self, sorted_by))]
170 async fn keep_first_n(
171 mut self,
172 n: usize,
173 sorted_by: F,
174 ) -> futures::stream::Iter<std::vec::IntoIter<T>> {
175 let mut first_n = BinaryHeap::with_capacity_by(n, |a, b| match sorted_by(a, b) {
177 Ordering::Less => Ordering::Greater,
178 Ordering::Equal => Ordering::Equal,
179 Ordering::Greater => Ordering::Less,
180 });
181
182 while first_n.len() < n {
183 if let Some(item) = self.next().await {
184 first_n.push(item);
185 } else {
186 break;
187 }
188 }
189
190 if first_n.len() < n {
192 return futures::stream::iter(first_n.into_sorted_vec().into_iter());
193 }
194
195 let first_n_mutex = parking_lot::Mutex::new(first_n);
198 let smallest_kept =
199 parking_lot::RwLock::new(first_n_mutex.lock().peek().unwrap().to_owned());
200
201 self.for_each_concurrent(
202 256,
203 |item| async {
204 if sorted_by(&*smallest_kept.read(), &item) == Ordering::Less {
205 let mut first_n_mut = first_n_mutex.lock();
206 first_n_mut.pop();
207 first_n_mut.push(item);
208 let mut update_smallest_kept = smallest_kept.write();
209 *update_smallest_kept = first_n_mut.peek().unwrap().to_owned();
210 }
211 },
212 )
213 .await;
214
215 futures::stream::iter(first_n_mutex.into_inner().into_sorted_vec().into_iter())
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::KeepFirstN;
222 use futures::stream::StreamExt;
223
224 #[tokio::test]
225 async fn keep_first_n() {
226 assert_eq!(
227 futures::stream::iter(1..10)
228 .keep_first_n(5, |a, b| (a % 2).cmp(&(b % 2))) .await
230 .keep_first_n(2, |a, b| a.cmp(b)) .await
232 .collect::<Vec<_>>()
233 .await,
234 vec![9, 7]
235 );
236 }
237}