marigold_impl/
keep_first_n.rs1use async_trait::async_trait;
2use binary_heap_plus::BinaryHeap;
3use futures::stream::Stream;
4use futures::stream::StreamExt;
5use std::cmp::Ordering;
6#[cfg(any(feature = "tokio", feature = "async-std"))]
7use std::ops::Deref;
8use tracing::instrument;
9
10#[async_trait]
11pub trait KeepFirstN<T, F>
12where
13 F: Fn(&T, &T) -> Ordering,
14{
15 async fn keep_first_n(
18 self,
19 n: usize,
20 sorted_by: F,
21 ) -> futures::stream::Iter<std::vec::IntoIter<T>>;
22}
23
24#[cfg(any(feature = "tokio", feature = "async-std"))]
25#[async_trait]
26impl<SInput, T, F> KeepFirstN<T, F> for SInput
27where
28 SInput: Stream<Item = T> + Send + Unpin + std::marker::Sync + 'static,
29 T: Clone + Send + std::marker::Sync + std::fmt::Debug + 'static,
30 F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + std::marker::Copy + 'static,
31{
32 #[instrument(skip(self, sorted_by))]
33 async fn keep_first_n(
34 mut self,
35 n: usize,
36 sorted_by: F,
37 ) -> futures::stream::Iter<std::vec::IntoIter<T>> {
38 let first_n = BinaryHeap::with_capacity_by(n, move |a, b| sorted_by(a, b).reverse());
40 impl_keep_first_n(self, first_n, n, sorted_by).await
41 }
42}
43
44#[cfg(any(feature = "tokio", feature = "async-std"))]
49async fn impl_keep_first_n<SInput, T, F, FReversed>(
50 mut sinput: SInput,
51 mut first_n: BinaryHeap<T, binary_heap_plus::FnComparator<FReversed>>,
52 n: usize,
53 sorted_by: F,
54) -> futures::stream::Iter<std::vec::IntoIter<T>>
55where
56 SInput: Stream<Item = T> + Send + Unpin + std::marker::Sync + 'static,
57 T: Clone + Send + std::marker::Sync + std::fmt::Debug + 'static,
58 F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + std::marker::Copy + 'static,
59 FReversed: Fn(&T, &T) -> std::cmp::Ordering + Clone + Send + 'static,
60{
61 while first_n.len() < n {
63 if let Some(item) = sinput.next().await {
64 first_n.push(item);
65 } else {
66 break;
67 }
68 }
69
70 if first_n.len() < n {
72 return futures::stream::iter(first_n.into_sorted_vec().into_iter());
73 }
74
75 let first_n_mutex = std::sync::Arc::new(parking_lot::Mutex::new(first_n));
79 let smallest_kept = std::sync::Arc::new(parking_lot::RwLock::new(
80 first_n_mutex.lock().peek().unwrap().to_owned(),
81 ));
82 {
83 let first_n_arc = first_n_mutex.clone();
84 let smallest_kept_arc = smallest_kept.clone();
85 let mut ongoing_tasks = sinput
86 .map(move |item| {
87 let first_n_arc = first_n_arc.clone();
88 let smallest_kept_arc = smallest_kept_arc.clone();
89 crate::async_runtime::spawn(async move {
90 if sorted_by(smallest_kept_arc.read().deref(), &item) == Ordering::Less {
91 let mut update_first_n = first_n_arc.lock();
92 update_first_n.pop();
93 update_first_n.push(item);
94 let mut update_smallest_kept = smallest_kept_arc.write();
95 *update_smallest_kept = update_first_n.peek().unwrap().to_owned();
96 }
97 })
98 })
99 .buffer_unordered(num_cpus::get() * 4);
100 while let Some(_task) = ongoing_tasks.next().await {}
101 }
102 futures::stream::iter(
103 std::sync::Arc::try_unwrap(first_n_mutex)
104 .expect("Dangling references to mutex")
105 .into_inner()
106 .into_sorted_vec()
107 .into_iter(),
108 )
109}
110
111#[async_trait]
112#[cfg(not(any(feature = "tokio", feature = "async-std")))]
113impl<SInput, T, F> KeepFirstN<T, F> for SInput
114where
115 SInput: Stream<Item = T> + Send + Unpin,
116 T: Clone + Send + std::marker::Sync,
117 F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + 'static,
118{
119 #[instrument(skip(self, sorted_by))]
120 async fn keep_first_n(
121 mut self,
122 n: usize,
123 sorted_by: F,
124 ) -> futures::stream::Iter<std::vec::IntoIter<T>> {
125 let mut first_n = BinaryHeap::with_capacity_by(n, |a, b| match sorted_by(a, b) {
127 Ordering::Less => Ordering::Greater,
128 Ordering::Equal => Ordering::Equal,
129 Ordering::Greater => Ordering::Less,
130 });
131
132 while first_n.len() < n {
133 if let Some(item) = self.next().await {
134 first_n.push(item);
135 } else {
136 break;
137 }
138 }
139
140 if first_n.len() < n {
142 return futures::stream::iter(first_n.into_sorted_vec().into_iter());
143 }
144
145 let first_n_mutex = parking_lot::Mutex::new(first_n);
148 let smallest_kept =
149 parking_lot::RwLock::new(first_n_mutex.lock().peek().unwrap().to_owned());
150
151 self.for_each_concurrent(
152 256,
153 |item| async {
154 if sorted_by(&*smallest_kept.read(), &item) == Ordering::Less {
155 let mut first_n_mut = first_n_mutex.lock();
156 first_n_mut.pop();
157 first_n_mut.push(item);
158 let mut update_smallest_kept = smallest_kept.write();
159 *update_smallest_kept = first_n_mut.peek().unwrap().to_owned();
160 }
161 },
162 )
163 .await;
164
165 futures::stream::iter(first_n_mutex.into_inner().into_sorted_vec().into_iter())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::KeepFirstN;
172 use futures::stream::StreamExt;
173
174 #[tokio::test]
175 async fn keep_first_n() {
176 assert_eq!(
177 futures::stream::iter(1..10)
178 .keep_first_n(5, |a, b| (a % 2).cmp(&(b % 2))) .await
180 .keep_first_n(2, |a, b| a.cmp(b)) .await
182 .collect::<Vec<_>>()
183 .await,
184 vec![9, 7]
185 );
186 }
187}