1use reifydb_core::value::column::{Column, columns::Columns, data::ColumnData, view::group_by::GroupByView};
5use reifydb_function::{AggregateFunction, AggregateFunctionContext, ScalarFunctionContext, registry::Functions};
6use reifydb_rql::{
7 expression::{CallExpression, Expression},
8 instruction::{CompiledFunctionDef, Instruction, ScopeType},
9 query::QueryPlan,
10};
11use reifydb_runtime::clock::Clock;
12use reifydb_type::{
13 error::Error,
14 fragment::Fragment,
15 params::Params,
16 value::{Value, identity::IdentityId, r#type::Type},
17};
18
19use super::eval::evaluate;
20use crate::{
21 Result,
22 error::EngineError,
23 expression::context::EvalContext,
24 vm::{
25 scalar,
26 stack::{SymbolTable, Variable},
27 },
28};
29
30fn strip_dollar_prefix(name: &str) -> String {
32 if name.starts_with('$') {
33 name[1..].to_string()
34 } else {
35 name.to_string()
36 }
37}
38
39fn column_data_from_values(values: &[Value]) -> ColumnData {
41 if values.is_empty() {
42 return ColumnData::none_typed(Type::Boolean, 0);
43 }
44
45 let mut data = ColumnData::none_typed(Type::Boolean, 0);
46 for value in values {
47 data.push_value(value.clone());
48 }
49 data
50}
51
52pub(crate) fn call_eval(
53 ctx: &EvalContext,
54 call: &CallExpression,
55 functions: &Functions,
56 clock: &Clock,
57) -> Result<Column> {
58 let function_name = call.func.0.text();
59
60 if ctx.is_aggregate_context {
63 if let Some(aggregate_fn) = functions.get_aggregate(function_name) {
64 return handle_aggregate_function(ctx, call, aggregate_fn, functions, clock);
65 }
66 }
67
68 let arguments = evaluate_arguments(ctx, &call.args, functions, clock)?;
70
71 if let Some(func_def) = ctx.symbol_table.get_function(function_name) {
73 return call_user_defined_function(ctx, call, func_def.clone(), &arguments, functions, clock);
74 }
75
76 let functor = functions.get_scalar(function_name).ok_or_else(|| -> Error {
78 EngineError::UnknownFunction {
79 name: call.func.0.text().to_string(),
80 fragment: call.func.0.clone(),
81 }
82 .into()
83 })?;
84
85 let row_count = ctx.row_count;
86
87 let final_data = functor.scalar(ScalarFunctionContext {
88 fragment: call.func.0.clone(),
89 columns: &arguments,
90 row_count,
91 clock,
92 identity: ctx.identity,
93 })?;
94
95 Ok(Column {
96 name: call.full_fragment_owned(),
97 data: final_data,
98 })
99}
100
101fn call_user_defined_function(
103 ctx: &EvalContext,
104 call: &CallExpression,
105 func_def: CompiledFunctionDef,
106 arguments: &Columns,
107 functions: &Functions,
108 clock: &Clock,
109) -> Result<Column> {
110 let row_count = ctx.row_count;
111 let mut results: Vec<Value> = Vec::with_capacity(row_count);
112
113 let body_instructions = &func_def.body;
115
116 let mut func_symbol_table = ctx.symbol_table.clone();
117
118 for row_idx in 0..row_count {
120 let base_depth = func_symbol_table.scope_depth();
121 func_symbol_table.enter_scope(ScopeType::Function);
122
123 for (param, arg_col) in func_def.parameters.iter().zip(arguments.iter()) {
125 let param_name = strip_dollar_prefix(param.name.text());
126 let value = arg_col.data().get_value(row_idx);
127 func_symbol_table.set(param_name, Variable::scalar(value), true)?;
128 }
129
130 let result = execute_function_body_for_scalar(
132 &body_instructions,
133 &mut func_symbol_table,
134 ctx.params,
135 functions,
136 clock,
137 ctx.identity,
138 )?;
139
140 while func_symbol_table.scope_depth() > base_depth {
141 let _ = func_symbol_table.exit_scope();
142 }
143
144 results.push(result);
145 }
146
147 let data = column_data_from_values(&results);
149 Ok(Column {
150 name: call.full_fragment_owned(),
151 data,
152 })
153}
154
155fn execute_function_body_for_scalar(
158 instructions: &[Instruction],
159 symbol_table: &mut SymbolTable,
160 params: &Params,
161 functions: &Functions,
162 clock: &Clock,
163 identity: IdentityId,
164) -> Result<Value> {
165 let mut ip = 0;
166 let mut stack: Vec<Value> = Vec::new();
167
168 while ip < instructions.len() {
169 match &instructions[ip] {
170 Instruction::Halt => break,
171 Instruction::Nop => {}
172
173 Instruction::PushConst(v) => stack.push(v.clone()),
175 Instruction::PushNone => stack.push(Value::none()),
176 Instruction::Pop => {
177 stack.pop();
178 }
179 Instruction::Dup => {
180 if let Some(v) = stack.last() {
181 stack.push(v.clone());
182 }
183 }
184
185 Instruction::LoadVar(name) => {
187 let var_name = strip_dollar_prefix(name.text());
188 let val = symbol_table
189 .get(&var_name)
190 .map(|v| match v {
191 Variable::Scalar(c) => c.scalar_value(),
192 _ => Value::none(),
193 })
194 .unwrap_or(Value::none());
195 stack.push(val);
196 }
197 Instruction::StoreVar(name) => {
198 let val = stack.pop().unwrap_or(Value::none());
199 let var_name = strip_dollar_prefix(name.text());
200 symbol_table.set(var_name, Variable::scalar(val), true)?;
201 }
202 Instruction::DeclareVar(name) => {
203 let val = stack.pop().unwrap_or(Value::none());
204 let var_name = strip_dollar_prefix(name.text());
205 symbol_table.set(var_name, Variable::scalar(val), true)?;
206 }
207
208 Instruction::Add => {
210 let r = stack.pop().unwrap_or(Value::none());
211 let l = stack.pop().unwrap_or(Value::none());
212 stack.push(scalar::scalar_add(l, r)?);
213 }
214 Instruction::Sub => {
215 let r = stack.pop().unwrap_or(Value::none());
216 let l = stack.pop().unwrap_or(Value::none());
217 stack.push(scalar::scalar_sub(l, r)?);
218 }
219 Instruction::Mul => {
220 let r = stack.pop().unwrap_or(Value::none());
221 let l = stack.pop().unwrap_or(Value::none());
222 stack.push(scalar::scalar_mul(l, r)?);
223 }
224 Instruction::Div => {
225 let r = stack.pop().unwrap_or(Value::none());
226 let l = stack.pop().unwrap_or(Value::none());
227 stack.push(scalar::scalar_div(l, r)?);
228 }
229 Instruction::Rem => {
230 let r = stack.pop().unwrap_or(Value::none());
231 let l = stack.pop().unwrap_or(Value::none());
232 stack.push(scalar::scalar_rem(l, r)?);
233 }
234
235 Instruction::Negate => {
237 let v = stack.pop().unwrap_or(Value::none());
238 stack.push(scalar::scalar_negate(v)?);
239 }
240 Instruction::LogicNot => {
241 let v = stack.pop().unwrap_or(Value::none());
242 stack.push(scalar::scalar_not(&v));
243 }
244
245 Instruction::CmpEq => {
247 let r = stack.pop().unwrap_or(Value::none());
248 let l = stack.pop().unwrap_or(Value::none());
249 stack.push(scalar::scalar_eq(&l, &r));
250 }
251 Instruction::CmpNe => {
252 let r = stack.pop().unwrap_or(Value::none());
253 let l = stack.pop().unwrap_or(Value::none());
254 stack.push(scalar::scalar_ne(&l, &r));
255 }
256 Instruction::CmpLt => {
257 let r = stack.pop().unwrap_or(Value::none());
258 let l = stack.pop().unwrap_or(Value::none());
259 stack.push(scalar::scalar_lt(&l, &r));
260 }
261 Instruction::CmpLe => {
262 let r = stack.pop().unwrap_or(Value::none());
263 let l = stack.pop().unwrap_or(Value::none());
264 stack.push(scalar::scalar_le(&l, &r));
265 }
266 Instruction::CmpGt => {
267 let r = stack.pop().unwrap_or(Value::none());
268 let l = stack.pop().unwrap_or(Value::none());
269 stack.push(scalar::scalar_gt(&l, &r));
270 }
271 Instruction::CmpGe => {
272 let r = stack.pop().unwrap_or(Value::none());
273 let l = stack.pop().unwrap_or(Value::none());
274 stack.push(scalar::scalar_ge(&l, &r));
275 }
276
277 Instruction::LogicAnd => {
279 let r = stack.pop().unwrap_or(Value::none());
280 let l = stack.pop().unwrap_or(Value::none());
281 stack.push(scalar::scalar_and(&l, &r));
282 }
283 Instruction::LogicOr => {
284 let r = stack.pop().unwrap_or(Value::none());
285 let l = stack.pop().unwrap_or(Value::none());
286 stack.push(scalar::scalar_or(&l, &r));
287 }
288 Instruction::LogicXor => {
289 let r = stack.pop().unwrap_or(Value::none());
290 let l = stack.pop().unwrap_or(Value::none());
291 stack.push(scalar::scalar_xor(&l, &r));
292 }
293
294 Instruction::Cast(target) => {
296 let v = stack.pop().unwrap_or(Value::none());
297 stack.push(scalar::scalar_cast(v, target.clone())?);
298 }
299 Instruction::Between => {
300 let upper = stack.pop().unwrap_or(Value::none());
301 let lower = stack.pop().unwrap_or(Value::none());
302 let val = stack.pop().unwrap_or(Value::none());
303 let ge = scalar::scalar_ge(&val, &lower);
304 let le = scalar::scalar_le(&val, &upper);
305 let result = match (ge, le) {
306 (Value::Boolean(a), Value::Boolean(b)) => Value::Boolean(a && b),
307 _ => Value::none(),
308 };
309 stack.push(result);
310 }
311 Instruction::InList {
312 count,
313 negated,
314 } => {
315 let count = *count as usize;
316 let negated = *negated;
317 let mut items: Vec<Value> = Vec::with_capacity(count);
318 for _ in 0..count {
319 items.push(stack.pop().unwrap_or(Value::none()));
320 }
321 items.reverse();
322 let val = stack.pop().unwrap_or(Value::none());
323 let has_undefined = matches!(val, Value::None { .. })
324 || items.iter().any(|item| matches!(item, Value::None { .. }));
325 if has_undefined {
326 stack.push(Value::none());
327 } else {
328 let found = items.iter().any(|item| {
329 matches!(scalar::scalar_eq(&val, item), Value::Boolean(true))
330 });
331 stack.push(Value::Boolean(if negated {
332 !found
333 } else {
334 found
335 }));
336 }
337 }
338
339 Instruction::Jump(addr) => {
341 ip = *addr;
342 continue;
343 }
344 Instruction::JumpIfFalsePop(addr) => {
345 let v = stack.pop().unwrap_or(Value::none());
346 if !scalar::value_is_truthy(&v) {
347 ip = *addr;
348 continue;
349 }
350 }
351 Instruction::JumpIfTruePop(addr) => {
352 let v = stack.pop().unwrap_or(Value::none());
353 if scalar::value_is_truthy(&v) {
354 ip = *addr;
355 continue;
356 }
357 }
358
359 Instruction::EnterScope(scope_type) => {
360 symbol_table.enter_scope(scope_type.clone());
361 }
362 Instruction::ExitScope => {
363 let _ = symbol_table.exit_scope();
364 }
365
366 Instruction::ReturnValue => {
368 let v = stack.pop().unwrap_or(Value::none());
369 return Ok(v);
370 }
371 Instruction::ReturnVoid => {
372 return Ok(Value::none());
373 }
374
375 Instruction::Query(plan) => match plan {
377 QueryPlan::Map(map_node) => {
378 if map_node.input.is_none() && !map_node.map.is_empty() {
379 let evaluation_context = EvalContext {
380 target: None,
381 columns: Columns::empty(),
382 row_count: 1,
383 take: None,
384 params,
385 symbol_table,
386 is_aggregate_context: false,
387 functions,
388 clock,
389 arena: None,
390 identity,
391 };
392 let result_column = evaluate(
393 &evaluation_context,
394 &map_node.map[0],
395 functions,
396 clock,
397 )?;
398 if result_column.data.len() > 0 {
399 stack.push(result_column.data.get_value(0));
400 }
401 }
402 }
403 _ => {
404 }
406 },
407
408 Instruction::Emit => {
409 }
411
412 Instruction::Call {
414 name,
415 arity,
416 ..
417 } => {
418 let arity = *arity as usize;
419 let mut args: Vec<Value> = Vec::with_capacity(arity);
420 for _ in 0..arity {
421 args.push(stack.pop().unwrap_or(Value::none()));
422 }
423 args.reverse();
424
425 if let Some(func_def) = symbol_table.get_function(name.text()) {
427 let func_def = func_def.clone();
428 let base_depth = symbol_table.scope_depth();
429 symbol_table.enter_scope(ScopeType::Function);
430 for (param, arg_val) in func_def.parameters.iter().zip(args.iter()) {
431 let param_name = strip_dollar_prefix(param.name.text());
432 symbol_table.set(
433 param_name,
434 Variable::scalar(arg_val.clone()),
435 true,
436 )?;
437 }
438 let result = execute_function_body_for_scalar(
439 &func_def.body,
440 symbol_table,
441 params,
442 functions,
443 clock,
444 identity,
445 )?;
446 while symbol_table.scope_depth() > base_depth {
447 let _ = symbol_table.exit_scope();
448 }
449 stack.push(result);
450 } else if let Some(functor) = functions.get_scalar(name.text()) {
451 let mut arg_cols = Vec::with_capacity(args.len());
452 for arg in &args {
453 let mut data = ColumnData::none_typed(Type::Boolean, 0);
454 data.push_value(arg.clone());
455 arg_cols.push(Column::new("_", data));
456 }
457 let columns = Columns::new(arg_cols);
458 let result_data = functor.scalar(ScalarFunctionContext {
459 fragment: name.clone(),
460 columns: &columns,
461 row_count: 1,
462 clock,
463 identity,
464 })?;
465 if result_data.len() > 0 {
466 stack.push(result_data.get_value(0));
467 } else {
468 stack.push(Value::none());
469 }
470 }
471 }
472
473 Instruction::DefineFunction(func_def) => {
474 symbol_table.define_function(func_def.name.text().to_string(), func_def.clone());
475 }
476
477 _ => {
478 }
480 }
481 ip += 1;
482 }
483
484 Ok(stack.pop().unwrap_or(Value::none()))
486}
487
488fn handle_aggregate_function(
489 ctx: &EvalContext,
490 call: &CallExpression,
491 mut aggregate_fn: Box<dyn AggregateFunction>,
492 functions: &Functions,
493 clock: &Clock,
494) -> Result<Column> {
495 let mut group_view = GroupByView::new();
497 let all_indices: Vec<usize> = (0..ctx.row_count).collect();
498 group_view.insert(Vec::<Value>::new(), all_indices); let column = if call.args.is_empty() {
502 Column {
504 name: Fragment::internal("dummy"),
505 data: ColumnData::int4_with_capacity(ctx.row_count),
506 }
507 } else {
508 let arguments = evaluate_arguments(ctx, &call.args, functions, clock)?;
510 arguments[0].clone()
511 };
512
513 aggregate_fn.aggregate(AggregateFunctionContext {
515 fragment: call.func.0.clone(),
516 column: &column,
517 groups: &group_view,
518 })?;
519
520 let (_keys, result_data) = aggregate_fn.finalize()?;
522
523 Ok(Column {
524 name: call.full_fragment_owned(),
525 data: result_data,
526 })
527}
528
529fn evaluate_arguments(
530 ctx: &EvalContext,
531 expressions: &Vec<Expression>,
532 functions: &Functions,
533 clock: &Clock,
534) -> Result<Columns> {
535 let inner_ctx = EvalContext {
536 target: None,
537 columns: ctx.columns.clone(),
538 row_count: ctx.row_count,
539 take: ctx.take,
540 params: ctx.params,
541 symbol_table: ctx.symbol_table,
542 is_aggregate_context: ctx.is_aggregate_context,
543 functions: ctx.functions,
544 clock: ctx.clock,
545 arena: None,
546 identity: ctx.identity,
547 };
548 let mut result: Vec<Column> = Vec::with_capacity(expressions.len());
549
550 for expression in expressions {
551 match expression {
552 Expression::Type(type_expr) => {
553 let values: Vec<Box<Value>> = (0..ctx.row_count)
554 .map(|_| Box::new(Value::Type(type_expr.ty.clone())))
555 .collect();
556 result.push(Column::new(type_expr.fragment.text(), ColumnData::any(values)));
557 }
558 _ => result.push(evaluate(&inner_ctx, expression, functions, clock)?),
559 }
560 }
561
562 Ok(Columns::new(result))
563}