Skip to main content

dbx_core/sql/executor/
expr.rs

1//! Physical Expression Evaluation
2
3use crate::error::{DbxError, DbxResult};
4use crate::sql::planner::{BinaryOperator, PhysicalExpr, ScalarFunction};
5use crate::storage::columnar::ScalarValue;
6use arrow::array::*;
7use arrow::compute::{self, kernels::cmp};
8use arrow::datatypes::DataType;
9use std::sync::Arc;
10
11/// Evaluate a PhysicalExpr against a RecordBatch, producing an ArrayRef.
12pub fn evaluate_expr(expr: &PhysicalExpr, batch: &RecordBatch) -> DbxResult<ArrayRef> {
13    match expr {
14        PhysicalExpr::Column(idx) => {
15            if *idx >= batch.num_columns() {
16                return Err(DbxError::SqlExecution {
17                    message: format!(
18                        "column index {} out of range ({})",
19                        idx,
20                        batch.num_columns()
21                    ),
22                    context: "evaluate_expr".to_string(),
23                });
24            }
25            Ok(Arc::clone(batch.column(*idx)))
26        }
27        PhysicalExpr::Literal(scalar) => scalar_to_array(scalar, batch.num_rows()),
28        PhysicalExpr::BinaryOp { left, op, right } => {
29            let left_arr = evaluate_expr(left, batch)?;
30            let right_arr = evaluate_expr(right, batch)?;
31            evaluate_binary_op(&left_arr, op, &right_arr)
32        }
33        PhysicalExpr::IsNull(expr) => {
34            let arr = evaluate_expr(expr, batch)?;
35            Ok(Arc::new(compute::is_null(&arr)?))
36        }
37        PhysicalExpr::IsNotNull(expr) => {
38            let arr = evaluate_expr(expr, batch)?;
39            Ok(Arc::new(compute::is_not_null(&arr)?))
40        }
41        PhysicalExpr::ScalarFunc { func, args } => {
42            let arg_arrays = args
43                .iter()
44                .map(|arg| evaluate_expr(arg, batch))
45                .collect::<DbxResult<Vec<_>>>()?;
46            evaluate_scalar_func(func, &arg_arrays)
47        }
48    }
49}
50
51/// Convert a ScalarValue to a constant array of `len` rows.
52fn scalar_to_array(scalar: &ScalarValue, len: usize) -> DbxResult<ArrayRef> {
53    match scalar {
54        ScalarValue::Int32(v) => {
55            let arr: Int32Array = vec![Some(*v); len].into_iter().collect();
56            Ok(Arc::new(arr))
57        }
58        ScalarValue::Int64(v) => {
59            let arr: Int64Array = vec![Some(*v); len].into_iter().collect();
60            Ok(Arc::new(arr))
61        }
62        ScalarValue::Float64(v) => {
63            let arr: Float64Array = vec![Some(*v); len].into_iter().collect();
64            Ok(Arc::new(arr))
65        }
66        ScalarValue::Utf8(v) => {
67            let arr: StringArray = vec![Some(v.as_str()); len].into_iter().collect();
68            Ok(Arc::new(arr))
69        }
70        ScalarValue::Boolean(v) => {
71            let arr: BooleanArray = vec![Some(*v); len].into_iter().collect();
72            Ok(Arc::new(arr))
73        }
74        ScalarValue::Binary(v) => {
75            let arr: BinaryArray = vec![Some(v.as_slice()); len].into_iter().collect();
76            Ok(Arc::new(arr))
77        }
78        ScalarValue::Null => {
79            // Default to Int32 null array
80            let arr: Int32Array = vec![None; len].into_iter().collect();
81            Ok(Arc::new(arr))
82        }
83    }
84}
85
86/// Evaluate a binary operation on two arrays.
87fn evaluate_binary_op(
88    left: &ArrayRef,
89    op: &BinaryOperator,
90    right: &ArrayRef,
91) -> DbxResult<ArrayRef> {
92    match op {
93        BinaryOperator::Eq
94        | BinaryOperator::NotEq
95        | BinaryOperator::Lt
96        | BinaryOperator::LtEq
97        | BinaryOperator::Gt
98        | BinaryOperator::GtEq => comparison_op(left, right, op),
99
100        BinaryOperator::And | BinaryOperator::Or => logical_op(left, right, op),
101
102        BinaryOperator::Plus
103        | BinaryOperator::Minus
104        | BinaryOperator::Multiply
105        | BinaryOperator::Divide
106        | BinaryOperator::Modulo => arithmetic_op(left, right, op),
107    }
108}
109
110/// Evaluate a scalar function.
111fn evaluate_scalar_func(func: &ScalarFunction, args: &[ArrayRef]) -> DbxResult<ArrayRef> {
112    match func {
113        // --- String Functions ---
114        ScalarFunction::Upper => {
115            let array = args[0]
116                .as_any()
117                .downcast_ref::<StringArray>()
118                .ok_or_else(|| DbxError::SqlExecution {
119                    message: format!(
120                        "UPPER requires StringArray but found {:?}",
121                        args[0].data_type()
122                    ),
123                    context: "UPPER".into(),
124                })?;
125            let result: StringArray = array.iter().map(|s| s.map(|v| v.to_uppercase())).collect();
126            Ok(Arc::new(result))
127        }
128        ScalarFunction::Lower => {
129            let array = args[0]
130                .as_any()
131                .downcast_ref::<StringArray>()
132                .ok_or_else(|| DbxError::SqlExecution {
133                    message: format!(
134                        "LOWER requires StringArray but found {:?}",
135                        args[0].data_type()
136                    ),
137                    context: "LOWER".into(),
138                })?;
139            let result: StringArray = array.iter().map(|s| s.map(|v| v.to_lowercase())).collect();
140            Ok(Arc::new(result))
141        }
142        ScalarFunction::Trim => {
143            let array = args[0]
144                .as_any()
145                .downcast_ref::<StringArray>()
146                .ok_or_else(|| DbxError::SqlExecution {
147                    message: format!(
148                        "TRIM requires StringArray but found {:?}",
149                        args[0].data_type()
150                    ),
151                    context: "TRIM".into(),
152                })?;
153            let result: StringArray = array.iter().map(|s| s.map(|v| v.trim())).collect();
154            Ok(Arc::new(result))
155        }
156        ScalarFunction::Length => {
157            let array = args[0]
158                .as_any()
159                .downcast_ref::<StringArray>()
160                .ok_or_else(|| DbxError::SqlExecution {
161                    message: format!(
162                        "LENGTH requires StringArray but found {:?}",
163                        args[0].data_type()
164                    ),
165                    context: "LENGTH".into(),
166                })?;
167            let result: Int32Array = array.iter().map(|s| s.map(|v| v.len() as i32)).collect();
168            Ok(Arc::new(result))
169        }
170        ScalarFunction::Concat => {
171            let num_rows = args[0].len();
172            let mut result_vec = Vec::with_capacity(num_rows);
173
174            for i in 0..num_rows {
175                let mut joined = String::new();
176                for arg in args {
177                    let s_arr = arg.as_any().downcast_ref::<StringArray>().unwrap();
178                    if !s_arr.is_null(i) {
179                        joined.push_str(s_arr.value(i));
180                    }
181                }
182                result_vec.push(Some(joined));
183            }
184            let result: StringArray = result_vec.into_iter().collect();
185            Ok(Arc::new(result))
186        }
187
188        // --- Math Functions ---
189        ScalarFunction::Abs => match args[0].data_type() {
190            DataType::Int32 => {
191                let array = args[0].as_any().downcast_ref::<Int32Array>().unwrap();
192                let result: Int32Array = array.iter().map(|v| v.map(|x| x.abs())).collect();
193                Ok(Arc::new(result))
194            }
195            DataType::Float64 => {
196                let array = args[0].as_any().downcast_ref::<Float64Array>().unwrap();
197                let result: Float64Array = array.iter().map(|v| v.map(|x| x.abs())).collect();
198                Ok(Arc::new(result))
199            }
200            _ => Err(DbxError::NotImplemented(format!(
201                "ABS for {:?}",
202                args[0].data_type()
203            ))),
204        },
205        ScalarFunction::Round => {
206            let array = args[0]
207                .as_any()
208                .downcast_ref::<Float64Array>()
209                .ok_or_else(|| DbxError::SqlExecution {
210                    message: "ROUND requires float argument".into(),
211                    context: "ROUND".into(),
212                })?;
213            let result: Float64Array = array.iter().map(|v| v.map(|x| x.round())).collect();
214            Ok(Arc::new(result))
215        }
216        ScalarFunction::Sqrt => {
217            let array = args[0]
218                .as_any()
219                .downcast_ref::<Float64Array>()
220                .ok_or_else(|| DbxError::SqlExecution {
221                    message: "SQRT requires float argument".into(),
222                    context: "SQRT".into(),
223                })?;
224            let result: Float64Array = array.iter().map(|v| v.map(|x| x.sqrt())).collect();
225            Ok(Arc::new(result))
226        }
227
228        // --- Date/Time Functions (Simple Stub) ---
229        ScalarFunction::Now | ScalarFunction::CurrentDate | ScalarFunction::CurrentTime => {
230            let now = std::time::SystemTime::now()
231                .duration_since(std::time::UNIX_EPOCH)
232                .unwrap()
233                .as_secs();
234            let len = if args.is_empty() { 1 } else { args[0].len() };
235            let result: Int64Array = vec![Some(now as i64); len].into_iter().collect();
236            Ok(Arc::new(result))
237        }
238
239        _ => Err(DbxError::NotImplemented(format!(
240            "Scalar function {:?}",
241            func
242        ))),
243    }
244}
245
246/// Coerce two arrays to a common type for comparison.
247fn coerce_for_compare(left: &ArrayRef, right: &ArrayRef) -> DbxResult<(ArrayRef, ArrayRef)> {
248    if left.data_type() == right.data_type() {
249        return Ok((Arc::clone(left), Arc::clone(right)));
250    }
251
252    // Int32 ↔ Int64 → promote both to Int64
253    match (left.data_type(), right.data_type()) {
254        (DataType::Int32, DataType::Int64) => {
255            let cast_left = compute::cast(left, &DataType::Int64)?;
256            Ok((cast_left, Arc::clone(right)))
257        }
258        (DataType::Int64, DataType::Int32) => {
259            let cast_right = compute::cast(right, &DataType::Int64)?;
260            Ok((Arc::clone(left), cast_right))
261        }
262        // Int32/Int64 ↔ Float64 → promote to Float64
263        (DataType::Int32 | DataType::Int64, DataType::Float64) => {
264            let cast_left = compute::cast(left, &DataType::Float64)?;
265            Ok((cast_left, Arc::clone(right)))
266        }
267        (DataType::Float64, DataType::Int32 | DataType::Int64) => {
268            let cast_right = compute::cast(right, &DataType::Float64)?;
269            Ok((Arc::clone(left), cast_right))
270        }
271        _ => Ok((Arc::clone(left), Arc::clone(right))),
272    }
273}
274
275/// Comparison operations on arrays.
276fn comparison_op(left: &ArrayRef, right: &ArrayRef, op: &BinaryOperator) -> DbxResult<ArrayRef> {
277    let (left, right) = coerce_for_compare(left, right)?;
278
279    let result: BooleanArray = match left.data_type() {
280        DataType::Int32 => {
281            let l = left.as_any().downcast_ref::<Int32Array>().unwrap();
282            let r = right.as_any().downcast_ref::<Int32Array>().unwrap();
283            match op {
284                BinaryOperator::Eq => cmp::eq(l, r)?,
285                BinaryOperator::NotEq => cmp::neq(l, r)?,
286                BinaryOperator::Lt => cmp::lt(l, r)?,
287                BinaryOperator::LtEq => cmp::lt_eq(l, r)?,
288                BinaryOperator::Gt => cmp::gt(l, r)?,
289                BinaryOperator::GtEq => cmp::gt_eq(l, r)?,
290                _ => unreachable!(),
291            }
292        }
293        DataType::Int64 => {
294            let l = left.as_any().downcast_ref::<Int64Array>().unwrap();
295            let r = right.as_any().downcast_ref::<Int64Array>().unwrap();
296            match op {
297                BinaryOperator::Eq => cmp::eq(l, r)?,
298                BinaryOperator::NotEq => cmp::neq(l, r)?,
299                BinaryOperator::Lt => cmp::lt(l, r)?,
300                BinaryOperator::LtEq => cmp::lt_eq(l, r)?,
301                BinaryOperator::Gt => cmp::gt(l, r)?,
302                BinaryOperator::GtEq => cmp::gt_eq(l, r)?,
303                _ => unreachable!(),
304            }
305        }
306        DataType::Float64 => {
307            let l = left.as_any().downcast_ref::<Float64Array>().unwrap();
308            let r = right.as_any().downcast_ref::<Float64Array>().unwrap();
309            match op {
310                BinaryOperator::Eq => cmp::eq(l, r)?,
311                BinaryOperator::NotEq => cmp::neq(l, r)?,
312                BinaryOperator::Lt => cmp::lt(l, r)?,
313                BinaryOperator::LtEq => cmp::lt_eq(l, r)?,
314                BinaryOperator::Gt => cmp::gt(l, r)?,
315                BinaryOperator::GtEq => cmp::gt_eq(l, r)?,
316                _ => unreachable!(),
317            }
318        }
319        DataType::Utf8 => {
320            let l = left.as_any().downcast_ref::<StringArray>().unwrap();
321            let r = right.as_any().downcast_ref::<StringArray>().unwrap();
322            match op {
323                BinaryOperator::Eq => cmp::eq(l, r)?,
324                BinaryOperator::NotEq => cmp::neq(l, r)?,
325                BinaryOperator::Lt => cmp::lt(l, r)?,
326                BinaryOperator::LtEq => cmp::lt_eq(l, r)?,
327                BinaryOperator::Gt => cmp::gt(l, r)?,
328                BinaryOperator::GtEq => cmp::gt_eq(l, r)?,
329                _ => unreachable!(),
330            }
331        }
332        DataType::Binary => {
333            let l = left.as_any().downcast_ref::<BinaryArray>().unwrap();
334            let r = right.as_any().downcast_ref::<BinaryArray>().unwrap();
335            match op {
336                BinaryOperator::Eq => cmp::eq(l, r)?,
337                BinaryOperator::NotEq => cmp::neq(l, r)?,
338                BinaryOperator::Lt => cmp::lt(l, r)?,
339                BinaryOperator::LtEq => cmp::lt_eq(l, r)?,
340                BinaryOperator::Gt => cmp::gt(l, r)?,
341                BinaryOperator::GtEq => cmp::gt_eq(l, r)?,
342                _ => unreachable!(),
343            }
344        }
345        dt => {
346            return Err(DbxError::NotImplemented(format!(
347                "comparison for type {:?}",
348                dt
349            )));
350        }
351    };
352    Ok(Arc::new(result))
353}
354
355/// Arithmetic operations on numeric arrays.
356fn arithmetic_op(left: &ArrayRef, right: &ArrayRef, op: &BinaryOperator) -> DbxResult<ArrayRef> {
357    match left.data_type() {
358        DataType::Int32 => {
359            let l = left.as_any().downcast_ref::<Int32Array>().unwrap();
360            let r = right.as_any().downcast_ref::<Int32Array>().unwrap();
361            match op {
362                BinaryOperator::Plus => Ok(compute::kernels::numeric::add(l, r)?),
363                BinaryOperator::Minus => Ok(compute::kernels::numeric::sub(l, r)?),
364                BinaryOperator::Multiply => Ok(compute::kernels::numeric::mul(l, r)?),
365                BinaryOperator::Divide => Ok(compute::kernels::numeric::div(l, r)?),
366                BinaryOperator::Modulo => Ok(compute::kernels::numeric::rem(l, r)?),
367                _ => unreachable!(),
368            }
369        }
370        DataType::Int64 => {
371            let l = left.as_any().downcast_ref::<Int64Array>().unwrap();
372            let r = right.as_any().downcast_ref::<Int64Array>().unwrap();
373            match op {
374                BinaryOperator::Plus => Ok(compute::kernels::numeric::add(l, r)?),
375                BinaryOperator::Minus => Ok(compute::kernels::numeric::sub(l, r)?),
376                BinaryOperator::Multiply => Ok(compute::kernels::numeric::mul(l, r)?),
377                BinaryOperator::Divide => Ok(compute::kernels::numeric::div(l, r)?),
378                BinaryOperator::Modulo => Ok(compute::kernels::numeric::rem(l, r)?),
379                _ => unreachable!(),
380            }
381        }
382        DataType::Float64 => {
383            let l = left.as_any().downcast_ref::<Float64Array>().unwrap();
384            let r = right.as_any().downcast_ref::<Float64Array>().unwrap();
385            match op {
386                BinaryOperator::Plus => Ok(compute::kernels::numeric::add(l, r)?),
387                BinaryOperator::Minus => Ok(compute::kernels::numeric::sub(l, r)?),
388                BinaryOperator::Multiply => Ok(compute::kernels::numeric::mul(l, r)?),
389                BinaryOperator::Divide => Ok(compute::kernels::numeric::div(l, r)?),
390                BinaryOperator::Modulo => Ok(compute::kernels::numeric::rem(l, r)?),
391                _ => unreachable!(),
392            }
393        }
394        dt => Err(DbxError::NotImplemented(format!(
395            "arithmetic for type {:?}",
396            dt
397        ))),
398    }
399}
400
401/// Logical operations on boolean arrays.
402fn logical_op(left: &ArrayRef, right: &ArrayRef, op: &BinaryOperator) -> DbxResult<ArrayRef> {
403    let l = left.as_any().downcast_ref::<BooleanArray>().unwrap();
404    let r = right.as_any().downcast_ref::<BooleanArray>().unwrap();
405    let result = match op {
406        BinaryOperator::And => compute::and(l, r)?,
407        BinaryOperator::Or => compute::or(l, r)?,
408        _ => unreachable!(),
409    };
410    Ok(Arc::new(result))
411}