Skip to main content

marigold_impl/
keep_first_n.rs

1use async_trait::async_trait;
2use binary_heap_plus::BinaryHeap;
3use futures::stream::Stream;
4use futures::stream::StreamExt;
5use std::cmp::Ordering;
6use tracing::instrument;
7
8#[async_trait]
9pub trait KeepFirstN<T, F>
10where
11    F: Fn(&T, &T) -> Ordering,
12{
13    /// Takes the largest N values according to the sorted function, returned in descending order
14    /// (max first). Exhausts the stream.
15    async fn keep_first_n(
16        self,
17        n: usize,
18        sorted_by: F,
19    ) -> futures::stream::Iter<std::vec::IntoIter<T>>;
20}
21
22#[cfg(any(feature = "tokio", feature = "async-std"))]
23#[async_trait]
24impl<SInput, T, F> KeepFirstN<T, F> for SInput
25where
26    SInput: Stream<Item = T> + Send + Unpin + std::marker::Sync + 'static,
27    T: Clone + Send + std::marker::Sync + std::fmt::Debug + 'static,
28    F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + std::marker::Copy + 'static,
29{
30    #[instrument(skip(self, sorted_by))]
31    async fn keep_first_n(
32        mut self,
33        n: usize,
34        sorted_by: F,
35    ) -> futures::stream::Iter<std::vec::IntoIter<T>> {
36        // use the reverse ordering so that the smallest value is always the first to pop.
37        let first_n = BinaryHeap::with_capacity_by(n, move |a, b| sorted_by(a, b).reverse());
38        impl_keep_first_n(self, first_n, n, sorted_by).await
39    }
40}
41
42/// Internal logic for keep_first_n. This is in a separate function so that we can get the full
43/// type of the binary heap, which includes a lambda for reversing the ordering fromt the passed
44/// sort_by function. By declaring a new function, we can use generics to describe its type, and
45/// then can use that type while unsafely casting pointers.
46///
47/// This implementation wraps items with their stream index to provide deterministic tie-breaking
48/// when the user's comparison function returns Equal. Lower indices (earlier in stream) are
49/// preferred to ensure consistent results even with parallel processing.
50#[cfg(any(feature = "tokio", feature = "async-std"))]
51async fn impl_keep_first_n<SInput, T, F, FReversed>(
52    sinput: SInput,
53    _first_n: BinaryHeap<T, binary_heap_plus::FnComparator<FReversed>>,
54    n: usize,
55    sorted_by: F,
56) -> futures::stream::Iter<std::vec::IntoIter<T>>
57where
58    SInput: Stream<Item = T> + Send + Unpin + std::marker::Sync + 'static,
59    T: Clone + Send + std::marker::Sync + std::fmt::Debug + 'static,
60    F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + std::marker::Copy + 'static,
61    FReversed: Fn(&T, &T) -> std::cmp::Ordering + Clone + Send + 'static,
62{
63    // Add indices to items for deterministic tie-breaking
64    let mut indexed_stream = sinput.enumerate();
65
66    // Create a heap that stores (index, item) tuples with tie-breaking comparator
67    let indexed_comparator = move |a: &(usize, T), b: &(usize, T)| {
68        match sorted_by(&a.1, &b.1) {
69            Ordering::Less => Ordering::Less,
70            Ordering::Greater => Ordering::Greater,
71            // When equal, prefer lower index (earlier in stream)
72            Ordering::Equal => a.0.cmp(&b.0),
73        }
74    };
75    let mut first_n =
76        BinaryHeap::with_capacity_by(n, move |a, b| indexed_comparator(a, b).reverse());
77
78    // Iterate through values in a single thread until we have seen n values.
79    while first_n.len() < n {
80        if let Some(indexed_item) = indexed_stream.next().await {
81            first_n.push(indexed_item);
82        } else {
83            break;
84        }
85    }
86
87    // If we have exhausted the stream before reaching n values, we can exit early.
88    if first_n.len() < n {
89        return futures::stream::iter(
90            first_n
91                .into_sorted_vec()
92                .into_iter()
93                .map(|(_idx, item)| item) // Unwrap indices
94                .collect::<Vec<_>>()
95                .into_iter(),
96        );
97    }
98
99    // Otherwise, we can check each remaining value in the stream against the smallest
100    // kept value, updating the kept values only when a keepable value is found. This
101    // is done by spawning tasks, which can be parallelized by multithreaded runtimes.
102    //
103    // The check and update are performed atomically under the same mutex to avoid a
104    // TOCTOU race where two tasks could both observe the same "smallest" and both
105    // decide to replace it, leading to non-deterministic tie-breaking.
106    let first_n_mutex = std::sync::Arc::new(parking_lot::Mutex::new(first_n));
107    let smallest_kept = std::sync::Arc::new(parking_lot::RwLock::new(
108        first_n_mutex.lock().peek().unwrap().to_owned(),
109    ));
110    {
111        let first_n_arc = first_n_mutex.clone();
112        let smallest_kept_arc = smallest_kept.clone();
113        let mut ongoing_tasks = indexed_stream
114            .map(move |indexed_item| {
115                let first_n_arc = first_n_arc.clone();
116                let smallest_kept_arc = smallest_kept_arc.clone();
117                crate::async_runtime::spawn(async move {
118                    // Fast pre-check under read lock to skip clearly inferior items
119                    // without contending on the mutex.
120                    {
121                        let smallest = smallest_kept_arc.read();
122                        if sorted_by(&smallest.1, &indexed_item.1) == Ordering::Greater {
123                            return;
124                        }
125                    }
126
127                    // Atomically check and update under the mutex to prevent TOCTOU.
128                    let mut update_first_n = first_n_arc.lock();
129                    let should_keep = {
130                        let smallest = update_first_n.peek().unwrap();
131                        match sorted_by(&smallest.1, &indexed_item.1) {
132                            Ordering::Less => true,
133                            Ordering::Greater => false,
134                            // When equal, prefer item with lower index (earlier in stream)
135                            Ordering::Equal => indexed_item.0 < smallest.0,
136                        }
137                    };
138                    if should_keep {
139                        update_first_n.pop();
140                        update_first_n.push(indexed_item);
141                        let mut update_smallest_kept = smallest_kept_arc.write();
142                        *update_smallest_kept = update_first_n.peek().unwrap().to_owned();
143                    }
144                })
145            })
146            .buffer_unordered(num_cpus::get() * 4);
147        while let Some(_task) = ongoing_tasks.next().await {}
148    }
149    futures::stream::iter(
150        std::sync::Arc::try_unwrap(first_n_mutex)
151            .expect("Dangling references to mutex")
152            .into_inner()
153            .into_sorted_vec()
154            .into_iter()
155            .map(|(_idx, item)| item) // Unwrap indices
156            .collect::<Vec<_>>()
157            .into_iter(),
158    )
159}
160
161#[async_trait]
162#[cfg(not(any(feature = "tokio", feature = "async-std")))]
163impl<SInput, T, F> KeepFirstN<T, F> for SInput
164where
165    SInput: Stream<Item = T> + Send + Unpin,
166    T: Clone + Send + std::marker::Sync,
167    F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + 'static,
168{
169    #[instrument(skip(self, sorted_by))]
170    async fn keep_first_n(
171        mut self,
172        n: usize,
173        sorted_by: F,
174    ) -> futures::stream::Iter<std::vec::IntoIter<T>> {
175        // use the reverse ordering so that the smallest value is always the first to pop.
176        let mut first_n = BinaryHeap::with_capacity_by(n, |a, b| match sorted_by(a, b) {
177            Ordering::Less => Ordering::Greater,
178            Ordering::Equal => Ordering::Equal,
179            Ordering::Greater => Ordering::Less,
180        });
181
182        while first_n.len() < n {
183            if let Some(item) = self.next().await {
184                first_n.push(item);
185            } else {
186                break;
187            }
188        }
189
190        // If we have exhausted the stream before reaching n values, we can exit early.
191        if first_n.len() < n {
192            return futures::stream::iter(first_n.into_sorted_vec().into_iter());
193        }
194
195        // Otherwise, we can check each remaining value in the stream against the smallest
196        // kept value, updating the kept values only when a keepable value is found.
197        let first_n_mutex = parking_lot::Mutex::new(first_n);
198        let smallest_kept =
199            parking_lot::RwLock::new(first_n_mutex.lock().peek().unwrap().to_owned());
200
201        self.for_each_concurrent(
202            /* arbitrarily set concurrency limit */ 256,
203            |item| async {
204                if sorted_by(&*smallest_kept.read(), &item) == Ordering::Less {
205                    let mut first_n_mut = first_n_mutex.lock();
206                    first_n_mut.pop();
207                    first_n_mut.push(item);
208                    let mut update_smallest_kept = smallest_kept.write();
209                    *update_smallest_kept = first_n_mut.peek().unwrap().to_owned();
210                }
211            },
212        )
213        .await;
214
215        futures::stream::iter(first_n_mutex.into_inner().into_sorted_vec().into_iter())
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::KeepFirstN;
222    use futures::stream::StreamExt;
223
224    #[tokio::test]
225    async fn keep_first_n() {
226        assert_eq!(
227            futures::stream::iter(1..10)
228                .keep_first_n(5, |a, b| (a % 2).cmp(&(b % 2))) // keep odd numbers
229                .await
230                .keep_first_n(2, |a, b| a.cmp(b)) // keep largest odd 2 numbers
231                .await
232                .collect::<Vec<_>>()
233                .await,
234            vec![9, 7]
235        );
236    }
237}