locustdb/engine/operators/
merge_aggregate.rs

1use ordered_float::OrderedFloat;
2
3use crate::engine::*;
4
5#[derive(Debug)]
6pub struct MergeAggregate<T> {
7    pub merge_ops: BufferRef<MergeOp>,
8    pub left: BufferRef<T>,
9    pub right: BufferRef<T>,
10    pub aggregated: BufferRef<T>,
11    pub aggregator: Aggregator,
12}
13
14impl<'a, T> VecOperator<'a> for MergeAggregate<T>
15where
16    T: VecData<T> + Combinable<T> + 'a,
17{
18    fn execute(&mut self, _: bool, scratchpad: &mut Scratchpad<'a>) -> Result<(), QueryError> {
19        let aggregated = {
20            let ops = scratchpad.get(self.merge_ops);
21            let left = scratchpad.get(self.left);
22            let right = scratchpad.get(self.right);
23            merge_aggregate(&ops, &left, &right, self.aggregator)
24        };
25        scratchpad.set(self.aggregated, aggregated?);
26        Ok(())
27    }
28
29    fn inputs(&self) -> Vec<BufferRef<Any>> {
30        vec![self.left.any(), self.right.any(), self.merge_ops.any()]
31    }
32    fn inputs_mut(&mut self) -> Vec<&mut usize> {
33        vec![&mut self.left.i, &mut self.right.i, &mut self.merge_ops.i]
34    }
35    fn outputs(&self) -> Vec<BufferRef<Any>> {
36        vec![self.aggregated.any()]
37    }
38    fn can_stream_input(&self, _: usize) -> bool {
39        false
40    }
41    fn can_stream_output(&self, _: usize) -> bool {
42        false
43    }
44    fn allocates(&self) -> bool {
45        true
46    }
47
48    fn display_op(&self, _: bool) -> String {
49        format!(
50            "merge_aggregate({:?}; {}, {}, {})",
51            self.aggregator, self.merge_ops, self.left, self.right
52        )
53    }
54}
55
56fn merge_aggregate<T: Combinable<T>>(
57    ops: &[MergeOp],
58    left: &[T],
59    right: &[T],
60    aggregator: Aggregator,
61) -> Result<Vec<T>, QueryError> {
62    if left.is_empty() {
63        return Ok(right.to_vec());
64    } else if right.is_empty() {
65        return Ok(left.to_vec());
66    }
67
68    let mut result = Vec::with_capacity(ops.len());
69    let mut i = 0;
70    let mut j = 0;
71    for op in ops {
72        match *op {
73            MergeOp::TakeLeft => {
74                if i == left.len() {
75                    error!("{} {} {}", left.len(), right.len(), ops.len());
76                }
77                result.push(left[i]);
78                i += 1;
79            }
80            MergeOp::TakeRight => {
81                if j == right.len() {
82                    error!("{} {} {}", left.len(), right.len(), ops.len());
83                }
84                result.push(right[j]);
85                j += 1;
86            }
87            MergeOp::MergeRight => {
88                let last = result.len() - 1;
89                result[last] = T::combine(aggregator, result[last], right[j])?;
90                j += 1;
91            }
92        }
93    }
94    Ok(result)
95}
96
97trait Combinable<T>: Clone + Copy {
98    fn combine(op: Aggregator, a: T, b: T) -> Result<T, QueryError>;
99}
100
101impl Combinable<i64> for i64 {
102    fn combine(op: Aggregator, a: i64, b: i64) -> Result<i64, QueryError> {
103        fn null_coalesce(a: i64, b: i64, combined: i64) -> Result<i64, QueryError> {
104            if a == I64_NULL {
105                Ok(b)
106            } else if b == I64_NULL {
107                Ok(a)
108            } else {
109                Ok(combined)
110            }
111        }
112        // TODO: remove null handling hack?
113        match op {
114            Aggregator::SumI64 => {
115                if a == I64_NULL {
116                    Ok(b)
117                } else if b == I64_NULL {
118                    Ok(a)
119                } else {
120                    a.checked_add(b).ok_or(QueryError::Overflow)
121                }
122            }
123            Aggregator::Count => {
124                if a == I64_NULL {
125                    Ok(b)
126                } else if b == I64_NULL {
127                    Ok(a)
128                } else {
129                    Ok(a + b)
130                }
131            }
132            Aggregator::MaxI64 => null_coalesce(a, b, std::cmp::max(a, b)),
133            Aggregator::MinI64 => null_coalesce(a, b, std::cmp::min(a, b)),
134            _ => Err(fatal!("Unsupported aggregator for i64: {:?}", op)),
135        }
136    }
137}
138
139impl Combinable<OrderedFloat<f64>> for OrderedFloat<f64> {
140    fn combine(op: Aggregator, a: of64, b: of64) -> Result<OrderedFloat<f64>, QueryError> {
141        // possibly Aggregator::XI64 is masking a bug
142        fn null_coalesce(a: of64, b: of64, combined: of64) -> Result<of64, QueryError> {
143            if a.to_bits() == F64_NULL.to_bits() {
144                Ok(b)
145            } else if b.to_bits() == F64_NULL.to_bits() {
146                Ok(a)
147            } else {
148                Ok(combined)
149            }
150        }
151        match op {
152            Aggregator::SumF64 | Aggregator::SumI64 => null_coalesce(a, b, a + b),
153            Aggregator::MaxF64 | Aggregator::MaxI64 => null_coalesce(a, b, std::cmp::max(a, b)),
154            Aggregator::MinF64 | Aggregator::MinI64 => null_coalesce(a, b, std::cmp::min(a, b)),
155            _ => Err(fatal!("Unsupported aggregator for f64: {:?}", op)),
156        }
157    }
158}