reifydb_engine/expression/
call.rs1use reifydb_core::value::column::{Column, columns::Columns, data::ColumnData, view::group_by::GroupByView};
5use reifydb_routine::function::{FunctionCapability, FunctionContext, error::FunctionError, registry::Functions};
6use reifydb_rql::{
7 expression::CallExpression,
8 instruction::{CompiledFunction, Instruction, ScopeType},
9 query::QueryPlan,
10};
11use reifydb_runtime::context::RuntimeContext;
12use reifydb_type::{
13 error::Error,
14 fragment::Fragment,
15 params::Params,
16 value::{Value, identity::IdentityId, r#type::Type},
17};
18
19use super::eval::evaluate;
20use crate::{
21 Result,
22 error::EngineError,
23 expression::context::{EvalContext, EvalSession},
24 vm::{
25 scalar,
26 stack::{SymbolTable, Variable},
27 },
28};
29
30fn strip_dollar_prefix(name: &str) -> String {
32 name.strip_prefix('$').unwrap_or(name).to_string()
33}
34
35fn column_data_from_values(values: &[Value]) -> ColumnData {
37 if values.is_empty() {
38 return ColumnData::none_typed(Type::Boolean, 0);
39 }
40
41 let mut data = ColumnData::none_typed(Type::Boolean, 0);
42 for value in values {
43 data.push_value(value.clone());
44 }
45 data
46}
47
48pub(crate) fn call_eval_with_args(
50 ctx: &EvalContext,
51 call: &CallExpression,
52 arguments: Columns,
53 functions: &Functions,
54) -> Result<Column> {
55 let function_name = call.func.0.text();
56 let fn_fragment = call.func.0.clone();
57
58 if let Some(func_def) = ctx.symbols.get_function(function_name) {
60 return call_user_defined_function(ctx, call, func_def.clone(), &arguments, functions);
61 }
62
63 let function = functions.get(function_name).ok_or_else(|| -> Error {
64 EngineError::UnknownFunction {
65 name: function_name.to_string(),
66 fragment: fn_fragment.clone(),
67 }
68 .into()
69 })?;
70
71 let fn_ctx = FunctionContext::new(fn_fragment.clone(), ctx.runtime_context, ctx.identity, ctx.row_count);
72
73 if ctx.is_aggregate_context && function.capabilities().contains(&FunctionCapability::Aggregate) {
75 let mut accumulator = function.accumulator(&fn_ctx).ok_or_else(|| FunctionError::ExecutionFailed {
76 function: fn_fragment.clone(),
77 reason: format!("Function {} is not an aggregate", function_name),
78 })?;
79
80 let column = if call.args.is_empty() {
81 Column {
82 name: Fragment::internal("dummy"),
83 data: ColumnData::with_capacity(Type::Int4, ctx.row_count),
84 }
85 } else {
86 arguments[0].clone()
87 };
88
89 let mut group_view = GroupByView::new();
90 let all_indices: Vec<usize> = (0..ctx.row_count).collect();
91 group_view.insert(Vec::<Value>::new(), all_indices);
92
93 accumulator
94 .update(&Columns::new(vec![column]), &group_view)
95 .map_err(|e| e.with_context(fn_fragment.clone()))?;
96
97 let (_keys, result_data) = accumulator.finalize().map_err(|e| e.with_context(fn_fragment))?;
98
99 return Ok(Column {
100 name: call.full_fragment_owned(),
101 data: result_data,
102 });
103 }
104
105 let result_columns = function.call(&fn_ctx, &arguments).map_err(|e| e.with_context(fn_fragment))?;
106
107 let result_column = result_columns.into_iter().next().ok_or_else(|| FunctionError::ExecutionFailed {
109 function: call.func.0.clone(),
110 reason: "Function returned no columns".to_string(),
111 })?;
112
113 Ok(Column {
114 name: call.full_fragment_owned(),
115 data: result_column.data,
116 })
117}
118
119fn call_user_defined_function(
121 ctx: &EvalContext,
122 call: &CallExpression,
123 func_def: CompiledFunction,
124 arguments: &Columns,
125 functions: &Functions,
126) -> Result<Column> {
127 let row_count = ctx.row_count;
128 let mut results: Vec<Value> = Vec::with_capacity(row_count);
129
130 let body_instructions = &func_def.body;
132
133 let mut func_symbols = ctx.symbols.clone();
134
135 for row_idx in 0..row_count {
137 let base_depth = func_symbols.scope_depth();
138 func_symbols.enter_scope(ScopeType::Function);
139
140 for (param, arg_col) in func_def.parameters.iter().zip(arguments.iter()) {
142 let param_name = strip_dollar_prefix(param.name.text());
143 let value = arg_col.data().get_value(row_idx);
144 func_symbols.set(param_name, Variable::scalar(value), true)?;
145 }
146
147 let result = execute_function_body_for_scalar(
149 body_instructions,
150 &mut func_symbols,
151 ctx.params,
152 functions,
153 ctx.runtime_context,
154 ctx.identity,
155 )?;
156
157 while func_symbols.scope_depth() > base_depth {
158 let _ = func_symbols.exit_scope();
159 }
160
161 results.push(result);
162 }
163
164 let data = column_data_from_values(&results);
166 Ok(Column {
167 name: call.full_fragment_owned(),
168 data,
169 })
170}
171
172fn execute_function_body_for_scalar<'a>(
175 instructions: &[Instruction],
176 symbols: &mut SymbolTable,
177 params: &'a Params,
178 functions: &'a Functions,
179 runtime_context: &'a RuntimeContext,
180 identity: IdentityId,
181) -> Result<Value> {
182 let mut ip = 0;
183 let mut stack: Vec<Value> = Vec::new();
184
185 while ip < instructions.len() {
186 match &instructions[ip] {
187 Instruction::Halt => break,
188 Instruction::Nop => {}
189
190 Instruction::PushConst(v) => stack.push(v.clone()),
191 Instruction::PushNone => stack.push(Value::none()),
192 Instruction::Pop => {
193 stack.pop();
194 }
195 Instruction::Dup => {
196 if let Some(v) = stack.last() {
197 stack.push(v.clone());
198 }
199 }
200
201 Instruction::LoadVar(name) => {
202 let var_name = strip_dollar_prefix(name.text());
203 let val = symbols
204 .get(&var_name)
205 .map(|v| match v {
206 Variable::Scalar(c) => c.scalar_value(),
207 _ => Value::none(),
208 })
209 .unwrap_or(Value::none());
210 stack.push(val);
211 }
212 Instruction::StoreVar(name) => {
213 let val = stack.pop().unwrap_or(Value::none());
214 let var_name = strip_dollar_prefix(name.text());
215 symbols.set(var_name, Variable::scalar(val), true)?;
216 }
217 Instruction::DeclareVar(name) => {
218 let val = stack.pop().unwrap_or(Value::none());
219 let var_name = strip_dollar_prefix(name.text());
220 symbols.set(var_name, Variable::scalar(val), true)?;
221 }
222
223 Instruction::Add => {
224 let r = stack.pop().unwrap_or(Value::none());
225 let l = stack.pop().unwrap_or(Value::none());
226 stack.push(scalar::scalar_add(l, r)?);
227 }
228 Instruction::Sub => {
229 let r = stack.pop().unwrap_or(Value::none());
230 let l = stack.pop().unwrap_or(Value::none());
231 stack.push(scalar::scalar_sub(l, r)?);
232 }
233 Instruction::Mul => {
234 let r = stack.pop().unwrap_or(Value::none());
235 let l = stack.pop().unwrap_or(Value::none());
236 stack.push(scalar::scalar_mul(l, r)?);
237 }
238 Instruction::Div => {
239 let r = stack.pop().unwrap_or(Value::none());
240 let l = stack.pop().unwrap_or(Value::none());
241 stack.push(scalar::scalar_div(l, r)?);
242 }
243 Instruction::Rem => {
244 let r = stack.pop().unwrap_or(Value::none());
245 let l = stack.pop().unwrap_or(Value::none());
246 stack.push(scalar::scalar_rem(l, r)?);
247 }
248
249 Instruction::Negate => {
250 let v = stack.pop().unwrap_or(Value::none());
251 stack.push(scalar::scalar_negate(v)?);
252 }
253 Instruction::LogicNot => {
254 let v = stack.pop().unwrap_or(Value::none());
255 stack.push(scalar::scalar_not(&v));
256 }
257
258 Instruction::CmpEq => {
259 let r = stack.pop().unwrap_or(Value::none());
260 let l = stack.pop().unwrap_or(Value::none());
261 stack.push(scalar::scalar_eq(&l, &r));
262 }
263 Instruction::CmpNe => {
264 let r = stack.pop().unwrap_or(Value::none());
265 let l = stack.pop().unwrap_or(Value::none());
266 stack.push(scalar::scalar_ne(&l, &r));
267 }
268 Instruction::CmpLt => {
269 let r = stack.pop().unwrap_or(Value::none());
270 let l = stack.pop().unwrap_or(Value::none());
271 stack.push(scalar::scalar_lt(&l, &r));
272 }
273 Instruction::CmpLe => {
274 let r = stack.pop().unwrap_or(Value::none());
275 let l = stack.pop().unwrap_or(Value::none());
276 stack.push(scalar::scalar_le(&l, &r));
277 }
278 Instruction::CmpGt => {
279 let r = stack.pop().unwrap_or(Value::none());
280 let l = stack.pop().unwrap_or(Value::none());
281 stack.push(scalar::scalar_gt(&l, &r));
282 }
283 Instruction::CmpGe => {
284 let r = stack.pop().unwrap_or(Value::none());
285 let l = stack.pop().unwrap_or(Value::none());
286 stack.push(scalar::scalar_ge(&l, &r));
287 }
288
289 Instruction::LogicAnd => {
290 let r = stack.pop().unwrap_or(Value::none());
291 let l = stack.pop().unwrap_or(Value::none());
292 stack.push(scalar::scalar_and(&l, &r));
293 }
294 Instruction::LogicOr => {
295 let r = stack.pop().unwrap_or(Value::none());
296 let l = stack.pop().unwrap_or(Value::none());
297 stack.push(scalar::scalar_or(&l, &r));
298 }
299 Instruction::LogicXor => {
300 let r = stack.pop().unwrap_or(Value::none());
301 let l = stack.pop().unwrap_or(Value::none());
302 stack.push(scalar::scalar_xor(&l, &r));
303 }
304
305 Instruction::Cast(target) => {
306 let v = stack.pop().unwrap_or(Value::none());
307 stack.push(scalar::scalar_cast(v, target.clone())?);
308 }
309 Instruction::Between => {
310 let upper = stack.pop().unwrap_or(Value::none());
311 let lower = stack.pop().unwrap_or(Value::none());
312 let val = stack.pop().unwrap_or(Value::none());
313 let ge = scalar::scalar_ge(&val, &lower);
314 let le = scalar::scalar_le(&val, &upper);
315 let result = match (ge, le) {
316 (Value::Boolean(a), Value::Boolean(b)) => Value::Boolean(a && b),
317 _ => Value::none(),
318 };
319 stack.push(result);
320 }
321 Instruction::InList {
322 count,
323 negated,
324 } => {
325 let count = *count as usize;
326 let negated = *negated;
327 let mut items: Vec<Value> = Vec::with_capacity(count);
328 for _ in 0..count {
329 items.push(stack.pop().unwrap_or(Value::none()));
330 }
331 items.reverse();
332 let val = stack.pop().unwrap_or(Value::none());
333 let has_undefined = matches!(val, Value::None { .. })
334 || items.iter().any(|item| matches!(item, Value::None { .. }));
335 if has_undefined {
336 stack.push(Value::none());
337 } else {
338 let found = items.iter().any(|item| {
339 matches!(scalar::scalar_eq(&val, item), Value::Boolean(true))
340 });
341 stack.push(Value::Boolean(if negated {
342 !found
343 } else {
344 found
345 }));
346 }
347 }
348
349 Instruction::Jump(addr) => {
350 ip = *addr;
351 continue;
352 }
353 Instruction::JumpIfFalsePop(addr) => {
354 let v = stack.pop().unwrap_or(Value::none());
355 if !scalar::value_is_truthy(&v) {
356 ip = *addr;
357 continue;
358 }
359 }
360 Instruction::JumpIfTruePop(addr) => {
361 let v = stack.pop().unwrap_or(Value::none());
362 if scalar::value_is_truthy(&v) {
363 ip = *addr;
364 continue;
365 }
366 }
367
368 Instruction::EnterScope(scope_type) => {
369 symbols.enter_scope(scope_type.clone());
370 }
371 Instruction::ExitScope => {
372 let _ = symbols.exit_scope();
373 }
374
375 Instruction::ReturnValue => {
376 let v = stack.pop().unwrap_or(Value::none());
377 return Ok(v);
378 }
379 Instruction::ReturnVoid => {
380 return Ok(Value::none());
381 }
382
383 Instruction::Query(QueryPlan::Map(map_node)) => {
384 if map_node.input.is_none() && !map_node.map.is_empty() {
385 let call_session = EvalSession {
386 params,
387 symbols,
388 functions,
389 runtime_context,
390 arena: None,
391 identity,
392 is_aggregate_context: false,
393 };
394 let evaluation_context = call_session.eval_empty();
395 let result_column = evaluate(&evaluation_context, &map_node.map[0])?;
396 if !result_column.data.is_empty() {
397 stack.push(result_column.data.get_value(0));
398 }
399 }
400 }
401 Instruction::Query(_) => {
402 }
404
405 Instruction::Emit => {
406 }
408
409 Instruction::Call {
410 name,
411 arity,
412 ..
413 } => {
414 let arity = *arity as usize;
415 let mut args: Vec<Value> = Vec::with_capacity(arity);
416 for _ in 0..arity {
417 args.push(stack.pop().unwrap_or(Value::none()));
418 }
419 args.reverse();
420
421 if let Some(func_def) = symbols.get_function(name.text()) {
423 let func_def = func_def.clone();
424 let base_depth = symbols.scope_depth();
425 symbols.enter_scope(ScopeType::Function);
426 for (param, arg_val) in func_def.parameters.iter().zip(args.iter()) {
427 let param_name = strip_dollar_prefix(param.name.text());
428 symbols.set(param_name, Variable::scalar(arg_val.clone()), true)?;
429 }
430 let result = execute_function_body_for_scalar(
431 &func_def.body,
432 symbols,
433 params,
434 functions,
435 runtime_context,
436 identity,
437 )?;
438 while symbols.scope_depth() > base_depth {
439 let _ = symbols.exit_scope();
440 }
441 stack.push(result);
442 } else if let Some(function) = functions.get(name.text()) {
443 let mut arg_cols = Vec::with_capacity(args.len());
444 for arg in &args {
445 let mut data = ColumnData::none_typed(Type::Boolean, 0);
446 data.push_value(arg.clone());
447 arg_cols.push(Column::new("_", data));
448 }
449 let columns = Columns::new(arg_cols);
450 let fn_fragment = name.clone();
451 let fn_ctx =
452 FunctionContext::new(fn_fragment.clone(), runtime_context, identity, 1);
453
454 let result_columns = function
455 .call(&fn_ctx, &columns)
456 .map_err(|e| e.with_context(fn_fragment))?;
457
458 if let Some(col) = result_columns.into_iter().next() {
459 stack.push(col.data().get_value(0));
460 } else {
461 stack.push(Value::none());
462 }
463 }
464 }
465
466 Instruction::DefineFunction(func_def) => {
467 symbols.define_function(func_def.name.text().to_string(), func_def.clone());
468 }
469
470 _ => {
471 }
473 }
474 ip += 1;
475 }
476
477 Ok(stack.pop().unwrap_or(Value::none()))
479}