reifydb_engine/vm/volcano/
aggregate.rs1use std::{
5 collections::{HashMap, HashSet},
6 sync::Arc,
7};
8
9use reifydb_core::{
10 error::{CoreError, diagnostic::query},
11 value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders},
12};
13use reifydb_routine::routine::{
14 Accumulator, FunctionKind, context::FunctionContext, error::RoutineError, registry::Routines,
15};
16use reifydb_rql::expression::{Expression, name::display_label};
17use reifydb_transaction::transaction::Transaction;
18use reifydb_type::{
19 error,
20 fragment::Fragment,
21 value::{Value, r#type::Type},
22};
23use tracing::instrument;
24
25use crate::{
26 Result,
27 vm::volcano::query::{QueryContext, QueryNode},
28};
29
30enum Projection {
31 Aggregate {
32 column: String,
33 column_fragment: Fragment,
34 alias: Fragment,
35 accumulator: Box<dyn Accumulator>,
36 },
37 Group {
38 column: String,
39 alias: Fragment,
40 },
41}
42
43pub(crate) struct AggregateNode {
44 input: Box<dyn QueryNode>,
45 by: Vec<Expression>,
46 map: Vec<Expression>,
47 headers: Option<ColumnHeaders>,
48 context: Option<Arc<QueryContext>>,
49}
50
51impl AggregateNode {
52 pub fn new(
53 input: Box<dyn QueryNode>,
54 by: Vec<Expression>,
55 map: Vec<Expression>,
56 context: Arc<QueryContext>,
57 ) -> Self {
58 Self {
59 input,
60 by,
61 map,
62 headers: None,
63 context: Some(context),
64 }
65 }
66}
67
68impl QueryNode for AggregateNode {
69 #[instrument(level = "trace", skip_all, name = "volcano::aggregate::initialize")]
70 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
71 self.input.initialize(rx, ctx)?;
72 Ok(())
74 }
75
76 #[instrument(level = "trace", skip_all, name = "volcano::aggregate::next")]
77 fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
78 debug_assert!(self.context.is_some(), "AggregateNode::next() called before initialize()");
79 let stored_ctx = self.context.as_ref().unwrap();
80
81 if self.headers.is_some() {
82 return Ok(None);
83 }
84
85 let (keys, mut projections) =
86 parse_keys_and_aggregates(&self.by, &self.map, &stored_ctx.services.routines, stored_ctx)?;
87
88 let mut seen_groups = HashSet::<Vec<Value>>::new();
89 let mut group_key_order: Vec<Vec<Value>> = Vec::new();
90
91 while let Some(columns) = self.input.next(rx, ctx)? {
92 let groups = columns.group_by_view(&keys)?;
93
94 for (group_key, _) in &groups {
95 if seen_groups.insert(group_key.clone()) {
96 group_key_order.push(group_key.clone());
97 }
98 }
99
100 for projection in &mut projections {
101 if let Projection::Aggregate {
102 accumulator,
103 column,
104 column_fragment,
105 ..
106 } = projection
107 {
108 let column_ref = columns.column(column).ok_or_else(|| {
109 error!(query::column_not_found(column_fragment.clone()))
110 })?;
111 let cwn = ColumnWithName::new(
112 column_ref.name().clone(),
113 column_ref.data().clone(),
114 );
115 accumulator.update(&Columns::new(vec![cwn]), &groups)?;
116 }
117 }
118 }
119
120 let mut result_columns = Vec::new();
121
122 for projection in projections {
123 match projection {
124 Projection::Group {
125 alias,
126 column,
127 ..
128 } => {
129 let col_idx = keys.iter().position(|k| k == &column).unwrap();
130
131 let first_key_type = if group_key_order.is_empty() {
132 None
133 } else {
134 Some(group_key_order[0][col_idx].get_type())
135 };
136 let mut c = ColumnWithName {
137 name: Fragment::internal(alias.fragment()),
138 data: ColumnBuffer::none_typed(
139 first_key_type.unwrap_or(Type::Boolean),
140 0,
141 ),
142 };
143 for key in &group_key_order {
144 c.data_mut().push_value(key[col_idx].clone());
145 }
146 result_columns.push(c);
147 }
148 Projection::Aggregate {
149 alias,
150 mut accumulator,
151 ..
152 } => {
153 let (keys_out, mut data) = accumulator.finalize().unwrap();
154 align_column_data(&group_key_order, &keys_out, &mut data).unwrap();
155 result_columns.push(ColumnWithName {
156 name: Fragment::internal(alias.fragment()),
157 data,
158 });
159 }
160 }
161 }
162
163 let columns = Columns::new(result_columns);
164 self.headers = Some(ColumnHeaders::from_columns(&columns));
165
166 Ok(Some(columns))
167 }
168
169 fn headers(&self) -> Option<ColumnHeaders> {
170 self.headers.clone().or(self.input.headers())
171 }
172}
173
174fn parse_keys_and_aggregates<'a>(
175 by: &'a [Expression],
176 project: &'a [Expression],
177 routines: &'a Routines,
178 ctx: &QueryContext,
179) -> Result<(Vec<&'a str>, Vec<Projection>)> {
180 let mut keys = Vec::new();
181 let mut projections = Vec::new();
182
183 for gb in by {
184 match gb {
185 Expression::Column(c) => {
186 keys.push(c.0.name.text());
187 projections.push(Projection::Group {
188 column: c.0.name.text().to_string(),
189 alias: c.0.name.clone(),
190 })
191 }
192 Expression::AccessSource(access) => {
193 keys.push(access.column.name.text());
196 projections.push(Projection::Group {
197 column: access.column.name.text().to_string(),
198 alias: access.column.name.clone(),
199 })
200 }
201 expr => panic!("Non-column group by not supported: {expr:#?}"),
205 }
206 }
207
208 for p in project {
209 let (actual_expr, alias) = match p {
211 Expression::Alias(alias_expr) => {
212 (alias_expr.expression.as_ref(), alias_expr.alias.0.clone())
215 }
216 expr => {
217 (expr, display_label(expr))
219 }
220 };
221
222 match actual_expr {
223 Expression::Call(call) => {
224 let func_name = call.func.0.text();
225 let function = routines.get_aggregate_function(func_name).ok_or_else(|| {
226 RoutineError::FunctionNotFound {
227 function: call.func.0.clone(),
228 }
229 })?;
230 let _ = FunctionKind::Aggregate; let mut fn_ctx = FunctionContext {
233 fragment: call.func.0.clone(),
234 identity: ctx.identity,
235 row_count: 0,
236 runtime_context: &ctx.services.runtime_context,
237 };
238
239 let accumulator = function.accumulator(&mut fn_ctx).ok_or_else(|| {
240 RoutineError::FunctionExecutionFailed {
241 function: call.func.0.clone(),
242 reason: format!("Function {} is not an aggregate", func_name),
243 }
244 })?;
245
246 match call.args.first() {
247 Some(Expression::Column(c)) => {
248 projections.push(Projection::Aggregate {
249 column: c.0.name.text().to_string(),
250 column_fragment: c.0.name.clone(),
251 alias,
252 accumulator,
253 });
254 }
255 Some(Expression::AccessSource(access)) => {
256 projections.push(Projection::Aggregate {
260 column: access.column.name.text().to_string(),
261 column_fragment: access.column.name.clone(),
262 alias,
263 accumulator,
264 });
265 }
266 None => {
267 return Err(RoutineError::FunctionArityMismatch {
268 function: call.func.0.clone(),
269 expected: 1,
270 actual: 0,
271 }
272 .into());
273 }
274 Some(arg) => {
275 let actual_type = arg.infer_type().ok_or_else(|| {
276 RoutineError::FunctionExecutionFailed {
277 function: call.func.0.clone(),
278 reason: "aggregate function arguments must be column references".to_string(),
279 }
280 })?;
281 let expected = function.accepted_types().expected_at(0).to_vec();
282 return Err(RoutineError::FunctionInvalidArgumentType {
283 function: call.func.0.clone(),
284 argument_index: 0,
285 expected,
286 actual: actual_type,
287 }
288 .into());
289 }
290 }
291 }
292 _ => panic!("Expected aggregate call expression, got: {actual_expr:#?}"),
296 }
297 }
298 Ok((keys, projections))
299}
300
301fn align_column_data(group_key_order: &[Vec<Value>], keys: &[Vec<Value>], data: &mut ColumnBuffer) -> Result<()> {
302 let mut key_to_index = HashMap::new();
303 for (i, key) in keys.iter().enumerate() {
304 key_to_index.insert(key, i);
305 }
306
307 let reorder_indices: Vec<usize> = group_key_order
308 .iter()
309 .map(|k| {
310 key_to_index.get(k).copied().ok_or_else(|| {
311 CoreError::FrameError {
312 message: format!("Group key {:?} missing in aggregate output", k),
313 }
314 .into()
315 })
316 })
317 .collect::<Result<Vec<_>>>()?;
318
319 data.reorder(&reorder_indices);
320 Ok(())
321}