Skip to main content

differential_dataflow/algorithms/
prefix_sum.rs

1//! Implementation of Parallel Prefix Sum
2
3use timely::progress::Timestamp;
4
5use crate::{VecCollection, ExchangeData};
6use crate::lattice::Lattice;
7use crate::operators::*;
8
9/// Extension trait for the prefix_sum method.
10pub trait PrefixSum<'scope, T: Timestamp, K, D> {
11    /// Computes the prefix sum for each element in the collection.
12    ///
13    /// The prefix sum is data-parallel, in the sense that the sums are computed independently for
14    /// each key of type `K`. For a single prefix sum this type can be `()`, but this permits the
15    /// more general accumulation of multiple independent sequences.
16    fn prefix_sum<F>(self, zero: D, combine: F) -> Self where F: Fn(&K,&D,&D)->D + 'static;
17
18    /// Determine the prefix sum at each element of `location`.
19    fn prefix_sum_at<F>(self, locations: VecCollection<'scope, T, (usize, K)>, zero: D, combine: F) -> Self where F: Fn(&K,&D,&D)->D + 'static;
20}
21
22impl<'scope, T, K, D> PrefixSum<'scope, T, K, D> for VecCollection<'scope, T, ((usize, K), D)>
23where
24    T: Timestamp + Lattice,
25    K: ExchangeData + ::std::hash::Hash,
26    D: ExchangeData + ::std::hash::Hash,
27{
28    fn prefix_sum<F>(self, zero: D, combine: F) -> Self where F: Fn(&K,&D,&D)->D + 'static {
29        self.clone().prefix_sum_at(self.map(|(x,_)| x), zero, combine)
30    }
31
32    fn prefix_sum_at<F>(self, locations: VecCollection<'scope, T, (usize, K)>, zero: D, combine: F) -> Self where F: Fn(&K,&D,&D)->D + 'static {
33
34        let combine1 = ::std::rc::Rc::new(combine);
35        let combine2 = combine1.clone();
36
37        let ranges = aggregate(self.clone(), move |k,x,y| (*combine1)(k,x,y));
38        broadcast(ranges, locations, zero, move |k,x,y| (*combine2)(k,x,y))
39    }
40}
41
42/// Accumulate data in `collection` into all powers-of-two intervals containing them.
43pub fn aggregate<'scope, T, K, D, F>(collection: VecCollection<'scope, T, ((usize, K), D)>, combine: F) -> VecCollection<'scope, T, ((usize, usize, K), D)>
44where
45    T: Timestamp + Lattice,
46    K: ExchangeData + ::std::hash::Hash,
47    D: ExchangeData + ::std::hash::Hash,
48    F: Fn(&K,&D,&D)->D + 'static,
49{
50    // initial ranges are at each index, and with width 2^0.
51    let unit_ranges = collection.map(|((index, key), data)| ((index, 0, key), data));
52
53    unit_ranges
54        .clone()
55        .iterate(|scope, ranges| {
56
57            // Each available range, of size less than usize::max_value(), advertises itself as the range
58            // twice as large, aligned to integer multiples of its size. Each range, which may contain at
59            // most two elements, then summarizes itself using the `combine` function. Finally, we re-add
60            // the initial `unit_ranges` intervals, so that the set of ranges grows monotonically.
61
62            let unit_ranges = unit_ranges.enter(scope);
63            ranges
64                .filter(|&((_pos, log, _), _)| log < 64)
65                .map(|((pos, log, key), data)| ((pos >> 1, log + 1, key), (pos, data)))
66                .reduce(move |&(_pos, _log, ref key), input, output| {
67                    let mut result = (input[0].0).1.clone();
68                    if input.len() > 1 { result = combine(key, &result, &(input[1].0).1); }
69                    output.push((result, 1));
70                })
71                .concat(unit_ranges)
72        })
73}
74
75/// Produces the accumulated values at each of the `usize` locations in `queries`.
76pub fn broadcast<'scope, T, K, D, F>(
77    ranges: VecCollection<'scope, T, ((usize, usize, K), D)>,
78    queries: VecCollection<'scope, T, (usize, K)>,
79    zero: D,
80    combine: F) -> VecCollection<'scope, T, ((usize, K), D)>
81where
82    T: Timestamp + Lattice + Ord + ::std::fmt::Debug,
83    K: ExchangeData + ::std::hash::Hash,
84    D: ExchangeData + ::std::hash::Hash,
85    F: Fn(&K,&D,&D)->D + 'static,
86{
87    let zero0 = zero.clone();
88    let zero1 = zero.clone();
89    let zero2 = zero.clone();
90
91    // The `queries` collection may not line up with an existing element of `ranges`, and so we must
92    // track down the first range that matches. If it doesn't exist, we will need to produce a zero
93    // value. We could produce the full path from (0, key) to (idx, key), and aggregate any and all
94    // matches. This has the defect of being n log n rather than linear, as the root ranges will be
95    // replicated for each query.
96    //
97    // I think it works to have each (idx, key) propose each of the intervals it knows should be used
98    // to assemble its input. We then `distinct` these and intersect them with the offered `ranges`,
99    // essentially performing a semijoin. We then perform the unfolding, where we might need to use
100    // empty ranges if none exist in `ranges`.
101
102    // We extract desired ranges for each `idx` from its binary representation: each set bit requires
103    // the contribution of a range, and we call out each of these. This could produce a super-linear
104    // amount of data (multiple requests for the roots), but it will be compacted down in `distinct`.
105    // We could reduce the amount of data by producing the requests iteratively, with a distinct in
106    // the loop to pre-suppress duplicate requests. This comes at a complexity cost, though.
107    let requests =
108        queries
109            .clone()
110            .flat_map(|(idx, key)|
111                (0 .. 64)
112                    .filter(move |i| (idx & (1usize << i)) != 0)    // set bits require help.
113                    .map(move |i| ((idx >> i) - 1, i, key.clone())) // width 2^i interval.
114            )
115            .distinct();
116
117    // Acquire each requested range.
118    let full_ranges =
119        ranges
120            .semijoin(requests.clone());
121
122    // Each requested range should exist, even if as a zero range, for correct reconstruction.
123    let zero_ranges =
124        full_ranges
125            .clone()
126            .map(move |((idx, log, key), _)| ((idx, log, key), zero0.clone()))
127            .negate()
128            .concat(requests.map(move |(idx, log, key)| ((idx, log, key), zero1.clone())));
129
130    // Merge occupied and empty ranges.
131    let used_ranges = full_ranges.concat(zero_ranges);
132
133    // Each key should initiate a value of `zero` at position `0`.
134    let init_states =
135        queries
136            .clone()
137            .map(move |(_, key)| ((0, key), zero2.clone()))
138            .distinct();
139
140    // Iteratively expand assigned values by joining existing ranges with current assignments.
141    init_states
142        .clone()
143        .iterate(|scope, states| {
144            let init_states = init_states.enter(scope);
145            used_ranges
146                .enter(scope)
147                .map(|((pos, log, key), data)| ((pos << log, key), (log, data)))
148                .join_map(states, move |&(pos, ref key), &(log, ref data), state|
149                    ((pos + (1 << log), key.clone()), combine(key, state, data)))
150                .concat(init_states)
151                .distinct()
152        })
153        .semijoin(queries)
154}