reifydb_engine/vm/volcano/
aggregate.rs1use std::{
5 collections::{HashMap, HashSet},
6 sync::Arc,
7};
8
9use reifydb_core::{
10 error::{CoreError, diagnostic::query},
11 value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders},
12};
13use reifydb_routine::routine::{
14 Accumulator, FunctionKind, context::FunctionContext, error::RoutineError, registry::Routines,
15};
16use reifydb_rql::expression::{Expression, name::display_label};
17use reifydb_transaction::transaction::Transaction;
18use reifydb_type::{
19 error,
20 fragment::Fragment,
21 value::{Value, r#type::Type},
22};
23use tracing::instrument;
24
25use crate::{
26 Result,
27 vm::volcano::query::{QueryContext, QueryNode},
28};
29
30enum Projection {
31 Aggregate {
32 column: String,
33 column_fragment: Fragment,
34 alias: Fragment,
35 accumulator: Box<dyn Accumulator>,
36 },
37 Group {
38 column: String,
39 alias: Fragment,
40 },
41}
42
43pub(crate) struct AggregateNode {
44 input: Box<dyn QueryNode>,
45 by: Vec<Expression>,
46 map: Vec<Expression>,
47 headers: Option<ColumnHeaders>,
48 context: Option<Arc<QueryContext>>,
49}
50
51impl AggregateNode {
52 pub fn new(
53 input: Box<dyn QueryNode>,
54 by: Vec<Expression>,
55 map: Vec<Expression>,
56 context: Arc<QueryContext>,
57 ) -> Self {
58 Self {
59 input,
60 by,
61 map,
62 headers: None,
63 context: Some(context),
64 }
65 }
66}
67
68impl QueryNode for AggregateNode {
69 #[instrument(level = "trace", skip_all, name = "volcano::aggregate::initialize")]
70 fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
71 self.input.initialize(rx, ctx)?;
72
73 Ok(())
74 }
75
76 #[instrument(level = "trace", skip_all, name = "volcano::aggregate::next")]
77 fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
78 debug_assert!(self.context.is_some(), "AggregateNode::next() called before initialize()");
79 let stored_ctx = self.context.as_ref().unwrap();
80
81 if self.headers.is_some() {
82 return Ok(None);
83 }
84
85 let (keys, mut projections) =
86 parse_keys_and_aggregates(&self.by, &self.map, &stored_ctx.services.routines, stored_ctx)?;
87
88 let mut seen_groups = HashSet::<Vec<Value>>::new();
89 let mut group_key_order: Vec<Vec<Value>> = Vec::new();
90
91 while let Some(columns) = self.input.next(rx, ctx)? {
92 let groups = columns.group_by_view(&keys)?;
93
94 for (group_key, _) in &groups {
95 if seen_groups.insert(group_key.clone()) {
96 group_key_order.push(group_key.clone());
97 }
98 }
99
100 for projection in &mut projections {
101 if let Projection::Aggregate {
102 accumulator,
103 column,
104 column_fragment,
105 ..
106 } = projection
107 {
108 let column_ref = columns.column(column).ok_or_else(|| {
109 error!(query::column_not_found(column_fragment.clone()))
110 })?;
111 let cwn = ColumnWithName::new(
112 column_ref.name().clone(),
113 column_ref.data().clone(),
114 );
115 accumulator.update(&Columns::new(vec![cwn]), &groups)?;
116 }
117 }
118 }
119
120 let mut result_columns = Vec::new();
121
122 for projection in projections {
123 match projection {
124 Projection::Group {
125 alias,
126 column,
127 ..
128 } => {
129 let col_idx = keys.iter().position(|k| k == &column).unwrap();
130
131 let first_key_type = if group_key_order.is_empty() {
132 None
133 } else {
134 Some(group_key_order[0][col_idx].get_type())
135 };
136 let mut c = ColumnWithName {
137 name: Fragment::internal(alias.fragment()),
138 data: ColumnBuffer::none_typed(
139 first_key_type.unwrap_or(Type::Boolean),
140 0,
141 ),
142 };
143 for key in &group_key_order {
144 c.data_mut().push_value(key[col_idx].clone());
145 }
146 result_columns.push(c);
147 }
148 Projection::Aggregate {
149 alias,
150 mut accumulator,
151 ..
152 } => {
153 let (keys_out, mut data) = accumulator.finalize().unwrap();
154 align_column_data(&group_key_order, &keys_out, &mut data).unwrap();
155 result_columns.push(ColumnWithName {
156 name: Fragment::internal(alias.fragment()),
157 data,
158 });
159 }
160 }
161 }
162
163 let columns = Columns::new(result_columns);
164 self.headers = Some(ColumnHeaders::from_columns(&columns));
165
166 Ok(Some(columns))
167 }
168
169 fn headers(&self) -> Option<ColumnHeaders> {
170 self.headers.clone().or(self.input.headers())
171 }
172}
173
174fn parse_keys_and_aggregates<'a>(
175 by: &'a [Expression],
176 project: &'a [Expression],
177 routines: &'a Routines,
178 ctx: &QueryContext,
179) -> Result<(Vec<&'a str>, Vec<Projection>)> {
180 let mut keys = Vec::new();
181 let mut projections = Vec::new();
182
183 for gb in by {
184 match gb {
185 Expression::Column(c) => {
186 keys.push(c.0.name.text());
187 projections.push(Projection::Group {
188 column: c.0.name.text().to_string(),
189 alias: c.0.name.clone(),
190 })
191 }
192 Expression::AccessSource(access) => {
193 keys.push(access.column.name.text());
194 projections.push(Projection::Group {
195 column: access.column.name.text().to_string(),
196 alias: access.column.name.clone(),
197 })
198 }
199
200 expr => panic!("Non-column group by not supported: {expr:#?}"),
201 }
202 }
203
204 for p in project {
205 let (actual_expr, alias) = match p {
206 Expression::Alias(alias_expr) => (alias_expr.expression.as_ref(), alias_expr.alias.0.clone()),
207 expr => (expr, display_label(expr)),
208 };
209
210 match actual_expr {
211 Expression::Call(call) => {
212 let func_name = call.func.0.text();
213 let function = routines.get_aggregate_function(func_name).ok_or_else(|| {
214 RoutineError::FunctionNotFound {
215 function: call.func.0.clone(),
216 }
217 })?;
218 let _ = FunctionKind::Aggregate;
219
220 let mut fn_ctx = FunctionContext {
221 fragment: call.func.0.clone(),
222 identity: ctx.identity,
223 row_count: 0,
224 runtime_context: &ctx.services.runtime_context,
225 };
226
227 let accumulator = function.accumulator(&mut fn_ctx).ok_or_else(|| {
228 RoutineError::FunctionExecutionFailed {
229 function: call.func.0.clone(),
230 reason: format!("Function {} is not an aggregate", func_name),
231 }
232 })?;
233
234 match call.args.first() {
235 Some(Expression::Column(c)) => {
236 projections.push(Projection::Aggregate {
237 column: c.0.name.text().to_string(),
238 column_fragment: c.0.name.clone(),
239 alias,
240 accumulator,
241 });
242 }
243 Some(Expression::AccessSource(access)) => {
244 projections.push(Projection::Aggregate {
245 column: access.column.name.text().to_string(),
246 column_fragment: access.column.name.clone(),
247 alias,
248 accumulator,
249 });
250 }
251 None => {
252 return Err(RoutineError::FunctionArityMismatch {
253 function: call.func.0.clone(),
254 expected: 1,
255 actual: 0,
256 }
257 .into());
258 }
259 Some(arg) => {
260 let actual_type = arg.infer_type().ok_or_else(|| {
261 RoutineError::FunctionExecutionFailed {
262 function: call.func.0.clone(),
263 reason: "aggregate function arguments must be column references".to_string(),
264 }
265 })?;
266 let expected = function.accepted_types().expected_at(0).to_vec();
267 return Err(RoutineError::FunctionInvalidArgumentType {
268 function: call.func.0.clone(),
269 argument_index: 0,
270 expected,
271 actual: actual_type,
272 }
273 .into());
274 }
275 }
276 }
277
278 _ => panic!("Expected aggregate call expression, got: {actual_expr:#?}"),
279 }
280 }
281 Ok((keys, projections))
282}
283
284fn align_column_data(group_key_order: &[Vec<Value>], keys: &[Vec<Value>], data: &mut ColumnBuffer) -> Result<()> {
285 let mut key_to_index = HashMap::new();
286 for (i, key) in keys.iter().enumerate() {
287 key_to_index.insert(key, i);
288 }
289
290 let reorder_indices: Vec<usize> = group_key_order
291 .iter()
292 .map(|k| {
293 key_to_index.get(k).copied().ok_or_else(|| {
294 CoreError::FrameError {
295 message: format!("Group key {:?} missing in aggregate output", k),
296 }
297 .into()
298 })
299 })
300 .collect::<Result<Vec<_>>>()?;
301
302 data.reorder(&reorder_indices);
303 Ok(())
304}