Skip to main content

reifydb_engine/expression/
call.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{Column, columns::Columns, data::ColumnData, view::group_by::GroupByView};
5use reifydb_function::{AggregateFunction, AggregateFunctionContext, ScalarFunctionContext, registry::Functions};
6use reifydb_rql::{
7	expression::{CallExpression, Expression},
8	instruction::{CompiledFunctionDef, Instruction, ScopeType},
9	query::QueryPlan,
10};
11use reifydb_runtime::clock::Clock;
12use reifydb_type::{
13	error,
14	error::diagnostic::function,
15	fragment::Fragment,
16	params::Params,
17	value::{Value, r#type::Type},
18};
19
20use super::eval::evaluate;
21use crate::{
22	expression::context::EvalContext,
23	vm::{
24		scalar,
25		stack::{SymbolTable, Variable},
26	},
27};
28
29/// Strip the leading `$` from a variable name if present
30fn strip_dollar_prefix(name: &str) -> String {
31	if name.starts_with('$') {
32		name[1..].to_string()
33	} else {
34		name.to_string()
35	}
36}
37
38/// Convert a slice of Values into ColumnData
39fn column_data_from_values(values: &[Value]) -> ColumnData {
40	if values.is_empty() {
41		return ColumnData::none_typed(Type::Boolean, 0);
42	}
43
44	let mut data = ColumnData::none_typed(Type::Boolean, 0);
45	for value in values {
46		data.push_value(value.clone());
47	}
48	data
49}
50
51pub(crate) fn call_eval(
52	ctx: &EvalContext,
53	call: &CallExpression,
54	functions: &Functions,
55	clock: &Clock,
56) -> crate::Result<Column> {
57	let function_name = call.func.0.text();
58
59	// Check if we're in aggregation context and if function exists as aggregate
60	// FIXME this is a quick hack - this should be derived from a call stack
61	if ctx.is_aggregate_context {
62		if let Some(aggregate_fn) = functions.get_aggregate(function_name) {
63			return handle_aggregate_function(ctx, call, aggregate_fn, functions, clock);
64		}
65	}
66
67	// Evaluate arguments first (needed for both user-defined and built-in functions)
68	let arguments = evaluate_arguments(ctx, &call.args, functions, clock)?;
69
70	// Try user-defined function from symbol table first
71	if let Some(func_def) = ctx.symbol_table.get_function(function_name) {
72		return call_user_defined_function(ctx, call, func_def.clone(), &arguments, functions, clock);
73	}
74
75	// Fall back to built-in scalar function handling
76	let functor =
77		functions.get_scalar(function_name).ok_or(error!(function::unknown_function(call.func.0.clone())))?;
78
79	let row_count = ctx.row_count;
80
81	let final_data = functor.scalar(ScalarFunctionContext {
82		fragment: call.func.0.clone(),
83		columns: &arguments,
84		row_count,
85		clock,
86	})?;
87
88	Ok(Column {
89		name: call.full_fragment_owned(),
90		data: final_data,
91	})
92}
93
94/// Execute a user-defined function for each row, returning a column of results
95fn call_user_defined_function(
96	ctx: &EvalContext,
97	call: &CallExpression,
98	func_def: CompiledFunctionDef,
99	arguments: &Columns,
100	functions: &Functions,
101	clock: &Clock,
102) -> crate::Result<Column> {
103	let row_count = ctx.row_count;
104	let mut results: Vec<Value> = Vec::with_capacity(row_count);
105
106	// Function body is already pre-compiled
107	let body_instructions = &func_def.body;
108
109	let mut func_symbol_table = ctx.symbol_table.clone();
110
111	// For each row, execute the function
112	for row_idx in 0..row_count {
113		let base_depth = func_symbol_table.scope_depth();
114		func_symbol_table.enter_scope(ScopeType::Function);
115
116		// Bind arguments to parameters
117		for (param, arg_col) in func_def.parameters.iter().zip(arguments.iter()) {
118			let param_name = strip_dollar_prefix(param.name.text());
119			let value = arg_col.data().get_value(row_idx);
120			func_symbol_table.set(param_name, Variable::scalar(value), true)?;
121		}
122
123		// Execute function body instructions and get result
124		let result = execute_function_body_for_scalar(
125			&body_instructions,
126			&mut func_symbol_table,
127			ctx.params,
128			functions,
129			clock,
130		)?;
131
132		while func_symbol_table.scope_depth() > base_depth {
133			let _ = func_symbol_table.exit_scope();
134		}
135
136		results.push(result);
137	}
138
139	// Convert results to ColumnData
140	let data = column_data_from_values(&results);
141	Ok(Column {
142		name: call.full_fragment_owned(),
143		data,
144	})
145}
146
147/// Execute function body instructions and return a scalar result.
148/// Uses a simple stack-based interpreter matching the new bytecode ISA.
149fn execute_function_body_for_scalar(
150	instructions: &[Instruction],
151	symbol_table: &mut SymbolTable,
152	params: &Params,
153	functions: &Functions,
154	clock: &Clock,
155) -> crate::Result<Value> {
156	let mut ip = 0;
157	let mut stack: Vec<Value> = Vec::new();
158
159	while ip < instructions.len() {
160		match &instructions[ip] {
161			Instruction::Halt => break,
162			Instruction::Nop => {}
163
164			// === Stack ===
165			Instruction::PushConst(v) => stack.push(v.clone()),
166			Instruction::PushNone => stack.push(Value::none()),
167			Instruction::Pop => {
168				stack.pop();
169			}
170			Instruction::Dup => {
171				if let Some(v) = stack.last() {
172					stack.push(v.clone());
173				}
174			}
175
176			// === Variables ===
177			Instruction::LoadVar(name) => {
178				let var_name = strip_dollar_prefix(name.text());
179				let val = symbol_table
180					.get(&var_name)
181					.map(|v| match v {
182						Variable::Scalar(c) => c.scalar_value(),
183						_ => Value::none(),
184					})
185					.unwrap_or(Value::none());
186				stack.push(val);
187			}
188			Instruction::StoreVar(name) => {
189				let val = stack.pop().unwrap_or(Value::none());
190				let var_name = strip_dollar_prefix(name.text());
191				symbol_table.set(var_name, Variable::scalar(val), true)?;
192			}
193			Instruction::DeclareVar(name) => {
194				let val = stack.pop().unwrap_or(Value::none());
195				let var_name = strip_dollar_prefix(name.text());
196				symbol_table.set(var_name, Variable::scalar(val), true)?;
197			}
198
199			// === Arithmetic ===
200			Instruction::Add => {
201				let r = stack.pop().unwrap_or(Value::none());
202				let l = stack.pop().unwrap_or(Value::none());
203				stack.push(scalar::scalar_add(l, r)?);
204			}
205			Instruction::Sub => {
206				let r = stack.pop().unwrap_or(Value::none());
207				let l = stack.pop().unwrap_or(Value::none());
208				stack.push(scalar::scalar_sub(l, r)?);
209			}
210			Instruction::Mul => {
211				let r = stack.pop().unwrap_or(Value::none());
212				let l = stack.pop().unwrap_or(Value::none());
213				stack.push(scalar::scalar_mul(l, r)?);
214			}
215			Instruction::Div => {
216				let r = stack.pop().unwrap_or(Value::none());
217				let l = stack.pop().unwrap_or(Value::none());
218				stack.push(scalar::scalar_div(l, r)?);
219			}
220			Instruction::Rem => {
221				let r = stack.pop().unwrap_or(Value::none());
222				let l = stack.pop().unwrap_or(Value::none());
223				stack.push(scalar::scalar_rem(l, r)?);
224			}
225
226			// === Unary ===
227			Instruction::Negate => {
228				let v = stack.pop().unwrap_or(Value::none());
229				stack.push(scalar::scalar_negate(v)?);
230			}
231			Instruction::LogicNot => {
232				let v = stack.pop().unwrap_or(Value::none());
233				stack.push(scalar::scalar_not(&v));
234			}
235
236			// === Comparison ===
237			Instruction::CmpEq => {
238				let r = stack.pop().unwrap_or(Value::none());
239				let l = stack.pop().unwrap_or(Value::none());
240				stack.push(scalar::scalar_eq(&l, &r));
241			}
242			Instruction::CmpNe => {
243				let r = stack.pop().unwrap_or(Value::none());
244				let l = stack.pop().unwrap_or(Value::none());
245				stack.push(scalar::scalar_ne(&l, &r));
246			}
247			Instruction::CmpLt => {
248				let r = stack.pop().unwrap_or(Value::none());
249				let l = stack.pop().unwrap_or(Value::none());
250				stack.push(scalar::scalar_lt(&l, &r));
251			}
252			Instruction::CmpLe => {
253				let r = stack.pop().unwrap_or(Value::none());
254				let l = stack.pop().unwrap_or(Value::none());
255				stack.push(scalar::scalar_le(&l, &r));
256			}
257			Instruction::CmpGt => {
258				let r = stack.pop().unwrap_or(Value::none());
259				let l = stack.pop().unwrap_or(Value::none());
260				stack.push(scalar::scalar_gt(&l, &r));
261			}
262			Instruction::CmpGe => {
263				let r = stack.pop().unwrap_or(Value::none());
264				let l = stack.pop().unwrap_or(Value::none());
265				stack.push(scalar::scalar_ge(&l, &r));
266			}
267
268			// === Logic ===
269			Instruction::LogicAnd => {
270				let r = stack.pop().unwrap_or(Value::none());
271				let l = stack.pop().unwrap_or(Value::none());
272				stack.push(scalar::scalar_and(&l, &r));
273			}
274			Instruction::LogicOr => {
275				let r = stack.pop().unwrap_or(Value::none());
276				let l = stack.pop().unwrap_or(Value::none());
277				stack.push(scalar::scalar_or(&l, &r));
278			}
279			Instruction::LogicXor => {
280				let r = stack.pop().unwrap_or(Value::none());
281				let l = stack.pop().unwrap_or(Value::none());
282				stack.push(scalar::scalar_xor(&l, &r));
283			}
284
285			// === Compound ===
286			Instruction::Cast(target) => {
287				let v = stack.pop().unwrap_or(Value::none());
288				stack.push(scalar::scalar_cast(v, target.clone())?);
289			}
290			Instruction::Between => {
291				let upper = stack.pop().unwrap_or(Value::none());
292				let lower = stack.pop().unwrap_or(Value::none());
293				let val = stack.pop().unwrap_or(Value::none());
294				let ge = scalar::scalar_ge(&val, &lower);
295				let le = scalar::scalar_le(&val, &upper);
296				let result = match (ge, le) {
297					(Value::Boolean(a), Value::Boolean(b)) => Value::Boolean(a && b),
298					_ => Value::none(),
299				};
300				stack.push(result);
301			}
302			Instruction::InList {
303				count,
304				negated,
305			} => {
306				let count = *count as usize;
307				let negated = *negated;
308				let mut items: Vec<Value> = Vec::with_capacity(count);
309				for _ in 0..count {
310					items.push(stack.pop().unwrap_or(Value::none()));
311				}
312				items.reverse();
313				let val = stack.pop().unwrap_or(Value::none());
314				let has_undefined = matches!(val, Value::None { .. })
315					|| items.iter().any(|item| matches!(item, Value::None { .. }));
316				if has_undefined {
317					stack.push(Value::none());
318				} else {
319					let found = items.iter().any(|item| {
320						matches!(scalar::scalar_eq(&val, item), Value::Boolean(true))
321					});
322					stack.push(Value::Boolean(if negated {
323						!found
324					} else {
325						found
326					}));
327				}
328			}
329
330			// === Control flow ===
331			Instruction::Jump(addr) => {
332				ip = *addr;
333				continue;
334			}
335			Instruction::JumpIfFalsePop(addr) => {
336				let v = stack.pop().unwrap_or(Value::none());
337				if !scalar::value_is_truthy(&v) {
338					ip = *addr;
339					continue;
340				}
341			}
342			Instruction::JumpIfTruePop(addr) => {
343				let v = stack.pop().unwrap_or(Value::none());
344				if scalar::value_is_truthy(&v) {
345					ip = *addr;
346					continue;
347				}
348			}
349
350			Instruction::EnterScope(scope_type) => {
351				symbol_table.enter_scope(scope_type.clone());
352			}
353			Instruction::ExitScope => {
354				let _ = symbol_table.exit_scope();
355			}
356
357			// === Return ===
358			Instruction::ReturnValue => {
359				let v = stack.pop().unwrap_or(Value::none());
360				return Ok(v);
361			}
362			Instruction::ReturnVoid => {
363				return Ok(Value::none());
364			}
365
366			// === Query ===
367			Instruction::Query(plan) => match plan {
368				QueryPlan::Map(map_node) => {
369					if map_node.input.is_none() && !map_node.map.is_empty() {
370						let evaluation_context = EvalContext {
371							target: None,
372							columns: Columns::empty(),
373							row_count: 1,
374							take: None,
375							params,
376							symbol_table,
377							is_aggregate_context: false,
378							functions,
379							clock,
380							arena: None,
381						};
382						let result_column = evaluate(
383							&evaluation_context,
384							&map_node.map[0],
385							functions,
386							clock,
387						)?;
388						if result_column.data.len() > 0 {
389							stack.push(result_column.data.get_value(0));
390						}
391					}
392				}
393				_ => {
394					// Other plan types would need full VM execution
395				}
396			},
397
398			Instruction::Emit => {
399				// Emit in function body context - the stack top is the result
400			}
401
402			// === Function calls within function body ===
403			Instruction::Call {
404				name,
405				arity,
406			} => {
407				let arity = *arity as usize;
408				let mut args: Vec<Value> = Vec::with_capacity(arity);
409				for _ in 0..arity {
410					args.push(stack.pop().unwrap_or(Value::none()));
411				}
412				args.reverse();
413
414				// Try user-defined function
415				if let Some(func_def) = symbol_table.get_function(name.text()) {
416					let func_def = func_def.clone();
417					let base_depth = symbol_table.scope_depth();
418					symbol_table.enter_scope(ScopeType::Function);
419					for (param, arg_val) in func_def.parameters.iter().zip(args.iter()) {
420						let param_name = strip_dollar_prefix(param.name.text());
421						symbol_table.set(
422							param_name,
423							Variable::scalar(arg_val.clone()),
424							true,
425						)?;
426					}
427					let result = execute_function_body_for_scalar(
428						&func_def.body,
429						symbol_table,
430						params,
431						functions,
432						clock,
433					)?;
434					while symbol_table.scope_depth() > base_depth {
435						let _ = symbol_table.exit_scope();
436					}
437					stack.push(result);
438				} else if let Some(functor) = functions.get_scalar(name.text()) {
439					let mut arg_cols = Vec::with_capacity(args.len());
440					for arg in &args {
441						let mut data = ColumnData::none_typed(Type::Boolean, 0);
442						data.push_value(arg.clone());
443						arg_cols.push(Column::new("_", data));
444					}
445					let columns = Columns::new(arg_cols);
446					let result_data = functor.scalar(ScalarFunctionContext {
447						fragment: name.clone(),
448						columns: &columns,
449						row_count: 1,
450						clock,
451					})?;
452					if result_data.len() > 0 {
453						stack.push(result_data.get_value(0));
454					} else {
455						stack.push(Value::none());
456					}
457				}
458			}
459
460			Instruction::DefineFunction(func_def) => {
461				symbol_table.define_function(func_def.name.text().to_string(), func_def.clone());
462			}
463
464			_ => {
465				// DDL/DML instructions not expected in function body
466			}
467		}
468		ip += 1;
469	}
470
471	// Return top of stack or Undefined
472	Ok(stack.pop().unwrap_or(Value::none()))
473}
474
475fn handle_aggregate_function(
476	ctx: &EvalContext,
477	call: &CallExpression,
478	mut aggregate_fn: Box<dyn AggregateFunction>,
479	functions: &Functions,
480	clock: &Clock,
481) -> crate::Result<Column> {
482	// Create a single group containing all row indices for aggregation
483	let mut group_view = GroupByView::new();
484	let all_indices: Vec<usize> = (0..ctx.row_count).collect();
485	group_view.insert(Vec::<Value>::new(), all_indices); // Empty group key for single group
486
487	// Determine which column to aggregate over
488	let column = if call.args.is_empty() {
489		// For count() with no arguments, create a dummy column
490		Column {
491			name: Fragment::internal("dummy"),
492			data: ColumnData::int4_with_capacity(ctx.row_count),
493		}
494	} else {
495		// For functions with arguments like sum(amount), use the first argument column
496		let arguments = evaluate_arguments(ctx, &call.args, functions, clock)?;
497		arguments[0].clone()
498	};
499
500	// Call the aggregate function
501	aggregate_fn.aggregate(AggregateFunctionContext {
502		fragment: call.func.0.clone(),
503		column: &column,
504		groups: &group_view,
505	})?;
506
507	// Finalize and get results
508	let (_keys, result_data) = aggregate_fn.finalize()?;
509
510	Ok(Column {
511		name: call.full_fragment_owned(),
512		data: result_data,
513	})
514}
515
516fn evaluate_arguments(
517	ctx: &EvalContext,
518	expressions: &Vec<Expression>,
519	functions: &Functions,
520	clock: &Clock,
521) -> crate::Result<Columns> {
522	let inner_ctx = EvalContext {
523		target: None,
524		columns: ctx.columns.clone(),
525		row_count: ctx.row_count,
526		take: ctx.take,
527		params: ctx.params,
528		symbol_table: ctx.symbol_table,
529		is_aggregate_context: ctx.is_aggregate_context,
530		functions: ctx.functions,
531		clock: ctx.clock,
532		arena: None,
533	};
534	let mut result: Vec<Column> = Vec::with_capacity(expressions.len());
535
536	for expression in expressions {
537		match expression {
538			Expression::Type(type_expr) => {
539				let values: Vec<Box<Value>> = (0..ctx.row_count)
540					.map(|_| Box::new(Value::Type(type_expr.ty.clone())))
541					.collect();
542				result.push(Column::new(type_expr.fragment.text(), ColumnData::any(values)));
543			}
544			_ => result.push(evaluate(&inner_ctx, expression, functions, clock)?),
545		}
546	}
547
548	Ok(Columns::new(result))
549}