Skip to main content

reifydb_engine/vm/volcano/
aggregate.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{
5	collections::{HashMap, HashSet},
6	sync::Arc,
7};
8
9use reifydb_core::{
10	error::{CoreError, diagnostic::query},
11	value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders},
12};
13use reifydb_routine::routine::{
14	Accumulator, FunctionKind, context::FunctionContext, error::RoutineError, registry::Routines,
15};
16use reifydb_rql::expression::{Expression, name::display_label};
17use reifydb_transaction::transaction::Transaction;
18use reifydb_type::{
19	error,
20	fragment::Fragment,
21	value::{Value, r#type::Type},
22};
23use tracing::instrument;
24
25use crate::{
26	Result,
27	vm::volcano::query::{QueryContext, QueryNode},
28};
29
30enum Projection {
31	Aggregate {
32		column: String,
33		column_fragment: Fragment,
34		alias: Fragment,
35		accumulator: Box<dyn Accumulator>,
36	},
37	Group {
38		column: String,
39		alias: Fragment,
40	},
41}
42
43pub(crate) struct AggregateNode {
44	input: Box<dyn QueryNode>,
45	by: Vec<Expression>,
46	map: Vec<Expression>,
47	headers: Option<ColumnHeaders>,
48	context: Option<Arc<QueryContext>>,
49}
50
51impl AggregateNode {
52	pub fn new(
53		input: Box<dyn QueryNode>,
54		by: Vec<Expression>,
55		map: Vec<Expression>,
56		context: Arc<QueryContext>,
57	) -> Self {
58		Self {
59			input,
60			by,
61			map,
62			headers: None,
63			context: Some(context),
64		}
65	}
66}
67
68impl QueryNode for AggregateNode {
69	#[instrument(level = "trace", skip_all, name = "volcano::aggregate::initialize")]
70	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
71		self.input.initialize(rx, ctx)?;
72
73		Ok(())
74	}
75
76	#[instrument(level = "trace", skip_all, name = "volcano::aggregate::next")]
77	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
78		debug_assert!(self.context.is_some(), "AggregateNode::next() called before initialize()");
79		let stored_ctx = self.context.as_ref().unwrap();
80
81		if self.headers.is_some() {
82			return Ok(None);
83		}
84
85		let (keys, mut projections) =
86			parse_keys_and_aggregates(&self.by, &self.map, &stored_ctx.services.routines, stored_ctx)?;
87
88		let mut seen_groups = HashSet::<Vec<Value>>::new();
89		let mut group_key_order: Vec<Vec<Value>> = Vec::new();
90
91		while let Some(columns) = self.input.next(rx, ctx)? {
92			let groups = columns.group_by_view(&keys)?;
93
94			for (group_key, _) in &groups {
95				if seen_groups.insert(group_key.clone()) {
96					group_key_order.push(group_key.clone());
97				}
98			}
99
100			for projection in &mut projections {
101				if let Projection::Aggregate {
102					accumulator,
103					column,
104					column_fragment,
105					..
106				} = projection
107				{
108					let column_ref = columns.column(column).ok_or_else(|| {
109						error!(query::column_not_found(column_fragment.clone()))
110					})?;
111					let cwn = ColumnWithName::new(
112						column_ref.name().clone(),
113						column_ref.data().clone(),
114					);
115					accumulator.update(&Columns::new(vec![cwn]), &groups)?;
116				}
117			}
118		}
119
120		let mut result_columns = Vec::new();
121
122		for projection in projections {
123			match projection {
124				Projection::Group {
125					alias,
126					column,
127					..
128				} => {
129					let col_idx = keys.iter().position(|k| k == &column).unwrap();
130
131					let first_key_type = if group_key_order.is_empty() {
132						None
133					} else {
134						Some(group_key_order[0][col_idx].get_type())
135					};
136					let mut c = ColumnWithName {
137						name: Fragment::internal(alias.fragment()),
138						data: ColumnBuffer::none_typed(
139							first_key_type.unwrap_or(Type::Boolean),
140							0,
141						),
142					};
143					for key in &group_key_order {
144						c.data_mut().push_value(key[col_idx].clone());
145					}
146					result_columns.push(c);
147				}
148				Projection::Aggregate {
149					alias,
150					mut accumulator,
151					..
152				} => {
153					let (keys_out, mut data) = accumulator.finalize().unwrap();
154					align_column_data(&group_key_order, &keys_out, &mut data).unwrap();
155					result_columns.push(ColumnWithName {
156						name: Fragment::internal(alias.fragment()),
157						data,
158					});
159				}
160			}
161		}
162
163		let columns = Columns::new(result_columns);
164		self.headers = Some(ColumnHeaders::from_columns(&columns));
165
166		Ok(Some(columns))
167	}
168
169	fn headers(&self) -> Option<ColumnHeaders> {
170		self.headers.clone().or(self.input.headers())
171	}
172}
173
174fn parse_keys_and_aggregates<'a>(
175	by: &'a [Expression],
176	project: &'a [Expression],
177	routines: &'a Routines,
178	ctx: &QueryContext,
179) -> Result<(Vec<&'a str>, Vec<Projection>)> {
180	let mut keys = Vec::new();
181	let mut projections = Vec::new();
182
183	for gb in by {
184		match gb {
185			Expression::Column(c) => {
186				keys.push(c.0.name.text());
187				projections.push(Projection::Group {
188					column: c.0.name.text().to_string(),
189					alias: c.0.name.clone(),
190				})
191			}
192			Expression::AccessSource(access) => {
193				keys.push(access.column.name.text());
194				projections.push(Projection::Group {
195					column: access.column.name.text().to_string(),
196					alias: access.column.name.clone(),
197				})
198			}
199
200			expr => panic!("Non-column group by not supported: {expr:#?}"),
201		}
202	}
203
204	for p in project {
205		let (actual_expr, alias) = match p {
206			Expression::Alias(alias_expr) => (alias_expr.expression.as_ref(), alias_expr.alias.0.clone()),
207			expr => (expr, display_label(expr)),
208		};
209
210		match actual_expr {
211			Expression::Call(call) => {
212				let func_name = call.func.0.text();
213				let function = routines.get_aggregate_function(func_name).ok_or_else(|| {
214					RoutineError::FunctionNotFound {
215						function: call.func.0.clone(),
216					}
217				})?;
218				let _ = FunctionKind::Aggregate;
219
220				let mut fn_ctx = FunctionContext {
221					fragment: call.func.0.clone(),
222					identity: ctx.identity,
223					row_count: 0,
224					runtime_context: &ctx.services.runtime_context,
225				};
226
227				let accumulator = function.accumulator(&mut fn_ctx).ok_or_else(|| {
228					RoutineError::FunctionExecutionFailed {
229						function: call.func.0.clone(),
230						reason: format!("Function {} is not an aggregate", func_name),
231					}
232				})?;
233
234				match call.args.first() {
235					Some(Expression::Column(c)) => {
236						projections.push(Projection::Aggregate {
237							column: c.0.name.text().to_string(),
238							column_fragment: c.0.name.clone(),
239							alias,
240							accumulator,
241						});
242					}
243					Some(Expression::AccessSource(access)) => {
244						projections.push(Projection::Aggregate {
245							column: access.column.name.text().to_string(),
246							column_fragment: access.column.name.clone(),
247							alias,
248							accumulator,
249						});
250					}
251					None => {
252						return Err(RoutineError::FunctionArityMismatch {
253							function: call.func.0.clone(),
254							expected: 1,
255							actual: 0,
256						}
257						.into());
258					}
259					Some(arg) => {
260						let actual_type = arg.infer_type().ok_or_else(|| {
261							RoutineError::FunctionExecutionFailed {
262								function: call.func.0.clone(),
263								reason: "aggregate function arguments must be column references".to_string(),
264							}
265						})?;
266						let expected = function.accepted_types().expected_at(0).to_vec();
267						return Err(RoutineError::FunctionInvalidArgumentType {
268							function: call.func.0.clone(),
269							argument_index: 0,
270							expected,
271							actual: actual_type,
272						}
273						.into());
274					}
275				}
276			}
277
278			_ => panic!("Expected aggregate call expression, got: {actual_expr:#?}"),
279		}
280	}
281	Ok((keys, projections))
282}
283
284fn align_column_data(group_key_order: &[Vec<Value>], keys: &[Vec<Value>], data: &mut ColumnBuffer) -> Result<()> {
285	let mut key_to_index = HashMap::new();
286	for (i, key) in keys.iter().enumerate() {
287		key_to_index.insert(key, i);
288	}
289
290	let reorder_indices: Vec<usize> = group_key_order
291		.iter()
292		.map(|k| {
293			key_to_index.get(k).copied().ok_or_else(|| {
294				CoreError::FrameError {
295					message: format!("Group key {:?} missing in aggregate output", k),
296				}
297				.into()
298			})
299		})
300		.collect::<Result<Vec<_>>>()?;
301
302	data.reorder(&reorder_indices);
303	Ok(())
304}