Skip to main content

reifydb_engine/vm/volcano/
aggregate.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{
5	collections::{HashMap, HashSet},
6	sync::Arc,
7};
8
9use reifydb_core::{
10	error::{CoreError, diagnostic::query},
11	value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders},
12};
13use reifydb_routine::routine::{
14	Accumulator, FunctionKind, context::FunctionContext, error::RoutineError, registry::Routines,
15};
16use reifydb_rql::expression::{Expression, name::display_label};
17use reifydb_transaction::transaction::Transaction;
18use reifydb_type::{
19	error,
20	fragment::Fragment,
21	value::{Value, r#type::Type},
22};
23use tracing::instrument;
24
25use crate::{
26	Result,
27	vm::volcano::query::{QueryContext, QueryNode},
28};
29
30enum Projection {
31	Aggregate {
32		column: String,
33		column_fragment: Fragment,
34		alias: Fragment,
35		accumulator: Box<dyn Accumulator>,
36	},
37	Group {
38		column: String,
39		alias: Fragment,
40	},
41}
42
43pub(crate) struct AggregateNode {
44	input: Box<dyn QueryNode>,
45	by: Vec<Expression>,
46	map: Vec<Expression>,
47	headers: Option<ColumnHeaders>,
48	context: Option<Arc<QueryContext>>,
49}
50
51impl AggregateNode {
52	pub fn new(
53		input: Box<dyn QueryNode>,
54		by: Vec<Expression>,
55		map: Vec<Expression>,
56		context: Arc<QueryContext>,
57	) -> Self {
58		Self {
59			input,
60			by,
61			map,
62			headers: None,
63			context: Some(context),
64		}
65	}
66}
67
68impl QueryNode for AggregateNode {
69	#[instrument(level = "trace", skip_all, name = "volcano::aggregate::initialize")]
70	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
71		self.input.initialize(rx, ctx)?;
72		// Already has context from constructor
73		Ok(())
74	}
75
76	#[instrument(level = "trace", skip_all, name = "volcano::aggregate::next")]
77	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
78		debug_assert!(self.context.is_some(), "AggregateNode::next() called before initialize()");
79		let stored_ctx = self.context.as_ref().unwrap();
80
81		if self.headers.is_some() {
82			return Ok(None);
83		}
84
85		let (keys, mut projections) =
86			parse_keys_and_aggregates(&self.by, &self.map, &stored_ctx.services.routines, stored_ctx)?;
87
88		let mut seen_groups = HashSet::<Vec<Value>>::new();
89		let mut group_key_order: Vec<Vec<Value>> = Vec::new();
90
91		while let Some(columns) = self.input.next(rx, ctx)? {
92			let groups = columns.group_by_view(&keys)?;
93
94			for (group_key, _) in &groups {
95				if seen_groups.insert(group_key.clone()) {
96					group_key_order.push(group_key.clone());
97				}
98			}
99
100			for projection in &mut projections {
101				if let Projection::Aggregate {
102					accumulator,
103					column,
104					column_fragment,
105					..
106				} = projection
107				{
108					let column_ref = columns.column(column).ok_or_else(|| {
109						error!(query::column_not_found(column_fragment.clone()))
110					})?;
111					let cwn = ColumnWithName::new(
112						column_ref.name().clone(),
113						column_ref.data().clone(),
114					);
115					accumulator.update(&Columns::new(vec![cwn]), &groups)?;
116				}
117			}
118		}
119
120		let mut result_columns = Vec::new();
121
122		for projection in projections {
123			match projection {
124				Projection::Group {
125					alias,
126					column,
127					..
128				} => {
129					let col_idx = keys.iter().position(|k| k == &column).unwrap();
130
131					let first_key_type = if group_key_order.is_empty() {
132						None
133					} else {
134						Some(group_key_order[0][col_idx].get_type())
135					};
136					let mut c = ColumnWithName {
137						name: Fragment::internal(alias.fragment()),
138						data: ColumnBuffer::none_typed(
139							first_key_type.unwrap_or(Type::Boolean),
140							0,
141						),
142					};
143					for key in &group_key_order {
144						c.data_mut().push_value(key[col_idx].clone());
145					}
146					result_columns.push(c);
147				}
148				Projection::Aggregate {
149					alias,
150					mut accumulator,
151					..
152				} => {
153					let (keys_out, mut data) = accumulator.finalize().unwrap();
154					align_column_data(&group_key_order, &keys_out, &mut data).unwrap();
155					result_columns.push(ColumnWithName {
156						name: Fragment::internal(alias.fragment()),
157						data,
158					});
159				}
160			}
161		}
162
163		let columns = Columns::new(result_columns);
164		self.headers = Some(ColumnHeaders::from_columns(&columns));
165
166		Ok(Some(columns))
167	}
168
169	fn headers(&self) -> Option<ColumnHeaders> {
170		self.headers.clone().or(self.input.headers())
171	}
172}
173
174fn parse_keys_and_aggregates<'a>(
175	by: &'a [Expression],
176	project: &'a [Expression],
177	routines: &'a Routines,
178	ctx: &QueryContext,
179) -> Result<(Vec<&'a str>, Vec<Projection>)> {
180	let mut keys = Vec::new();
181	let mut projections = Vec::new();
182
183	for gb in by {
184		match gb {
185			Expression::Column(c) => {
186				keys.push(c.0.name.text());
187				projections.push(Projection::Group {
188					column: c.0.name.text().to_string(),
189					alias: c.0.name.clone(),
190				})
191			}
192			Expression::AccessSource(access) => {
193				// Handle qualified column references like
194				// departments.dept_name
195				keys.push(access.column.name.text());
196				projections.push(Projection::Group {
197					column: access.column.name.text().to_string(),
198					alias: access.column.name.clone(),
199				})
200			}
201			// _ => return
202			// Err(reifydb_type::error::Error::Unsupported("Non-column
203			// group by not supported".into())),
204			expr => panic!("Non-column group by not supported: {expr:#?}"),
205		}
206	}
207
208	for p in project {
209		// Extract the actual expression, handling aliases
210		let (actual_expr, alias) = match p {
211			Expression::Alias(alias_expr) => {
212				// This is an aliased expression like
213				// "total_count: count(value)"
214				(alias_expr.expression.as_ref(), alias_expr.alias.0.clone())
215			}
216			expr => {
217				// Non-aliased expression, derive a deterministic display label.
218				(expr, display_label(expr))
219			}
220		};
221
222		match actual_expr {
223			Expression::Call(call) => {
224				let func_name = call.func.0.text();
225				let function = routines.get_aggregate_function(func_name).ok_or_else(|| {
226					RoutineError::FunctionNotFound {
227						function: call.func.0.clone(),
228					}
229				})?;
230				let _ = FunctionKind::Aggregate; // ensure kinds enum is in scope
231
232				let mut fn_ctx = FunctionContext {
233					fragment: call.func.0.clone(),
234					identity: ctx.identity,
235					row_count: 0,
236					runtime_context: &ctx.services.runtime_context,
237				};
238
239				let accumulator = function.accumulator(&mut fn_ctx).ok_or_else(|| {
240					RoutineError::FunctionExecutionFailed {
241						function: call.func.0.clone(),
242						reason: format!("Function {} is not an aggregate", func_name),
243					}
244				})?;
245
246				match call.args.first() {
247					Some(Expression::Column(c)) => {
248						projections.push(Projection::Aggregate {
249							column: c.0.name.text().to_string(),
250							column_fragment: c.0.name.clone(),
251							alias,
252							accumulator,
253						});
254					}
255					Some(Expression::AccessSource(access)) => {
256						// Handle qualified column
257						// references in aggregate
258						// functions
259						projections.push(Projection::Aggregate {
260							column: access.column.name.text().to_string(),
261							column_fragment: access.column.name.clone(),
262							alias,
263							accumulator,
264						});
265					}
266					None => {
267						return Err(RoutineError::FunctionArityMismatch {
268							function: call.func.0.clone(),
269							expected: 1,
270							actual: 0,
271						}
272						.into());
273					}
274					Some(arg) => {
275						let actual_type = arg.infer_type().ok_or_else(|| {
276							RoutineError::FunctionExecutionFailed {
277								function: call.func.0.clone(),
278								reason: "aggregate function arguments must be column references".to_string(),
279							}
280						})?;
281						let expected = function.accepted_types().expected_at(0).to_vec();
282						return Err(RoutineError::FunctionInvalidArgumentType {
283							function: call.func.0.clone(),
284							argument_index: 0,
285							expected,
286							actual: actual_type,
287						}
288						.into());
289					}
290				}
291			}
292			// _ => return
293			// Err(reifydb_type::error::Error::Unsupported("Expected
294			// aggregate call expression".into())),
295			_ => panic!("Expected aggregate call expression, got: {actual_expr:#?}"),
296		}
297	}
298	Ok((keys, projections))
299}
300
301fn align_column_data(group_key_order: &[Vec<Value>], keys: &[Vec<Value>], data: &mut ColumnBuffer) -> Result<()> {
302	let mut key_to_index = HashMap::new();
303	for (i, key) in keys.iter().enumerate() {
304		key_to_index.insert(key, i);
305	}
306
307	let reorder_indices: Vec<usize> = group_key_order
308		.iter()
309		.map(|k| {
310			key_to_index.get(k).copied().ok_or_else(|| {
311				CoreError::FrameError {
312					message: format!("Group key {:?} missing in aggregate output", k),
313				}
314				.into()
315			})
316		})
317		.collect::<Result<Vec<_>>>()?;
318
319	data.reorder(&reorder_indices);
320	Ok(())
321}