orx_concurrent_iter/pullers/flattened_chunk_puller.rs
1use super::ChunkPuller;
2
3/// Flattened version of a [`ChunkPuller`] which conveniently implements [`Iterator`].
4///
5/// Similar to the regular chunk puller, a flattened chunk puller is created from and
6/// linked to and pulls its elements from a [`ConcurrentIter`].
7///
8/// It can be created by calling the [`flattened`] method on a chunk puller that is
9/// created by the [`chunk_puller`] method of a concurrent iterator.
10///
11/// [`ChunkPuller`]: crate::ChunkPuller
12/// [`ConcurrentIter`]: crate::ConcurrentIter
13/// [`chunk_puller`]: crate::ConcurrentIter::chunk_puller
14/// [`flattened`]: crate::ChunkPuller::flattened
15///
16/// # Examples
17///
18/// See the [`ItemPuller`] documentation for the notes on how the pullers bring the convenience of
19/// Iterator methods to concurrent programs, which is demonstrated by a 4-line implementation of the
20/// parallelized [`reduce`]. We can add the iteration-by-chunks optimization on top of this while
21/// keeping the implementation as simple and fitting 4-lines due to the fact that flattened chunk
22/// puller implements Iterator.
23///
24/// In the following code, the sums are computed by 8 threads while each thread pulls elements in
25/// chunks of 64.
26///
27/// ```
28/// use orx_concurrent_iter::*;
29///
30/// fn parallel_reduce<T, F>(
31/// num_threads: usize,
32/// chunk: usize,
33/// con_iter: impl ConcurrentIter<Item = T>,
34/// reduce: F,
35/// ) -> Option<T>
36/// where
37/// T: Send + Sync,
38/// F: Fn(T, T) -> T + Send + Sync,
39/// {
40/// std::thread::scope(|s| {
41/// (0..num_threads)
42/// .map(|_| s.spawn(|| con_iter.chunk_puller(chunk).flattened().reduce(&reduce))) // reduce inside each thread
43/// .filter_map(|x| x.join().unwrap()) // join threads
44/// .reduce(&reduce) // reduce thread results to final result
45/// })
46/// }
47///
48/// let sum = parallel_reduce(8, 64, (0..0).into_con_iter(), |a, b| a + b);
49/// assert_eq!(sum, None);
50///
51/// let n = 10_000;
52/// let data: Vec<_> = (0..n).collect();
53/// let sum = parallel_reduce(8, 64, data.con_iter().copied(), |a, b| a + b);
54/// assert_eq!(sum, Some(n * (n - 1) / 2));
55/// ```
56///
57/// [`reduce`]: Iterator::reduce
58/// [`ItemPuller`]: crate::ItemPuller
59pub struct FlattenedChunkPuller<'c, P>
60where
61 P: ChunkPuller + 'c,
62{
63 puller: P,
64 current_chunk: P::Chunk<'c>,
65}
66
67impl<P> From<P> for FlattenedChunkPuller<'_, P>
68where
69 P: ChunkPuller,
70{
71 fn from(puller: P) -> Self {
72 Self {
73 puller,
74 current_chunk: Default::default(),
75 }
76 }
77}
78
79impl<P> FlattenedChunkPuller<'_, P>
80where
81 P: ChunkPuller,
82{
83 /// Converts the flattened chunk puller back to the chunk puller it
84 /// is created from.
85 pub fn into_chunk_puller(self) -> P {
86 self.puller
87 }
88
89 fn next_chunk(&mut self) -> Option<P::ChunkItem> {
90 let puller = unsafe { &mut *(&mut self.puller as *mut P) };
91 match puller.pull() {
92 Some(chunk) => {
93 self.current_chunk = chunk;
94 self.next()
95 }
96 None => None,
97 }
98 }
99}
100
101impl<P> Iterator for FlattenedChunkPuller<'_, P>
102where
103 P: ChunkPuller,
104{
105 type Item = P::ChunkItem;
106
107 fn next(&mut self) -> Option<Self::Item> {
108 let next = self.current_chunk.next();
109 match next.is_some() {
110 true => next,
111 false => self.next_chunk(),
112 }
113 }
114}