locustdb/engine/operators/
merge_aggregate.rs1use ordered_float::OrderedFloat;
2
3use crate::engine::*;
4
5#[derive(Debug)]
6pub struct MergeAggregate<T> {
7 pub merge_ops: BufferRef<MergeOp>,
8 pub left: BufferRef<T>,
9 pub right: BufferRef<T>,
10 pub aggregated: BufferRef<T>,
11 pub aggregator: Aggregator,
12}
13
14impl<'a, T> VecOperator<'a> for MergeAggregate<T>
15where
16 T: VecData<T> + Combinable<T> + 'a,
17{
18 fn execute(&mut self, _: bool, scratchpad: &mut Scratchpad<'a>) -> Result<(), QueryError> {
19 let aggregated = {
20 let ops = scratchpad.get(self.merge_ops);
21 let left = scratchpad.get(self.left);
22 let right = scratchpad.get(self.right);
23 merge_aggregate(&ops, &left, &right, self.aggregator)
24 };
25 scratchpad.set(self.aggregated, aggregated?);
26 Ok(())
27 }
28
29 fn inputs(&self) -> Vec<BufferRef<Any>> {
30 vec![self.left.any(), self.right.any(), self.merge_ops.any()]
31 }
32 fn inputs_mut(&mut self) -> Vec<&mut usize> {
33 vec![&mut self.left.i, &mut self.right.i, &mut self.merge_ops.i]
34 }
35 fn outputs(&self) -> Vec<BufferRef<Any>> {
36 vec![self.aggregated.any()]
37 }
38 fn can_stream_input(&self, _: usize) -> bool {
39 false
40 }
41 fn can_stream_output(&self, _: usize) -> bool {
42 false
43 }
44 fn allocates(&self) -> bool {
45 true
46 }
47
48 fn display_op(&self, _: bool) -> String {
49 format!(
50 "merge_aggregate({:?}; {}, {}, {})",
51 self.aggregator, self.merge_ops, self.left, self.right
52 )
53 }
54}
55
56fn merge_aggregate<T: Combinable<T>>(
57 ops: &[MergeOp],
58 left: &[T],
59 right: &[T],
60 aggregator: Aggregator,
61) -> Result<Vec<T>, QueryError> {
62 if left.is_empty() {
63 return Ok(right.to_vec());
64 } else if right.is_empty() {
65 return Ok(left.to_vec());
66 }
67
68 let mut result = Vec::with_capacity(ops.len());
69 let mut i = 0;
70 let mut j = 0;
71 for op in ops {
72 match *op {
73 MergeOp::TakeLeft => {
74 if i == left.len() {
75 error!("{} {} {}", left.len(), right.len(), ops.len());
76 }
77 result.push(left[i]);
78 i += 1;
79 }
80 MergeOp::TakeRight => {
81 if j == right.len() {
82 error!("{} {} {}", left.len(), right.len(), ops.len());
83 }
84 result.push(right[j]);
85 j += 1;
86 }
87 MergeOp::MergeRight => {
88 let last = result.len() - 1;
89 result[last] = T::combine(aggregator, result[last], right[j])?;
90 j += 1;
91 }
92 }
93 }
94 Ok(result)
95}
96
97trait Combinable<T>: Clone + Copy {
98 fn combine(op: Aggregator, a: T, b: T) -> Result<T, QueryError>;
99}
100
101impl Combinable<i64> for i64 {
102 fn combine(op: Aggregator, a: i64, b: i64) -> Result<i64, QueryError> {
103 fn null_coalesce(a: i64, b: i64, combined: i64) -> Result<i64, QueryError> {
104 if a == I64_NULL {
105 Ok(b)
106 } else if b == I64_NULL {
107 Ok(a)
108 } else {
109 Ok(combined)
110 }
111 }
112 match op {
114 Aggregator::SumI64 => {
115 if a == I64_NULL {
116 Ok(b)
117 } else if b == I64_NULL {
118 Ok(a)
119 } else {
120 a.checked_add(b).ok_or(QueryError::Overflow)
121 }
122 }
123 Aggregator::Count => {
124 if a == I64_NULL {
125 Ok(b)
126 } else if b == I64_NULL {
127 Ok(a)
128 } else {
129 Ok(a + b)
130 }
131 }
132 Aggregator::MaxI64 => null_coalesce(a, b, std::cmp::max(a, b)),
133 Aggregator::MinI64 => null_coalesce(a, b, std::cmp::min(a, b)),
134 _ => Err(fatal!("Unsupported aggregator for i64: {:?}", op)),
135 }
136 }
137}
138
139impl Combinable<OrderedFloat<f64>> for OrderedFloat<f64> {
140 fn combine(op: Aggregator, a: of64, b: of64) -> Result<OrderedFloat<f64>, QueryError> {
141 fn null_coalesce(a: of64, b: of64, combined: of64) -> Result<of64, QueryError> {
143 if a.to_bits() == F64_NULL.to_bits() {
144 Ok(b)
145 } else if b.to_bits() == F64_NULL.to_bits() {
146 Ok(a)
147 } else {
148 Ok(combined)
149 }
150 }
151 match op {
152 Aggregator::SumF64 | Aggregator::SumI64 => null_coalesce(a, b, a + b),
153 Aggregator::MaxF64 | Aggregator::MaxI64 => null_coalesce(a, b, std::cmp::max(a, b)),
154 Aggregator::MinF64 | Aggregator::MinI64 => null_coalesce(a, b, std::cmp::min(a, b)),
155 _ => Err(fatal!("Unsupported aggregator for f64: {:?}", op)),
156 }
157 }
158}