Skip to main content

reifydb_engine/expression/
call.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{Column, columns::Columns, data::ColumnData, view::group_by::GroupByView};
5use reifydb_routine::function::{AggregateFunctionContext, ScalarFunctionContext, registry::Functions};
6use reifydb_rql::{
7	expression::CallExpression,
8	instruction::{CompiledFunction, Instruction, ScopeType},
9	query::QueryPlan,
10};
11use reifydb_runtime::context::RuntimeContext;
12use reifydb_type::{
13	error::Error,
14	fragment::Fragment,
15	params::Params,
16	value::{Value, identity::IdentityId, r#type::Type},
17};
18
19use super::eval::evaluate;
20use crate::{
21	Result,
22	error::EngineError,
23	expression::context::{EvalContext, EvalSession},
24	vm::{
25		scalar,
26		stack::{SymbolTable, Variable},
27	},
28};
29
30/// Strip the leading `$` from a variable name if present
31fn strip_dollar_prefix(name: &str) -> String {
32	if name.starts_with('$') {
33		name[1..].to_string()
34	} else {
35		name.to_string()
36	}
37}
38
39/// Convert a slice of Values into ColumnData
40fn column_data_from_values(values: &[Value]) -> ColumnData {
41	if values.is_empty() {
42		return ColumnData::none_typed(Type::Boolean, 0);
43	}
44
45	let mut data = ColumnData::none_typed(Type::Boolean, 0);
46	for value in values {
47		data.push_value(value.clone());
48	}
49	data
50}
51
52/// Evaluate a call expression with pre-evaluated arguments (avoids re-compiling argument expressions).
53pub(crate) fn call_eval_with_args(
54	ctx: &EvalContext,
55	call: &CallExpression,
56	arguments: Columns,
57	functions: &Functions,
58) -> Result<Column> {
59	let function_name = call.func.0.text();
60
61	// Check if we're in aggregation context and if function exists as aggregate
62	if ctx.is_aggregate_context {
63		if let Some(mut aggregate_fn) = functions.get_aggregate(function_name) {
64			let column = if call.args.is_empty() {
65				Column {
66					name: Fragment::internal("dummy"),
67					data: ColumnData::with_capacity(Type::Int4, ctx.row_count),
68				}
69			} else {
70				arguments[0].clone()
71			};
72
73			let mut group_view = GroupByView::new();
74			let all_indices: Vec<usize> = (0..ctx.row_count).collect();
75			group_view.insert(Vec::<Value>::new(), all_indices);
76
77			let agg_fragment = call.func.0.clone();
78			aggregate_fn
79				.aggregate(AggregateFunctionContext {
80					fragment: agg_fragment.clone(),
81					column: &column,
82					groups: &group_view,
83				})
84				.map_err(|e| e.with_context(agg_fragment.clone()))?;
85
86			let (_keys, result_data) = aggregate_fn.finalize().map_err(|e| e.with_context(agg_fragment))?;
87
88			return Ok(Column {
89				name: call.full_fragment_owned(),
90				data: result_data,
91			});
92		}
93	}
94
95	// Try user-defined function from symbol table first
96	if let Some(func_def) = ctx.symbols.get_function(function_name) {
97		return call_user_defined_function(ctx, call, func_def.clone(), &arguments, functions);
98	}
99
100	// Fall back to built-in scalar function handling
101	let functor = functions.get_scalar(function_name).ok_or_else(|| -> Error {
102		EngineError::UnknownFunction {
103			name: call.func.0.text().to_string(),
104			fragment: call.func.0.clone(),
105		}
106		.into()
107	})?;
108
109	let row_count = ctx.row_count;
110
111	let fn_fragment = call.func.0.clone();
112	let final_data = functor
113		.scalar(ScalarFunctionContext {
114			fragment: fn_fragment.clone(),
115			columns: &arguments,
116			row_count,
117			runtime_context: ctx.runtime_context,
118			identity: ctx.identity,
119		})
120		.map_err(|e| e.with_context(fn_fragment))?;
121
122	Ok(Column {
123		name: call.full_fragment_owned(),
124		data: final_data,
125	})
126}
127
128/// Execute a user-defined function for each row, returning a column of results
129fn call_user_defined_function(
130	ctx: &EvalContext,
131	call: &CallExpression,
132	func_def: CompiledFunction,
133	arguments: &Columns,
134	functions: &Functions,
135) -> Result<Column> {
136	let row_count = ctx.row_count;
137	let mut results: Vec<Value> = Vec::with_capacity(row_count);
138
139	// Function body is already pre-compiled
140	let body_instructions = &func_def.body;
141
142	let mut func_symbols = ctx.symbols.clone();
143
144	// For each row, execute the function
145	for row_idx in 0..row_count {
146		let base_depth = func_symbols.scope_depth();
147		func_symbols.enter_scope(ScopeType::Function);
148
149		// Bind arguments to parameters
150		for (param, arg_col) in func_def.parameters.iter().zip(arguments.iter()) {
151			let param_name = strip_dollar_prefix(param.name.text());
152			let value = arg_col.data().get_value(row_idx);
153			func_symbols.set(param_name, Variable::scalar(value), true)?;
154		}
155
156		// Execute function body instructions and get result
157		let result = execute_function_body_for_scalar(
158			&body_instructions,
159			&mut func_symbols,
160			ctx.params,
161			functions,
162			ctx.runtime_context,
163			ctx.identity,
164		)?;
165
166		while func_symbols.scope_depth() > base_depth {
167			let _ = func_symbols.exit_scope();
168		}
169
170		results.push(result);
171	}
172
173	// Convert results to ColumnData
174	let data = column_data_from_values(&results);
175	Ok(Column {
176		name: call.full_fragment_owned(),
177		data,
178	})
179}
180
181/// Execute function body instructions and return a scalar result.
182/// Uses a simple stack-based interpreter matching the new bytecode ISA.
183fn execute_function_body_for_scalar(
184	instructions: &[Instruction],
185	symbols: &mut SymbolTable,
186	params: &Params,
187	functions: &Functions,
188	runtime_context: &RuntimeContext,
189	identity: IdentityId,
190) -> Result<Value> {
191	let mut ip = 0;
192	let mut stack: Vec<Value> = Vec::new();
193
194	while ip < instructions.len() {
195		match &instructions[ip] {
196			Instruction::Halt => break,
197			Instruction::Nop => {}
198
199			Instruction::PushConst(v) => stack.push(v.clone()),
200			Instruction::PushNone => stack.push(Value::none()),
201			Instruction::Pop => {
202				stack.pop();
203			}
204			Instruction::Dup => {
205				if let Some(v) = stack.last() {
206					stack.push(v.clone());
207				}
208			}
209
210			Instruction::LoadVar(name) => {
211				let var_name = strip_dollar_prefix(name.text());
212				let val = symbols
213					.get(&var_name)
214					.map(|v| match v {
215						Variable::Scalar(c) => c.scalar_value(),
216						_ => Value::none(),
217					})
218					.unwrap_or(Value::none());
219				stack.push(val);
220			}
221			Instruction::StoreVar(name) => {
222				let val = stack.pop().unwrap_or(Value::none());
223				let var_name = strip_dollar_prefix(name.text());
224				symbols.set(var_name, Variable::scalar(val), true)?;
225			}
226			Instruction::DeclareVar(name) => {
227				let val = stack.pop().unwrap_or(Value::none());
228				let var_name = strip_dollar_prefix(name.text());
229				symbols.set(var_name, Variable::scalar(val), true)?;
230			}
231
232			Instruction::Add => {
233				let r = stack.pop().unwrap_or(Value::none());
234				let l = stack.pop().unwrap_or(Value::none());
235				stack.push(scalar::scalar_add(l, r)?);
236			}
237			Instruction::Sub => {
238				let r = stack.pop().unwrap_or(Value::none());
239				let l = stack.pop().unwrap_or(Value::none());
240				stack.push(scalar::scalar_sub(l, r)?);
241			}
242			Instruction::Mul => {
243				let r = stack.pop().unwrap_or(Value::none());
244				let l = stack.pop().unwrap_or(Value::none());
245				stack.push(scalar::scalar_mul(l, r)?);
246			}
247			Instruction::Div => {
248				let r = stack.pop().unwrap_or(Value::none());
249				let l = stack.pop().unwrap_or(Value::none());
250				stack.push(scalar::scalar_div(l, r)?);
251			}
252			Instruction::Rem => {
253				let r = stack.pop().unwrap_or(Value::none());
254				let l = stack.pop().unwrap_or(Value::none());
255				stack.push(scalar::scalar_rem(l, r)?);
256			}
257
258			Instruction::Negate => {
259				let v = stack.pop().unwrap_or(Value::none());
260				stack.push(scalar::scalar_negate(v)?);
261			}
262			Instruction::LogicNot => {
263				let v = stack.pop().unwrap_or(Value::none());
264				stack.push(scalar::scalar_not(&v));
265			}
266
267			Instruction::CmpEq => {
268				let r = stack.pop().unwrap_or(Value::none());
269				let l = stack.pop().unwrap_or(Value::none());
270				stack.push(scalar::scalar_eq(&l, &r));
271			}
272			Instruction::CmpNe => {
273				let r = stack.pop().unwrap_or(Value::none());
274				let l = stack.pop().unwrap_or(Value::none());
275				stack.push(scalar::scalar_ne(&l, &r));
276			}
277			Instruction::CmpLt => {
278				let r = stack.pop().unwrap_or(Value::none());
279				let l = stack.pop().unwrap_or(Value::none());
280				stack.push(scalar::scalar_lt(&l, &r));
281			}
282			Instruction::CmpLe => {
283				let r = stack.pop().unwrap_or(Value::none());
284				let l = stack.pop().unwrap_or(Value::none());
285				stack.push(scalar::scalar_le(&l, &r));
286			}
287			Instruction::CmpGt => {
288				let r = stack.pop().unwrap_or(Value::none());
289				let l = stack.pop().unwrap_or(Value::none());
290				stack.push(scalar::scalar_gt(&l, &r));
291			}
292			Instruction::CmpGe => {
293				let r = stack.pop().unwrap_or(Value::none());
294				let l = stack.pop().unwrap_or(Value::none());
295				stack.push(scalar::scalar_ge(&l, &r));
296			}
297
298			Instruction::LogicAnd => {
299				let r = stack.pop().unwrap_or(Value::none());
300				let l = stack.pop().unwrap_or(Value::none());
301				stack.push(scalar::scalar_and(&l, &r));
302			}
303			Instruction::LogicOr => {
304				let r = stack.pop().unwrap_or(Value::none());
305				let l = stack.pop().unwrap_or(Value::none());
306				stack.push(scalar::scalar_or(&l, &r));
307			}
308			Instruction::LogicXor => {
309				let r = stack.pop().unwrap_or(Value::none());
310				let l = stack.pop().unwrap_or(Value::none());
311				stack.push(scalar::scalar_xor(&l, &r));
312			}
313
314			Instruction::Cast(target) => {
315				let v = stack.pop().unwrap_or(Value::none());
316				stack.push(scalar::scalar_cast(v, target.clone())?);
317			}
318			Instruction::Between => {
319				let upper = stack.pop().unwrap_or(Value::none());
320				let lower = stack.pop().unwrap_or(Value::none());
321				let val = stack.pop().unwrap_or(Value::none());
322				let ge = scalar::scalar_ge(&val, &lower);
323				let le = scalar::scalar_le(&val, &upper);
324				let result = match (ge, le) {
325					(Value::Boolean(a), Value::Boolean(b)) => Value::Boolean(a && b),
326					_ => Value::none(),
327				};
328				stack.push(result);
329			}
330			Instruction::InList {
331				count,
332				negated,
333			} => {
334				let count = *count as usize;
335				let negated = *negated;
336				let mut items: Vec<Value> = Vec::with_capacity(count);
337				for _ in 0..count {
338					items.push(stack.pop().unwrap_or(Value::none()));
339				}
340				items.reverse();
341				let val = stack.pop().unwrap_or(Value::none());
342				let has_undefined = matches!(val, Value::None { .. })
343					|| items.iter().any(|item| matches!(item, Value::None { .. }));
344				if has_undefined {
345					stack.push(Value::none());
346				} else {
347					let found = items.iter().any(|item| {
348						matches!(scalar::scalar_eq(&val, item), Value::Boolean(true))
349					});
350					stack.push(Value::Boolean(if negated {
351						!found
352					} else {
353						found
354					}));
355				}
356			}
357
358			Instruction::Jump(addr) => {
359				ip = *addr;
360				continue;
361			}
362			Instruction::JumpIfFalsePop(addr) => {
363				let v = stack.pop().unwrap_or(Value::none());
364				if !scalar::value_is_truthy(&v) {
365					ip = *addr;
366					continue;
367				}
368			}
369			Instruction::JumpIfTruePop(addr) => {
370				let v = stack.pop().unwrap_or(Value::none());
371				if scalar::value_is_truthy(&v) {
372					ip = *addr;
373					continue;
374				}
375			}
376
377			Instruction::EnterScope(scope_type) => {
378				symbols.enter_scope(scope_type.clone());
379			}
380			Instruction::ExitScope => {
381				let _ = symbols.exit_scope();
382			}
383
384			Instruction::ReturnValue => {
385				let v = stack.pop().unwrap_or(Value::none());
386				return Ok(v);
387			}
388			Instruction::ReturnVoid => {
389				return Ok(Value::none());
390			}
391
392			Instruction::Query(plan) => match plan {
393				QueryPlan::Map(map_node) => {
394					if map_node.input.is_none() && !map_node.map.is_empty() {
395						let call_session = EvalSession {
396							params,
397							symbols,
398							functions,
399							runtime_context,
400							arena: None,
401							identity,
402							is_aggregate_context: false,
403						};
404						let evaluation_context = call_session.eval_empty();
405						let result_column = evaluate(&evaluation_context, &map_node.map[0])?;
406						if result_column.data.len() > 0 {
407							stack.push(result_column.data.get_value(0));
408						}
409					}
410				}
411				_ => {
412					// Other plan types would need full VM execution
413				}
414			},
415
416			Instruction::Emit => {
417				// Emit in function body context - the stack top is the result
418			}
419
420			Instruction::Call {
421				name,
422				arity,
423				..
424			} => {
425				let arity = *arity as usize;
426				let mut args: Vec<Value> = Vec::with_capacity(arity);
427				for _ in 0..arity {
428					args.push(stack.pop().unwrap_or(Value::none()));
429				}
430				args.reverse();
431
432				// Try user-defined function
433				if let Some(func_def) = symbols.get_function(name.text()) {
434					let func_def = func_def.clone();
435					let base_depth = symbols.scope_depth();
436					symbols.enter_scope(ScopeType::Function);
437					for (param, arg_val) in func_def.parameters.iter().zip(args.iter()) {
438						let param_name = strip_dollar_prefix(param.name.text());
439						symbols.set(param_name, Variable::scalar(arg_val.clone()), true)?;
440					}
441					let result = execute_function_body_for_scalar(
442						&func_def.body,
443						symbols,
444						params,
445						functions,
446						runtime_context,
447						identity,
448					)?;
449					while symbols.scope_depth() > base_depth {
450						let _ = symbols.exit_scope();
451					}
452					stack.push(result);
453				} else if let Some(functor) = functions.get_scalar(name.text()) {
454					let mut arg_cols = Vec::with_capacity(args.len());
455					for arg in &args {
456						let mut data = ColumnData::none_typed(Type::Boolean, 0);
457						data.push_value(arg.clone());
458						arg_cols.push(Column::new("_", data));
459					}
460					let columns = Columns::new(arg_cols);
461					let fn_fragment = name.clone();
462					let result_data = functor
463						.scalar(ScalarFunctionContext {
464							fragment: fn_fragment.clone(),
465							columns: &columns,
466							row_count: 1,
467							runtime_context,
468							identity,
469						})
470						.map_err(|e| e.with_context(fn_fragment))?;
471					if result_data.len() > 0 {
472						stack.push(result_data.get_value(0));
473					} else {
474						stack.push(Value::none());
475					}
476				}
477			}
478
479			Instruction::DefineFunction(func_def) => {
480				symbols.define_function(func_def.name.text().to_string(), func_def.clone());
481			}
482
483			_ => {
484				// DDL/DML instructions not expected in function body
485			}
486		}
487		ip += 1;
488	}
489
490	// Return top of stack or Undefined
491	Ok(stack.pop().unwrap_or(Value::none()))
492}