Skip to main content

reifydb_engine/expression/
call.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{Column, columns::Columns, data::ColumnData, view::group_by::GroupByView};
5use reifydb_function::{AggregateFunction, AggregateFunctionContext, ScalarFunctionContext, registry::Functions};
6use reifydb_rql::{
7	expression::{CallExpression, Expression},
8	instruction::{CompiledFunctionDef, Instruction, ScopeType},
9	query::QueryPlan,
10};
11use reifydb_runtime::clock::Clock;
12use reifydb_type::{
13	error::Error,
14	fragment::Fragment,
15	params::Params,
16	value::{Value, identity::IdentityId, r#type::Type},
17};
18
19use super::eval::evaluate;
20use crate::{
21	Result,
22	error::EngineError,
23	expression::context::EvalContext,
24	vm::{
25		scalar,
26		stack::{SymbolTable, Variable},
27	},
28};
29
30/// Strip the leading `$` from a variable name if present
31fn strip_dollar_prefix(name: &str) -> String {
32	if name.starts_with('$') {
33		name[1..].to_string()
34	} else {
35		name.to_string()
36	}
37}
38
39/// Convert a slice of Values into ColumnData
40fn column_data_from_values(values: &[Value]) -> ColumnData {
41	if values.is_empty() {
42		return ColumnData::none_typed(Type::Boolean, 0);
43	}
44
45	let mut data = ColumnData::none_typed(Type::Boolean, 0);
46	for value in values {
47		data.push_value(value.clone());
48	}
49	data
50}
51
52pub(crate) fn call_eval(
53	ctx: &EvalContext,
54	call: &CallExpression,
55	functions: &Functions,
56	clock: &Clock,
57) -> Result<Column> {
58	let function_name = call.func.0.text();
59
60	// Check if we're in aggregation context and if function exists as aggregate
61	// FIXME this is a quick hack - this should be derived from a call stack
62	if ctx.is_aggregate_context {
63		if let Some(aggregate_fn) = functions.get_aggregate(function_name) {
64			return handle_aggregate_function(ctx, call, aggregate_fn, functions, clock);
65		}
66	}
67
68	// Evaluate arguments first (needed for both user-defined and built-in functions)
69	let arguments = evaluate_arguments(ctx, &call.args, functions, clock)?;
70
71	// Try user-defined function from symbol table first
72	if let Some(func_def) = ctx.symbol_table.get_function(function_name) {
73		return call_user_defined_function(ctx, call, func_def.clone(), &arguments, functions, clock);
74	}
75
76	// Fall back to built-in scalar function handling
77	let functor = functions.get_scalar(function_name).ok_or_else(|| -> Error {
78		EngineError::UnknownFunction {
79			name: call.func.0.text().to_string(),
80			fragment: call.func.0.clone(),
81		}
82		.into()
83	})?;
84
85	let row_count = ctx.row_count;
86
87	let final_data = functor.scalar(ScalarFunctionContext {
88		fragment: call.func.0.clone(),
89		columns: &arguments,
90		row_count,
91		clock,
92		identity: ctx.identity,
93	})?;
94
95	Ok(Column {
96		name: call.full_fragment_owned(),
97		data: final_data,
98	})
99}
100
101/// Execute a user-defined function for each row, returning a column of results
102fn call_user_defined_function(
103	ctx: &EvalContext,
104	call: &CallExpression,
105	func_def: CompiledFunctionDef,
106	arguments: &Columns,
107	functions: &Functions,
108	clock: &Clock,
109) -> Result<Column> {
110	let row_count = ctx.row_count;
111	let mut results: Vec<Value> = Vec::with_capacity(row_count);
112
113	// Function body is already pre-compiled
114	let body_instructions = &func_def.body;
115
116	let mut func_symbol_table = ctx.symbol_table.clone();
117
118	// For each row, execute the function
119	for row_idx in 0..row_count {
120		let base_depth = func_symbol_table.scope_depth();
121		func_symbol_table.enter_scope(ScopeType::Function);
122
123		// Bind arguments to parameters
124		for (param, arg_col) in func_def.parameters.iter().zip(arguments.iter()) {
125			let param_name = strip_dollar_prefix(param.name.text());
126			let value = arg_col.data().get_value(row_idx);
127			func_symbol_table.set(param_name, Variable::scalar(value), true)?;
128		}
129
130		// Execute function body instructions and get result
131		let result = execute_function_body_for_scalar(
132			&body_instructions,
133			&mut func_symbol_table,
134			ctx.params,
135			functions,
136			clock,
137			ctx.identity,
138		)?;
139
140		while func_symbol_table.scope_depth() > base_depth {
141			let _ = func_symbol_table.exit_scope();
142		}
143
144		results.push(result);
145	}
146
147	// Convert results to ColumnData
148	let data = column_data_from_values(&results);
149	Ok(Column {
150		name: call.full_fragment_owned(),
151		data,
152	})
153}
154
155/// Execute function body instructions and return a scalar result.
156/// Uses a simple stack-based interpreter matching the new bytecode ISA.
157fn execute_function_body_for_scalar(
158	instructions: &[Instruction],
159	symbol_table: &mut SymbolTable,
160	params: &Params,
161	functions: &Functions,
162	clock: &Clock,
163	identity: IdentityId,
164) -> Result<Value> {
165	let mut ip = 0;
166	let mut stack: Vec<Value> = Vec::new();
167
168	while ip < instructions.len() {
169		match &instructions[ip] {
170			Instruction::Halt => break,
171			Instruction::Nop => {}
172
173			// === Stack ===
174			Instruction::PushConst(v) => stack.push(v.clone()),
175			Instruction::PushNone => stack.push(Value::none()),
176			Instruction::Pop => {
177				stack.pop();
178			}
179			Instruction::Dup => {
180				if let Some(v) = stack.last() {
181					stack.push(v.clone());
182				}
183			}
184
185			// === Variables ===
186			Instruction::LoadVar(name) => {
187				let var_name = strip_dollar_prefix(name.text());
188				let val = symbol_table
189					.get(&var_name)
190					.map(|v| match v {
191						Variable::Scalar(c) => c.scalar_value(),
192						_ => Value::none(),
193					})
194					.unwrap_or(Value::none());
195				stack.push(val);
196			}
197			Instruction::StoreVar(name) => {
198				let val = stack.pop().unwrap_or(Value::none());
199				let var_name = strip_dollar_prefix(name.text());
200				symbol_table.set(var_name, Variable::scalar(val), true)?;
201			}
202			Instruction::DeclareVar(name) => {
203				let val = stack.pop().unwrap_or(Value::none());
204				let var_name = strip_dollar_prefix(name.text());
205				symbol_table.set(var_name, Variable::scalar(val), true)?;
206			}
207
208			// === Arithmetic ===
209			Instruction::Add => {
210				let r = stack.pop().unwrap_or(Value::none());
211				let l = stack.pop().unwrap_or(Value::none());
212				stack.push(scalar::scalar_add(l, r)?);
213			}
214			Instruction::Sub => {
215				let r = stack.pop().unwrap_or(Value::none());
216				let l = stack.pop().unwrap_or(Value::none());
217				stack.push(scalar::scalar_sub(l, r)?);
218			}
219			Instruction::Mul => {
220				let r = stack.pop().unwrap_or(Value::none());
221				let l = stack.pop().unwrap_or(Value::none());
222				stack.push(scalar::scalar_mul(l, r)?);
223			}
224			Instruction::Div => {
225				let r = stack.pop().unwrap_or(Value::none());
226				let l = stack.pop().unwrap_or(Value::none());
227				stack.push(scalar::scalar_div(l, r)?);
228			}
229			Instruction::Rem => {
230				let r = stack.pop().unwrap_or(Value::none());
231				let l = stack.pop().unwrap_or(Value::none());
232				stack.push(scalar::scalar_rem(l, r)?);
233			}
234
235			// === Unary ===
236			Instruction::Negate => {
237				let v = stack.pop().unwrap_or(Value::none());
238				stack.push(scalar::scalar_negate(v)?);
239			}
240			Instruction::LogicNot => {
241				let v = stack.pop().unwrap_or(Value::none());
242				stack.push(scalar::scalar_not(&v));
243			}
244
245			// === Comparison ===
246			Instruction::CmpEq => {
247				let r = stack.pop().unwrap_or(Value::none());
248				let l = stack.pop().unwrap_or(Value::none());
249				stack.push(scalar::scalar_eq(&l, &r));
250			}
251			Instruction::CmpNe => {
252				let r = stack.pop().unwrap_or(Value::none());
253				let l = stack.pop().unwrap_or(Value::none());
254				stack.push(scalar::scalar_ne(&l, &r));
255			}
256			Instruction::CmpLt => {
257				let r = stack.pop().unwrap_or(Value::none());
258				let l = stack.pop().unwrap_or(Value::none());
259				stack.push(scalar::scalar_lt(&l, &r));
260			}
261			Instruction::CmpLe => {
262				let r = stack.pop().unwrap_or(Value::none());
263				let l = stack.pop().unwrap_or(Value::none());
264				stack.push(scalar::scalar_le(&l, &r));
265			}
266			Instruction::CmpGt => {
267				let r = stack.pop().unwrap_or(Value::none());
268				let l = stack.pop().unwrap_or(Value::none());
269				stack.push(scalar::scalar_gt(&l, &r));
270			}
271			Instruction::CmpGe => {
272				let r = stack.pop().unwrap_or(Value::none());
273				let l = stack.pop().unwrap_or(Value::none());
274				stack.push(scalar::scalar_ge(&l, &r));
275			}
276
277			// === Logic ===
278			Instruction::LogicAnd => {
279				let r = stack.pop().unwrap_or(Value::none());
280				let l = stack.pop().unwrap_or(Value::none());
281				stack.push(scalar::scalar_and(&l, &r));
282			}
283			Instruction::LogicOr => {
284				let r = stack.pop().unwrap_or(Value::none());
285				let l = stack.pop().unwrap_or(Value::none());
286				stack.push(scalar::scalar_or(&l, &r));
287			}
288			Instruction::LogicXor => {
289				let r = stack.pop().unwrap_or(Value::none());
290				let l = stack.pop().unwrap_or(Value::none());
291				stack.push(scalar::scalar_xor(&l, &r));
292			}
293
294			// === Compound ===
295			Instruction::Cast(target) => {
296				let v = stack.pop().unwrap_or(Value::none());
297				stack.push(scalar::scalar_cast(v, target.clone())?);
298			}
299			Instruction::Between => {
300				let upper = stack.pop().unwrap_or(Value::none());
301				let lower = stack.pop().unwrap_or(Value::none());
302				let val = stack.pop().unwrap_or(Value::none());
303				let ge = scalar::scalar_ge(&val, &lower);
304				let le = scalar::scalar_le(&val, &upper);
305				let result = match (ge, le) {
306					(Value::Boolean(a), Value::Boolean(b)) => Value::Boolean(a && b),
307					_ => Value::none(),
308				};
309				stack.push(result);
310			}
311			Instruction::InList {
312				count,
313				negated,
314			} => {
315				let count = *count as usize;
316				let negated = *negated;
317				let mut items: Vec<Value> = Vec::with_capacity(count);
318				for _ in 0..count {
319					items.push(stack.pop().unwrap_or(Value::none()));
320				}
321				items.reverse();
322				let val = stack.pop().unwrap_or(Value::none());
323				let has_undefined = matches!(val, Value::None { .. })
324					|| items.iter().any(|item| matches!(item, Value::None { .. }));
325				if has_undefined {
326					stack.push(Value::none());
327				} else {
328					let found = items.iter().any(|item| {
329						matches!(scalar::scalar_eq(&val, item), Value::Boolean(true))
330					});
331					stack.push(Value::Boolean(if negated {
332						!found
333					} else {
334						found
335					}));
336				}
337			}
338
339			// === Control flow ===
340			Instruction::Jump(addr) => {
341				ip = *addr;
342				continue;
343			}
344			Instruction::JumpIfFalsePop(addr) => {
345				let v = stack.pop().unwrap_or(Value::none());
346				if !scalar::value_is_truthy(&v) {
347					ip = *addr;
348					continue;
349				}
350			}
351			Instruction::JumpIfTruePop(addr) => {
352				let v = stack.pop().unwrap_or(Value::none());
353				if scalar::value_is_truthy(&v) {
354					ip = *addr;
355					continue;
356				}
357			}
358
359			Instruction::EnterScope(scope_type) => {
360				symbol_table.enter_scope(scope_type.clone());
361			}
362			Instruction::ExitScope => {
363				let _ = symbol_table.exit_scope();
364			}
365
366			// === Return ===
367			Instruction::ReturnValue => {
368				let v = stack.pop().unwrap_or(Value::none());
369				return Ok(v);
370			}
371			Instruction::ReturnVoid => {
372				return Ok(Value::none());
373			}
374
375			// === Query ===
376			Instruction::Query(plan) => match plan {
377				QueryPlan::Map(map_node) => {
378					if map_node.input.is_none() && !map_node.map.is_empty() {
379						let evaluation_context = EvalContext {
380							target: None,
381							columns: Columns::empty(),
382							row_count: 1,
383							take: None,
384							params,
385							symbol_table,
386							is_aggregate_context: false,
387							functions,
388							clock,
389							arena: None,
390							identity,
391						};
392						let result_column = evaluate(
393							&evaluation_context,
394							&map_node.map[0],
395							functions,
396							clock,
397						)?;
398						if result_column.data.len() > 0 {
399							stack.push(result_column.data.get_value(0));
400						}
401					}
402				}
403				_ => {
404					// Other plan types would need full VM execution
405				}
406			},
407
408			Instruction::Emit => {
409				// Emit in function body context - the stack top is the result
410			}
411
412			// === Function calls within function body ===
413			Instruction::Call {
414				name,
415				arity,
416				..
417			} => {
418				let arity = *arity as usize;
419				let mut args: Vec<Value> = Vec::with_capacity(arity);
420				for _ in 0..arity {
421					args.push(stack.pop().unwrap_or(Value::none()));
422				}
423				args.reverse();
424
425				// Try user-defined function
426				if let Some(func_def) = symbol_table.get_function(name.text()) {
427					let func_def = func_def.clone();
428					let base_depth = symbol_table.scope_depth();
429					symbol_table.enter_scope(ScopeType::Function);
430					for (param, arg_val) in func_def.parameters.iter().zip(args.iter()) {
431						let param_name = strip_dollar_prefix(param.name.text());
432						symbol_table.set(
433							param_name,
434							Variable::scalar(arg_val.clone()),
435							true,
436						)?;
437					}
438					let result = execute_function_body_for_scalar(
439						&func_def.body,
440						symbol_table,
441						params,
442						functions,
443						clock,
444						identity,
445					)?;
446					while symbol_table.scope_depth() > base_depth {
447						let _ = symbol_table.exit_scope();
448					}
449					stack.push(result);
450				} else if let Some(functor) = functions.get_scalar(name.text()) {
451					let mut arg_cols = Vec::with_capacity(args.len());
452					for arg in &args {
453						let mut data = ColumnData::none_typed(Type::Boolean, 0);
454						data.push_value(arg.clone());
455						arg_cols.push(Column::new("_", data));
456					}
457					let columns = Columns::new(arg_cols);
458					let result_data = functor.scalar(ScalarFunctionContext {
459						fragment: name.clone(),
460						columns: &columns,
461						row_count: 1,
462						clock,
463						identity,
464					})?;
465					if result_data.len() > 0 {
466						stack.push(result_data.get_value(0));
467					} else {
468						stack.push(Value::none());
469					}
470				}
471			}
472
473			Instruction::DefineFunction(func_def) => {
474				symbol_table.define_function(func_def.name.text().to_string(), func_def.clone());
475			}
476
477			_ => {
478				// DDL/DML instructions not expected in function body
479			}
480		}
481		ip += 1;
482	}
483
484	// Return top of stack or Undefined
485	Ok(stack.pop().unwrap_or(Value::none()))
486}
487
488fn handle_aggregate_function(
489	ctx: &EvalContext,
490	call: &CallExpression,
491	mut aggregate_fn: Box<dyn AggregateFunction>,
492	functions: &Functions,
493	clock: &Clock,
494) -> Result<Column> {
495	// Create a single group containing all row indices for aggregation
496	let mut group_view = GroupByView::new();
497	let all_indices: Vec<usize> = (0..ctx.row_count).collect();
498	group_view.insert(Vec::<Value>::new(), all_indices); // Empty group key for single group
499
500	// Determine which column to aggregate over
501	let column = if call.args.is_empty() {
502		// For count() with no arguments, create a dummy column
503		Column {
504			name: Fragment::internal("dummy"),
505			data: ColumnData::int4_with_capacity(ctx.row_count),
506		}
507	} else {
508		// For functions with arguments like sum(amount), use the first argument column
509		let arguments = evaluate_arguments(ctx, &call.args, functions, clock)?;
510		arguments[0].clone()
511	};
512
513	// Call the aggregate function
514	aggregate_fn.aggregate(AggregateFunctionContext {
515		fragment: call.func.0.clone(),
516		column: &column,
517		groups: &group_view,
518	})?;
519
520	// Finalize and get results
521	let (_keys, result_data) = aggregate_fn.finalize()?;
522
523	Ok(Column {
524		name: call.full_fragment_owned(),
525		data: result_data,
526	})
527}
528
529fn evaluate_arguments(
530	ctx: &EvalContext,
531	expressions: &Vec<Expression>,
532	functions: &Functions,
533	clock: &Clock,
534) -> Result<Columns> {
535	let inner_ctx = EvalContext {
536		target: None,
537		columns: ctx.columns.clone(),
538		row_count: ctx.row_count,
539		take: ctx.take,
540		params: ctx.params,
541		symbol_table: ctx.symbol_table,
542		is_aggregate_context: ctx.is_aggregate_context,
543		functions: ctx.functions,
544		clock: ctx.clock,
545		arena: None,
546		identity: ctx.identity,
547	};
548	let mut result: Vec<Column> = Vec::with_capacity(expressions.len());
549
550	for expression in expressions {
551		match expression {
552			Expression::Type(type_expr) => {
553				let values: Vec<Box<Value>> = (0..ctx.row_count)
554					.map(|_| Box::new(Value::Type(type_expr.ty.clone())))
555					.collect();
556				result.push(Column::new(type_expr.fragment.text(), ColumnData::any(values)));
557			}
558			_ => result.push(evaluate(&inner_ctx, expression, functions, clock)?),
559		}
560	}
561
562	Ok(Columns::new(result))
563}