marigold_impl/
keep_first_n.rs

1use async_trait::async_trait;
2use binary_heap_plus::BinaryHeap;
3use futures::stream::Stream;
4use futures::stream::StreamExt;
5use std::cmp::Ordering;
6#[cfg(any(feature = "tokio", feature = "async-std"))]
7use std::ops::Deref;
8use tracing::instrument;
9
10#[async_trait]
11pub trait KeepFirstN<T, F>
12where
13    F: Fn(&T, &T) -> Ordering,
14{
15    /// Takes the largest N values according to the sorted function, returned in descending order
16    /// (max first). Exhausts the stream.
17    async fn keep_first_n(
18        self,
19        n: usize,
20        sorted_by: F,
21    ) -> futures::stream::Iter<std::vec::IntoIter<T>>;
22}
23
24#[cfg(any(feature = "tokio", feature = "async-std"))]
25#[async_trait]
26impl<SInput, T, F> KeepFirstN<T, F> for SInput
27where
28    SInput: Stream<Item = T> + Send + Unpin + std::marker::Sync + 'static,
29    T: Clone + Send + std::marker::Sync + std::fmt::Debug + 'static,
30    F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + std::marker::Copy + 'static,
31{
32    #[instrument(skip(self, sorted_by))]
33    async fn keep_first_n(
34        mut self,
35        n: usize,
36        sorted_by: F,
37    ) -> futures::stream::Iter<std::vec::IntoIter<T>> {
38        // use the reverse ordering so that the smallest value is always the first to pop.
39        let first_n = BinaryHeap::with_capacity_by(n, move |a, b| sorted_by(a, b).reverse());
40        impl_keep_first_n(self, first_n, n, sorted_by).await
41    }
42}
43
44/// Internal logic for keep_first_n. This is in a separate function so that we can get the full
45/// type of the binary heap, which includes a lambda for reversing the ordering fromt the passed
46/// sort_by function. By declaring a new function, we can use generics to describe its type, and
47/// then can use that type while unsafely casting pointers.
48#[cfg(any(feature = "tokio", feature = "async-std"))]
49async fn impl_keep_first_n<SInput, T, F, FReversed>(
50    mut sinput: SInput,
51    mut first_n: BinaryHeap<T, binary_heap_plus::FnComparator<FReversed>>,
52    n: usize,
53    sorted_by: F,
54) -> futures::stream::Iter<std::vec::IntoIter<T>>
55where
56    SInput: Stream<Item = T> + Send + Unpin + std::marker::Sync + 'static,
57    T: Clone + Send + std::marker::Sync + std::fmt::Debug + 'static,
58    F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + std::marker::Copy + 'static,
59    FReversed: Fn(&T, &T) -> std::cmp::Ordering + Clone + Send + 'static,
60{
61    // Iterate through values in a single thread until we have seen n values.
62    while first_n.len() < n {
63        if let Some(item) = sinput.next().await {
64            first_n.push(item);
65        } else {
66            break;
67        }
68    }
69
70    // If we have exhausted the stream before reaching n values, we can exit early.
71    if first_n.len() < n {
72        return futures::stream::iter(first_n.into_sorted_vec().into_iter());
73    }
74
75    // Otherwise, we can check each remaining value in the stream against the smallest
76    // kept value, updating the kept values only when a keepable value is found. This
77    // is done by spawning tasks, which can be parallelized by multithreaded runtimes.
78    let first_n_mutex = std::sync::Arc::new(parking_lot::Mutex::new(first_n));
79    let smallest_kept = std::sync::Arc::new(parking_lot::RwLock::new(
80        first_n_mutex.lock().peek().unwrap().to_owned(),
81    ));
82    {
83        let first_n_arc = first_n_mutex.clone();
84        let smallest_kept_arc = smallest_kept.clone();
85        let mut ongoing_tasks = sinput
86            .map(move |item| {
87                let first_n_arc = first_n_arc.clone();
88                let smallest_kept_arc = smallest_kept_arc.clone();
89                crate::async_runtime::spawn(async move {
90                    if sorted_by(smallest_kept_arc.read().deref(), &item) == Ordering::Less {
91                        let mut update_first_n = first_n_arc.lock();
92                        update_first_n.pop();
93                        update_first_n.push(item);
94                        let mut update_smallest_kept = smallest_kept_arc.write();
95                        *update_smallest_kept = update_first_n.peek().unwrap().to_owned();
96                    }
97                })
98            })
99            .buffer_unordered(num_cpus::get() * 4);
100        while let Some(_task) = ongoing_tasks.next().await {}
101    }
102    futures::stream::iter(
103        std::sync::Arc::try_unwrap(first_n_mutex)
104            .expect("Dangling references to mutex")
105            .into_inner()
106            .into_sorted_vec()
107            .into_iter(),
108    )
109}
110
111#[async_trait]
112#[cfg(not(any(feature = "tokio", feature = "async-std")))]
113impl<SInput, T, F> KeepFirstN<T, F> for SInput
114where
115    SInput: Stream<Item = T> + Send + Unpin,
116    T: Clone + Send + std::marker::Sync,
117    F: Fn(&T, &T) -> Ordering + std::marker::Send + std::marker::Sync + 'static,
118{
119    #[instrument(skip(self, sorted_by))]
120    async fn keep_first_n(
121        mut self,
122        n: usize,
123        sorted_by: F,
124    ) -> futures::stream::Iter<std::vec::IntoIter<T>> {
125        // use the reverse ordering so that the smallest value is always the first to pop.
126        let mut first_n = BinaryHeap::with_capacity_by(n, |a, b| match sorted_by(a, b) {
127            Ordering::Less => Ordering::Greater,
128            Ordering::Equal => Ordering::Equal,
129            Ordering::Greater => Ordering::Less,
130        });
131
132        while first_n.len() < n {
133            if let Some(item) = self.next().await {
134                first_n.push(item);
135            } else {
136                break;
137            }
138        }
139
140        // If we have exhausted the stream before reaching n values, we can exit early.
141        if first_n.len() < n {
142            return futures::stream::iter(first_n.into_sorted_vec().into_iter());
143        }
144
145        // Otherwise, we can check each remaining value in the stream against the smallest
146        // kept value, updating the kept values only when a keepable value is found.
147        let first_n_mutex = parking_lot::Mutex::new(first_n);
148        let smallest_kept =
149            parking_lot::RwLock::new(first_n_mutex.lock().peek().unwrap().to_owned());
150
151        self.for_each_concurrent(
152            /* arbitrarily set concurrency limit */ 256,
153            |item| async {
154                if sorted_by(&*smallest_kept.read(), &item) == Ordering::Less {
155                    let mut first_n_mut = first_n_mutex.lock();
156                    first_n_mut.pop();
157                    first_n_mut.push(item);
158                    let mut update_smallest_kept = smallest_kept.write();
159                    *update_smallest_kept = first_n_mut.peek().unwrap().to_owned();
160                }
161            },
162        )
163        .await;
164
165        futures::stream::iter(first_n_mutex.into_inner().into_sorted_vec().into_iter())
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::KeepFirstN;
172    use futures::stream::StreamExt;
173
174    #[tokio::test]
175    async fn keep_first_n() {
176        assert_eq!(
177            futures::stream::iter(1..10)
178                .keep_first_n(5, |a, b| (a % 2).cmp(&(b % 2))) // keep odd numbers
179                .await
180                .keep_first_n(2, |a, b| a.cmp(b)) // keep largest odd 2 numbers
181                .await
182                .collect::<Vec<_>>()
183                .await,
184            vec![9, 7]
185        );
186    }
187}