1use timely::dataflow::Scope;
4
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 difference::{Multiply, Semigroup},
9 hashable::Hashable,
10 lattice::Lattice,
11 operators::{Join, Reduce, Threshold},
12 Data, ExchangeData, VecCollection,
13};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
17pub enum AggregateFunc {
18 Count,
20 Sum,
22 Min,
24 Max,
26 Avg,
28 CountDistinct,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
34pub enum AggregateValue {
35 Integer(i128),
37 Average {
39 sum: i128,
41 count: i128,
43 },
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum SortDirection {
49 Ascending,
51 Descending,
53}
54
55pub fn filter<G, D, R, F>(input: &VecCollection<G, D, R>, predicate: F) -> VecCollection<G, D, R>
57where
58 G: Scope,
59 D: Clone + 'static,
60 R: Clone + 'static,
61 F: FnMut(&D) -> bool + 'static,
62{
63 input.filter(predicate)
64}
65
66pub fn project<G, D, D2, R, F>(
68 input: &VecCollection<G, D, R>,
69 projection: F,
70) -> VecCollection<G, D2, R>
71where
72 G: Scope,
73 D: Clone + 'static,
74 D2: Data,
75 R: Clone + 'static,
76 F: FnMut(D) -> D2 + 'static,
77{
78 input.map(projection)
79}
80
81pub fn equi_join<G, K, V, V2, R, R2, D, F>(
83 left: &VecCollection<G, (K, V), R>,
84 right: &VecCollection<G, (K, V2), R2>,
85 projection: F,
86) -> VecCollection<G, D, <R as Multiply<R2>>::Output>
87where
88 G: Scope<Timestamp: Lattice + Ord>,
89 K: ExchangeData + Hashable,
90 V: ExchangeData,
91 V2: ExchangeData,
92 R: ExchangeData + Semigroup + Multiply<R2, Output: Semigroup + 'static>,
93 R2: ExchangeData + Semigroup,
94 D: Data,
95 F: FnMut(&K, &V, &V2) -> D + 'static,
96{
97 left.join_map(right, projection)
98}
99
100pub fn left_join<G, K, V, V2, D, F>(
102 left: &VecCollection<G, (K, V), isize>,
103 right: &VecCollection<G, (K, V2), isize>,
104 mut projection: F,
105) -> VecCollection<G, D, isize>
106where
107 G: Scope<Timestamp: Lattice + Ord>,
108 K: ExchangeData + Hashable,
109 V: ExchangeData,
110 V2: ExchangeData,
111 D: Data,
112 F: FnMut(&K, &V, Option<&V2>) -> D + Clone + 'static,
113{
114 let matched = left.join_map(right, {
115 let mut projection = projection.clone();
116 move |key, left, right| projection(key, left, Some(right))
117 });
118 let right_keys = right.map(|(key, _value)| key).distinct();
119 let unmatched = left
120 .antijoin(&right_keys)
121 .map(move |(key, value)| projection(&key, &value, None));
122
123 matched.concat(&unmatched)
124}
125
126pub fn distinct<G, D, R>(input: &VecCollection<G, D, R>) -> VecCollection<G, D, isize>
128where
129 G: Scope<Timestamp: Lattice + Ord>,
130 D: ExchangeData + Hashable,
131 R: ExchangeData + Semigroup,
132{
133 input.distinct()
134}
135
136pub fn aggregate_i64<G, K>(
138 input: &VecCollection<G, (K, i64), isize>,
139 funcs: Vec<AggregateFunc>,
140) -> VecCollection<G, (K, Vec<AggregateValue>), isize>
141where
142 G: Scope<Timestamp: Lattice + Ord>,
143 K: ExchangeData + Hashable,
144{
145 input.reduce(move |_key, values, output| {
146 let aggregates = funcs
147 .iter()
148 .map(|func| evaluate_i64_aggregate(*func, values))
149 .collect();
150 output.push((aggregates, 1));
151 })
152}
153
154fn evaluate_i64_aggregate(func: AggregateFunc, values: &[(&i64, isize)]) -> AggregateValue {
155 let positive_values = values
156 .iter()
157 .filter_map(|(value, diff)| usize::try_from(*diff).ok().map(|count| (**value, count)));
158
159 match func {
160 AggregateFunc::Count => AggregateValue::Integer(
161 values
162 .iter()
163 .filter_map(|(_value, diff)| i128::try_from(*diff).ok())
164 .sum(),
165 ),
166 AggregateFunc::Sum => AggregateValue::Integer(
167 values
168 .iter()
169 .map(|(value, diff)| i128::from(**value) * (*diff as i128))
170 .sum(),
171 ),
172 AggregateFunc::Min => AggregateValue::Integer(
173 positive_values
174 .clone()
175 .map(|(value, _count)| i128::from(value))
176 .min()
177 .unwrap_or_default(),
178 ),
179 AggregateFunc::Max => AggregateValue::Integer(
180 positive_values
181 .clone()
182 .map(|(value, _count)| i128::from(value))
183 .max()
184 .unwrap_or_default(),
185 ),
186 AggregateFunc::Avg => {
187 let mut sum = 0_i128;
188 let mut count = 0_i128;
189 for (value, diff) in values {
190 sum += i128::from(**value) * (*diff as i128);
191 count += *diff as i128;
192 }
193 AggregateValue::Average { sum, count }
194 }
195 AggregateFunc::CountDistinct => {
196 AggregateValue::Integer(i128::try_from(positive_values.count()).unwrap_or(i128::MAX))
197 }
198 }
199}
200
201pub fn topk<G, D>(
203 input: &VecCollection<G, D, isize>,
204 direction: SortDirection,
205 limit: usize,
206 offset: usize,
207) -> VecCollection<G, D, isize>
208where
209 G: Scope<Timestamp: Lattice + Ord>,
210 D: ExchangeData + Hashable,
211{
212 input
213 .map(|value| ((), value))
214 .reduce(move |_key, values, output| {
215 let mut expanded = Vec::new();
216 for (value, diff) in values {
217 if let Ok(count) = usize::try_from(*diff) {
218 expanded.extend(std::iter::repeat_with(|| (*value).clone()).take(count));
219 }
220 }
221 if direction == SortDirection::Descending {
222 expanded.reverse();
223 }
224 for value in expanded.into_iter().skip(offset).take(limit) {
225 output.push((value, 1));
226 }
227 })
228 .map(|(_key, value)| value)
229}
230
231pub fn union<G, D, R>(
233 left: &VecCollection<G, D, R>,
234 right: &VecCollection<G, D, R>,
235) -> VecCollection<G, D, R>
236where
237 G: Scope,
238 D: Clone + 'static,
239 R: Clone + 'static,
240{
241 left.concat(right)
242}
243
244pub fn union_distinct<G, D, R>(
246 left: &VecCollection<G, D, R>,
247 right: &VecCollection<G, D, R>,
248) -> VecCollection<G, D, isize>
249where
250 G: Scope<Timestamp: Lattice + Ord>,
251 D: ExchangeData + Hashable,
252 R: ExchangeData + Semigroup,
253{
254 union(left, right).distinct()
255}
256
257#[cfg(test)]
258mod tests {
259 use crate::input::Input;
260
261 use super::{
262 aggregate_i64, distinct, equi_join, filter, left_join, project, topk, union,
263 union_distinct, AggregateFunc, AggregateValue, SortDirection,
264 };
265
266 #[test]
267 fn filter_and_project_delegate_to_differential_operators() {
268 timely::example(|scope| {
269 let input = scope.new_collection_from(0..5).1;
270 let actual = project(&filter(&input, |value| value % 2 == 0), |value| value * 10);
271 let expected = scope.new_collection_from(vec![0, 20, 40]).1;
272
273 actual.assert_eq(&expected);
274 });
275 }
276
277 #[test]
278 fn equi_join_uses_keyed_arrangements() {
279 timely::example(|scope| {
280 let left = scope
281 .new_collection_from(vec![(1_u64, String::from("a")), (2, String::from("b"))])
282 .1;
283 let right = scope.new_collection_from(vec![(1_u64, 10), (3, 30)]).1;
284 let actual = equi_join(&left, &right, |key, left, right| {
285 (*key, format!("{left}:{right}"))
286 });
287 let expected = scope
288 .new_collection_from(vec![(1_u64, String::from("a:10"))])
289 .1;
290
291 actual.assert_eq(&expected);
292 });
293 }
294
295 #[test]
296 fn left_join_emits_matches_and_null_extended_unmatched_rows() {
297 timely::example(|scope| {
298 let left = scope
299 .new_collection_from(vec![(1_u64, String::from("a")), (2, String::from("b"))])
300 .1;
301 let right = scope.new_collection_from(vec![(1_u64, 10), (3, 30)]).1;
302 let actual = left_join(&left, &right, |key, left, right| {
303 (*key, left.clone(), right.copied())
304 });
305 let expected = scope
306 .new_collection_from(vec![
307 (1_u64, String::from("a"), Some(10)),
308 (2, String::from("b"), None),
309 ])
310 .1;
311
312 actual.assert_eq(&expected);
313 });
314 }
315
316 #[test]
317 fn aggregate_i64_supports_numeric_aggregate_set() {
318 timely::example(|scope| {
319 let input = scope
320 .new_collection_from(vec![(1_u64, 10), (1, 20), (1, 20), (2, 7)])
321 .1;
322 let actual = aggregate_i64(
323 &input,
324 vec![
325 AggregateFunc::Count,
326 AggregateFunc::Sum,
327 AggregateFunc::Min,
328 AggregateFunc::Max,
329 AggregateFunc::Avg,
330 AggregateFunc::CountDistinct,
331 ],
332 );
333 let expected = scope
334 .new_collection_from(vec![
335 (
336 1_u64,
337 vec![
338 AggregateValue::Integer(3),
339 AggregateValue::Integer(50),
340 AggregateValue::Integer(10),
341 AggregateValue::Integer(20),
342 AggregateValue::Average { sum: 50, count: 3 },
343 AggregateValue::Integer(2),
344 ],
345 ),
346 (
347 2,
348 vec![
349 AggregateValue::Integer(1),
350 AggregateValue::Integer(7),
351 AggregateValue::Integer(7),
352 AggregateValue::Integer(7),
353 AggregateValue::Average { sum: 7, count: 1 },
354 AggregateValue::Integer(1),
355 ],
356 ),
357 ])
358 .1;
359
360 actual.assert_eq(&expected);
361 });
362 }
363
364 #[test]
365 fn topk_slices_ordered_rows_with_reduce() {
366 timely::example(|scope| {
367 let input = scope.new_collection_from(vec![5, 1, 3, 2, 4]).1;
368 let actual = topk(&input, SortDirection::Descending, 2, 1);
369 let expected = scope.new_collection_from(vec![4, 3]).1;
370
371 actual.assert_eq(&expected);
372 });
373 }
374
375 #[test]
376 fn distinct_and_union_delegate_to_differential_operators() {
377 timely::example(|scope| {
378 let left = scope.new_collection_from(vec![1, 1, 2]).1;
379 let right = scope.new_collection_from(vec![2, 3]).1;
380 let all = union(&left, &right).consolidate();
381 let all_expected = scope.new_collection_from(vec![1, 1, 2, 2, 3]).1;
382 let distinct_actual = distinct(&all);
383 let distinct_expected = scope.new_collection_from(vec![1, 2, 3]).1;
384
385 all.assert_eq(&all_expected);
386 distinct_actual.assert_eq(&distinct_expected);
387 });
388 }
389
390 #[test]
391 fn union_distinct_concats_then_distincts() {
392 timely::example(|scope| {
393 let left = scope.new_collection_from(vec![1, 2]).1;
394 let right = scope.new_collection_from(vec![2, 3]).1;
395 let actual = union_distinct(&left, &right);
396 let expected = scope.new_collection_from(vec![1, 2, 3]).1;
397
398 actual.assert_eq(&expected);
399 });
400 }
401}