orx_concurrent_iter/pullers/flattened_chunk_puller.rs
1use super::ChunkPuller;
2
3/// Flattened version of a [`ChunkPuller`] which conveniently implements [`Iterator`].
4///
5/// Similar to the regular chunk puller, a flattened chunk puller is created from and
6/// linked to and pulls its elements from a [`ConcurrentIter`].
7///
8/// It can be created by calling the [`flattened`] method on a chunk puller that is
9/// created by the [`chunk_puller`] method of a concurrent iterator.
10///
11/// [`ChunkPuller`]: crate::ChunkPuller
12/// [`ConcurrentIter`]: crate::ConcurrentIter
13/// [`chunk_puller`]: crate::ConcurrentIter::chunk_puller
14/// [`flattened`]: crate::ChunkPuller::flattened
15///
16/// # Examples
17///
18/// See the [`ItemPuller`] documentation for the notes on how the pullers bring the convenience of
19/// Iterator methods to concurrent programs, which is demonstrated by a 4-line implementation of the
20/// parallelized [`reduce`]. We can add the iteration-by-chunks optimization on top of this while
21/// keeping the implementation as simple and fitting 4-lines due to the fact that flattened chunk
22/// puller implements Iterator.
23///
24/// In the following code, the sums are computed by 8 threads while each thread pulls elements in
25/// chunks of 64.
26///
27/// ```
28/// use orx_concurrent_iter::*;
29///
30/// fn parallel_reduce<T, F>(
31///     num_threads: usize,
32///     chunk: usize,
33///     con_iter: impl ConcurrentIter<Item = T>,
34///     reduce: F,
35/// ) -> Option<T>
36/// where
37///     T: Send + Sync,
38///     F: Fn(T, T) -> T + Send + Sync,
39/// {
40///     std::thread::scope(|s| {
41///         (0..num_threads)
42///             .map(|_| s.spawn(|| con_iter.chunk_puller(chunk).flattened().reduce(&reduce))) // reduce inside each thread
43///             .filter_map(|x| x.join().unwrap()) // join threads
44///             .reduce(&reduce) // reduce thread results to final result
45///     })
46/// }
47///
48/// let sum = parallel_reduce(8, 64, (0..0).into_con_iter(), |a, b| a + b);
49/// assert_eq!(sum, None);
50///
51/// let n = 10_000;
52/// let data: Vec<_> = (0..n).collect();
53/// let sum = parallel_reduce(8, 64, data.con_iter().copied(), |a, b| a + b);
54/// assert_eq!(sum, Some(n * (n - 1) / 2));
55/// ```
56///
57/// [`reduce`]: Iterator::reduce
58/// [`ItemPuller`]: crate::ItemPuller
59pub struct FlattenedChunkPuller<'c, P>
60where
61    P: ChunkPuller + 'c,
62{
63    puller: P,
64    current_chunk: P::Chunk<'c>,
65}
66
67impl<P> From<P> for FlattenedChunkPuller<'_, P>
68where
69    P: ChunkPuller,
70{
71    fn from(puller: P) -> Self {
72        Self {
73            puller,
74            current_chunk: Default::default(),
75        }
76    }
77}
78
79impl<P> FlattenedChunkPuller<'_, P>
80where
81    P: ChunkPuller,
82{
83    /// Converts the flattened chunk puller back to the chunk puller it
84    /// is created from.
85    pub fn into_chunk_puller(self) -> P {
86        self.puller
87    }
88
89    fn next_chunk(&mut self) -> Option<P::ChunkItem> {
90        let puller = unsafe { &mut *(&mut self.puller as *mut P) };
91        match puller.pull() {
92            Some(chunk) => {
93                self.current_chunk = chunk;
94                self.next()
95            }
96            None => None,
97        }
98    }
99}
100
101impl<P> Iterator for FlattenedChunkPuller<'_, P>
102where
103    P: ChunkPuller,
104{
105    type Item = P::ChunkItem;
106
107    fn next(&mut self) -> Option<Self::Item> {
108        let next = self.current_chunk.next();
109        match next.is_some() {
110            true => next,
111            false => self.next_chunk(),
112        }
113    }
114}