Skip to main content

palimpsest_dataflow/algorithms/
prefix_sum.rs

1//! Implementation of Parallel Prefix Sum
2
3use timely::dataflow::Scope;
4
5use crate::lattice::Lattice;
6use crate::operators::*;
7use crate::{ExchangeData, VecCollection};
8
9/// Extension trait for the prefix_sum method.
10pub trait PrefixSum<G: Scope, K, D> {
11    /// Computes the prefix sum for each element in the collection.
12    ///
13    /// The prefix sum is data-parallel, in the sense that the sums are computed independently for
14    /// each key of type `K`. For a single prefix sum this type can be `()`, but this permits the
15    /// more general accumulation of multiple independent sequences.
16    fn prefix_sum<F>(&self, zero: D, combine: F) -> Self
17    where
18        F: Fn(&K, &D, &D) -> D + 'static;
19
20    /// Determine the prefix sum at each element of `location`.
21    fn prefix_sum_at<F>(
22        &self,
23        locations: VecCollection<G, (usize, K)>,
24        zero: D,
25        combine: F,
26    ) -> Self
27    where
28        F: Fn(&K, &D, &D) -> D + 'static;
29}
30
31impl<G, K, D> PrefixSum<G, K, D> for VecCollection<G, ((usize, K), D)>
32where
33    G: Scope<Timestamp: Lattice>,
34    K: ExchangeData + ::std::hash::Hash,
35    D: ExchangeData + ::std::hash::Hash,
36{
37    fn prefix_sum<F>(&self, zero: D, combine: F) -> Self
38    where
39        F: Fn(&K, &D, &D) -> D + 'static,
40    {
41        self.prefix_sum_at(self.map(|(x, _)| x), zero, combine)
42    }
43
44    fn prefix_sum_at<F>(&self, locations: VecCollection<G, (usize, K)>, zero: D, combine: F) -> Self
45    where
46        F: Fn(&K, &D, &D) -> D + 'static,
47    {
48        let combine1 = ::std::rc::Rc::new(combine);
49        let combine2 = combine1.clone();
50
51        let ranges = aggregate(self.clone(), move |k, x, y| (*combine1)(k, x, y));
52        broadcast(ranges, locations, zero, move |k, x, y| (*combine2)(k, x, y))
53    }
54}
55
56/// Accumulate data in `collection` into all powers-of-two intervals containing them.
57pub fn aggregate<G, K, D, F>(
58    collection: VecCollection<G, ((usize, K), D)>,
59    combine: F,
60) -> VecCollection<G, ((usize, usize, K), D)>
61where
62    G: Scope<Timestamp: Lattice>,
63    K: ExchangeData + ::std::hash::Hash,
64    D: ExchangeData + ::std::hash::Hash,
65    F: Fn(&K, &D, &D) -> D + 'static,
66{
67    // initial ranges are at each index, and with width 2^0.
68    let unit_ranges = collection.map(|((index, key), data)| ((index, 0, key), data));
69
70    unit_ranges.iterate(|ranges|
71
72            // Each available range, of size less than usize::max_value(), advertises itself as the range
73            // twice as large, aligned to integer multiples of its size. Each range, which may contain at
74            // most two elements, then summarizes itself using the `combine` function. Finally, we re-add
75            // the initial `unit_ranges` intervals, so that the set of ranges grows monotonically.
76
77            ranges
78                .filter(|&((_pos, log, _), _)| log < 64)
79                .map(|((pos, log, key), data)| ((pos >> 1, log + 1, key), (pos, data)))
80                .reduce(move |&(_pos, _log, ref key), input, output| {
81                    let mut result = (input[0].0).1.clone();
82                    if input.len() > 1 { result = combine(key, &result, &(input[1].0).1); }
83                    output.push((result, 1));
84                })
85                .concat(&unit_ranges.enter(&ranges.scope())))
86}
87
88/// Produces the accumulated values at each of the `usize` locations in `queries`.
89pub fn broadcast<G, K, D, F>(
90    ranges: VecCollection<G, ((usize, usize, K), D)>,
91    queries: VecCollection<G, (usize, K)>,
92    zero: D,
93    combine: F,
94) -> VecCollection<G, ((usize, K), D)>
95where
96    G: Scope<Timestamp: Lattice + Ord + ::std::fmt::Debug>,
97    K: ExchangeData + ::std::hash::Hash,
98    D: ExchangeData + ::std::hash::Hash,
99    F: Fn(&K, &D, &D) -> D + 'static,
100{
101    let zero0 = zero.clone();
102    let zero1 = zero.clone();
103    let zero2 = zero.clone();
104
105    // The `queries` collection may not line up with an existing element of `ranges`, and so we must
106    // track down the first range that matches. If it doesn't exist, we will need to produce a zero
107    // value. We could produce the full path from (0, key) to (idx, key), and aggregate any and all
108    // matches. This has the defect of being n log n rather than linear, as the root ranges will be
109    // replicated for each query.
110    //
111    // I think it works to have each (idx, key) propose each of the intervals it knows should be used
112    // to assemble its input. We then `distinct` these and intersect them with the offered `ranges`,
113    // essentially performing a semijoin. We then perform the unfolding, where we might need to use
114    // empty ranges if none exist in `ranges`.
115
116    // We extract desired ranges for each `idx` from its binary representation: each set bit requires
117    // the contribution of a range, and we call out each of these. This could produce a super-linear
118    // amount of data (multiple requests for the roots), but it will be compacted down in `distinct`.
119    // We could reduce the amount of data by producing the requests iteratively, with a distinct in
120    // the loop to pre-suppress duplicate requests. This comes at a complexity cost, though.
121    let requests = queries
122        .flat_map(
123            |(idx, key)| {
124                (0 .. 64)
125                    .filter(move |i| (idx & (1usize << i)) != 0)    // set bits require help.
126                    .map(move |i| ((idx >> i) - 1, i, key.clone()))
127            }, // width 2^i interval.
128        )
129        .distinct();
130
131    // Acquire each requested range.
132    let full_ranges = ranges.semijoin(&requests);
133
134    // Each requested range should exist, even if as a zero range, for correct reconstruction.
135    let zero_ranges = full_ranges
136        .map(move |((idx, log, key), _)| ((idx, log, key), zero0.clone()))
137        .negate()
138        .concat(&requests.map(move |(idx, log, key)| ((idx, log, key), zero1.clone())));
139
140    // Merge occupied and empty ranges.
141    let used_ranges = full_ranges.concat(&zero_ranges);
142
143    // Each key should initiate a value of `zero` at position `0`.
144    let init_states = queries
145        .map(move |(_, key)| ((0, key), zero2.clone()))
146        .distinct();
147
148    // Iteratively expand assigned values by joining existing ranges with current assignments.
149    init_states
150        .iterate(|states| {
151            used_ranges
152                .enter(&states.scope())
153                .map(|((pos, log, key), data)| ((pos << log, key), (log, data)))
154                .join_map(states, move |&(pos, ref key), &(log, ref data), state| {
155                    ((pos + (1 << log), key.clone()), combine(key, state, data))
156                })
157                .concat(&init_states.enter(&states.scope()))
158                .distinct()
159        })
160        .semijoin(&queries)
161}