1use crate::error::{DbxError, DbxResult};
4use crate::sql::planner::{BinaryOperator, PhysicalExpr, ScalarFunction};
5use crate::storage::columnar::ScalarValue;
6use arrow::array::*;
7use arrow::compute::{self, kernels::cmp};
8use arrow::datatypes::DataType;
9use std::sync::Arc;
10
11pub fn evaluate_expr(expr: &PhysicalExpr, batch: &RecordBatch) -> DbxResult<ArrayRef> {
13 match expr {
14 PhysicalExpr::Column(idx) => {
15 if *idx >= batch.num_columns() {
16 return Err(DbxError::SqlExecution {
17 message: format!(
18 "column index {} out of range ({})",
19 idx,
20 batch.num_columns()
21 ),
22 context: "evaluate_expr".to_string(),
23 });
24 }
25 Ok(Arc::clone(batch.column(*idx)))
26 }
27 PhysicalExpr::Literal(scalar) => scalar_to_array(scalar, batch.num_rows()),
28 PhysicalExpr::BinaryOp { left, op, right } => {
29 let left_arr = evaluate_expr(left, batch)?;
30 let right_arr = evaluate_expr(right, batch)?;
31 evaluate_binary_op(&left_arr, op, &right_arr)
32 }
33 PhysicalExpr::IsNull(expr) => {
34 let arr = evaluate_expr(expr, batch)?;
35 Ok(Arc::new(compute::is_null(&arr)?))
36 }
37 PhysicalExpr::IsNotNull(expr) => {
38 let arr = evaluate_expr(expr, batch)?;
39 Ok(Arc::new(compute::is_not_null(&arr)?))
40 }
41 PhysicalExpr::ScalarFunc { func, args } => {
42 let arg_arrays = args
43 .iter()
44 .map(|arg| evaluate_expr(arg, batch))
45 .collect::<DbxResult<Vec<_>>>()?;
46 evaluate_scalar_func(func, &arg_arrays)
47 }
48 }
49}
50
51fn scalar_to_array(scalar: &ScalarValue, len: usize) -> DbxResult<ArrayRef> {
53 match scalar {
54 ScalarValue::Int32(v) => {
55 let arr: Int32Array = vec![Some(*v); len].into_iter().collect();
56 Ok(Arc::new(arr))
57 }
58 ScalarValue::Int64(v) => {
59 let arr: Int64Array = vec![Some(*v); len].into_iter().collect();
60 Ok(Arc::new(arr))
61 }
62 ScalarValue::Float64(v) => {
63 let arr: Float64Array = vec![Some(*v); len].into_iter().collect();
64 Ok(Arc::new(arr))
65 }
66 ScalarValue::Utf8(v) => {
67 let arr: StringArray = vec![Some(v.as_str()); len].into_iter().collect();
68 Ok(Arc::new(arr))
69 }
70 ScalarValue::Boolean(v) => {
71 let arr: BooleanArray = vec![Some(*v); len].into_iter().collect();
72 Ok(Arc::new(arr))
73 }
74 ScalarValue::Binary(v) => {
75 let arr: BinaryArray = vec![Some(v.as_slice()); len].into_iter().collect();
76 Ok(Arc::new(arr))
77 }
78 ScalarValue::Null => {
79 let arr: Int32Array = vec![None; len].into_iter().collect();
81 Ok(Arc::new(arr))
82 }
83 }
84}
85
86fn evaluate_binary_op(
88 left: &ArrayRef,
89 op: &BinaryOperator,
90 right: &ArrayRef,
91) -> DbxResult<ArrayRef> {
92 match op {
93 BinaryOperator::Eq
94 | BinaryOperator::NotEq
95 | BinaryOperator::Lt
96 | BinaryOperator::LtEq
97 | BinaryOperator::Gt
98 | BinaryOperator::GtEq => comparison_op(left, right, op),
99
100 BinaryOperator::And | BinaryOperator::Or => logical_op(left, right, op),
101
102 BinaryOperator::Plus
103 | BinaryOperator::Minus
104 | BinaryOperator::Multiply
105 | BinaryOperator::Divide
106 | BinaryOperator::Modulo => arithmetic_op(left, right, op),
107 }
108}
109
110fn evaluate_scalar_func(func: &ScalarFunction, args: &[ArrayRef]) -> DbxResult<ArrayRef> {
112 match func {
113 ScalarFunction::Upper => {
115 let array = args[0]
116 .as_any()
117 .downcast_ref::<StringArray>()
118 .ok_or_else(|| DbxError::SqlExecution {
119 message: format!(
120 "UPPER requires StringArray but found {:?}",
121 args[0].data_type()
122 ),
123 context: "UPPER".into(),
124 })?;
125 let result: StringArray = array.iter().map(|s| s.map(|v| v.to_uppercase())).collect();
126 Ok(Arc::new(result))
127 }
128 ScalarFunction::Lower => {
129 let array = args[0]
130 .as_any()
131 .downcast_ref::<StringArray>()
132 .ok_or_else(|| DbxError::SqlExecution {
133 message: format!(
134 "LOWER requires StringArray but found {:?}",
135 args[0].data_type()
136 ),
137 context: "LOWER".into(),
138 })?;
139 let result: StringArray = array.iter().map(|s| s.map(|v| v.to_lowercase())).collect();
140 Ok(Arc::new(result))
141 }
142 ScalarFunction::Trim => {
143 let array = args[0]
144 .as_any()
145 .downcast_ref::<StringArray>()
146 .ok_or_else(|| DbxError::SqlExecution {
147 message: format!(
148 "TRIM requires StringArray but found {:?}",
149 args[0].data_type()
150 ),
151 context: "TRIM".into(),
152 })?;
153 let result: StringArray = array.iter().map(|s| s.map(|v| v.trim())).collect();
154 Ok(Arc::new(result))
155 }
156 ScalarFunction::Length => {
157 let array = args[0]
158 .as_any()
159 .downcast_ref::<StringArray>()
160 .ok_or_else(|| DbxError::SqlExecution {
161 message: format!(
162 "LENGTH requires StringArray but found {:?}",
163 args[0].data_type()
164 ),
165 context: "LENGTH".into(),
166 })?;
167 let result: Int32Array = array.iter().map(|s| s.map(|v| v.len() as i32)).collect();
168 Ok(Arc::new(result))
169 }
170 ScalarFunction::Concat => {
171 let num_rows = args[0].len();
172 let mut result_vec = Vec::with_capacity(num_rows);
173
174 for i in 0..num_rows {
175 let mut joined = String::new();
176 for arg in args {
177 let s_arr = arg.as_any().downcast_ref::<StringArray>().unwrap();
178 if !s_arr.is_null(i) {
179 joined.push_str(s_arr.value(i));
180 }
181 }
182 result_vec.push(Some(joined));
183 }
184 let result: StringArray = result_vec.into_iter().collect();
185 Ok(Arc::new(result))
186 }
187
188 ScalarFunction::Abs => match args[0].data_type() {
190 DataType::Int32 => {
191 let array = args[0].as_any().downcast_ref::<Int32Array>().unwrap();
192 let result: Int32Array = array.iter().map(|v| v.map(|x| x.abs())).collect();
193 Ok(Arc::new(result))
194 }
195 DataType::Float64 => {
196 let array = args[0].as_any().downcast_ref::<Float64Array>().unwrap();
197 let result: Float64Array = array.iter().map(|v| v.map(|x| x.abs())).collect();
198 Ok(Arc::new(result))
199 }
200 _ => Err(DbxError::NotImplemented(format!(
201 "ABS for {:?}",
202 args[0].data_type()
203 ))),
204 },
205 ScalarFunction::Round => {
206 let array = args[0]
207 .as_any()
208 .downcast_ref::<Float64Array>()
209 .ok_or_else(|| DbxError::SqlExecution {
210 message: "ROUND requires float argument".into(),
211 context: "ROUND".into(),
212 })?;
213 let result: Float64Array = array.iter().map(|v| v.map(|x| x.round())).collect();
214 Ok(Arc::new(result))
215 }
216 ScalarFunction::Sqrt => {
217 let array = args[0]
218 .as_any()
219 .downcast_ref::<Float64Array>()
220 .ok_or_else(|| DbxError::SqlExecution {
221 message: "SQRT requires float argument".into(),
222 context: "SQRT".into(),
223 })?;
224 let result: Float64Array = array.iter().map(|v| v.map(|x| x.sqrt())).collect();
225 Ok(Arc::new(result))
226 }
227
228 ScalarFunction::Now | ScalarFunction::CurrentDate | ScalarFunction::CurrentTime => {
230 let now = std::time::SystemTime::now()
231 .duration_since(std::time::UNIX_EPOCH)
232 .unwrap()
233 .as_secs();
234 let len = if args.is_empty() { 1 } else { args[0].len() };
235 let result: Int64Array = vec![Some(now as i64); len].into_iter().collect();
236 Ok(Arc::new(result))
237 }
238
239 _ => Err(DbxError::NotImplemented(format!(
240 "Scalar function {:?}",
241 func
242 ))),
243 }
244}
245
246fn coerce_for_compare(left: &ArrayRef, right: &ArrayRef) -> DbxResult<(ArrayRef, ArrayRef)> {
248 if left.data_type() == right.data_type() {
249 return Ok((Arc::clone(left), Arc::clone(right)));
250 }
251
252 match (left.data_type(), right.data_type()) {
254 (DataType::Int32, DataType::Int64) => {
255 let cast_left = compute::cast(left, &DataType::Int64)?;
256 Ok((cast_left, Arc::clone(right)))
257 }
258 (DataType::Int64, DataType::Int32) => {
259 let cast_right = compute::cast(right, &DataType::Int64)?;
260 Ok((Arc::clone(left), cast_right))
261 }
262 (DataType::Int32 | DataType::Int64, DataType::Float64) => {
264 let cast_left = compute::cast(left, &DataType::Float64)?;
265 Ok((cast_left, Arc::clone(right)))
266 }
267 (DataType::Float64, DataType::Int32 | DataType::Int64) => {
268 let cast_right = compute::cast(right, &DataType::Float64)?;
269 Ok((Arc::clone(left), cast_right))
270 }
271 _ => Ok((Arc::clone(left), Arc::clone(right))),
272 }
273}
274
275fn comparison_op(left: &ArrayRef, right: &ArrayRef, op: &BinaryOperator) -> DbxResult<ArrayRef> {
277 let (left, right) = coerce_for_compare(left, right)?;
278
279 let result: BooleanArray = match left.data_type() {
280 DataType::Int32 => {
281 let l = left.as_any().downcast_ref::<Int32Array>().unwrap();
282 let r = right.as_any().downcast_ref::<Int32Array>().unwrap();
283 match op {
284 BinaryOperator::Eq => cmp::eq(l, r)?,
285 BinaryOperator::NotEq => cmp::neq(l, r)?,
286 BinaryOperator::Lt => cmp::lt(l, r)?,
287 BinaryOperator::LtEq => cmp::lt_eq(l, r)?,
288 BinaryOperator::Gt => cmp::gt(l, r)?,
289 BinaryOperator::GtEq => cmp::gt_eq(l, r)?,
290 _ => unreachable!(),
291 }
292 }
293 DataType::Int64 => {
294 let l = left.as_any().downcast_ref::<Int64Array>().unwrap();
295 let r = right.as_any().downcast_ref::<Int64Array>().unwrap();
296 match op {
297 BinaryOperator::Eq => cmp::eq(l, r)?,
298 BinaryOperator::NotEq => cmp::neq(l, r)?,
299 BinaryOperator::Lt => cmp::lt(l, r)?,
300 BinaryOperator::LtEq => cmp::lt_eq(l, r)?,
301 BinaryOperator::Gt => cmp::gt(l, r)?,
302 BinaryOperator::GtEq => cmp::gt_eq(l, r)?,
303 _ => unreachable!(),
304 }
305 }
306 DataType::Float64 => {
307 let l = left.as_any().downcast_ref::<Float64Array>().unwrap();
308 let r = right.as_any().downcast_ref::<Float64Array>().unwrap();
309 match op {
310 BinaryOperator::Eq => cmp::eq(l, r)?,
311 BinaryOperator::NotEq => cmp::neq(l, r)?,
312 BinaryOperator::Lt => cmp::lt(l, r)?,
313 BinaryOperator::LtEq => cmp::lt_eq(l, r)?,
314 BinaryOperator::Gt => cmp::gt(l, r)?,
315 BinaryOperator::GtEq => cmp::gt_eq(l, r)?,
316 _ => unreachable!(),
317 }
318 }
319 DataType::Utf8 => {
320 let l = left.as_any().downcast_ref::<StringArray>().unwrap();
321 let r = right.as_any().downcast_ref::<StringArray>().unwrap();
322 match op {
323 BinaryOperator::Eq => cmp::eq(l, r)?,
324 BinaryOperator::NotEq => cmp::neq(l, r)?,
325 BinaryOperator::Lt => cmp::lt(l, r)?,
326 BinaryOperator::LtEq => cmp::lt_eq(l, r)?,
327 BinaryOperator::Gt => cmp::gt(l, r)?,
328 BinaryOperator::GtEq => cmp::gt_eq(l, r)?,
329 _ => unreachable!(),
330 }
331 }
332 DataType::Binary => {
333 let l = left.as_any().downcast_ref::<BinaryArray>().unwrap();
334 let r = right.as_any().downcast_ref::<BinaryArray>().unwrap();
335 match op {
336 BinaryOperator::Eq => cmp::eq(l, r)?,
337 BinaryOperator::NotEq => cmp::neq(l, r)?,
338 BinaryOperator::Lt => cmp::lt(l, r)?,
339 BinaryOperator::LtEq => cmp::lt_eq(l, r)?,
340 BinaryOperator::Gt => cmp::gt(l, r)?,
341 BinaryOperator::GtEq => cmp::gt_eq(l, r)?,
342 _ => unreachable!(),
343 }
344 }
345 dt => {
346 return Err(DbxError::NotImplemented(format!(
347 "comparison for type {:?}",
348 dt
349 )));
350 }
351 };
352 Ok(Arc::new(result))
353}
354
355fn arithmetic_op(left: &ArrayRef, right: &ArrayRef, op: &BinaryOperator) -> DbxResult<ArrayRef> {
357 match left.data_type() {
358 DataType::Int32 => {
359 let l = left.as_any().downcast_ref::<Int32Array>().unwrap();
360 let r = right.as_any().downcast_ref::<Int32Array>().unwrap();
361 match op {
362 BinaryOperator::Plus => Ok(compute::kernels::numeric::add(l, r)?),
363 BinaryOperator::Minus => Ok(compute::kernels::numeric::sub(l, r)?),
364 BinaryOperator::Multiply => Ok(compute::kernels::numeric::mul(l, r)?),
365 BinaryOperator::Divide => Ok(compute::kernels::numeric::div(l, r)?),
366 BinaryOperator::Modulo => Ok(compute::kernels::numeric::rem(l, r)?),
367 _ => unreachable!(),
368 }
369 }
370 DataType::Int64 => {
371 let l = left.as_any().downcast_ref::<Int64Array>().unwrap();
372 let r = right.as_any().downcast_ref::<Int64Array>().unwrap();
373 match op {
374 BinaryOperator::Plus => Ok(compute::kernels::numeric::add(l, r)?),
375 BinaryOperator::Minus => Ok(compute::kernels::numeric::sub(l, r)?),
376 BinaryOperator::Multiply => Ok(compute::kernels::numeric::mul(l, r)?),
377 BinaryOperator::Divide => Ok(compute::kernels::numeric::div(l, r)?),
378 BinaryOperator::Modulo => Ok(compute::kernels::numeric::rem(l, r)?),
379 _ => unreachable!(),
380 }
381 }
382 DataType::Float64 => {
383 let l = left.as_any().downcast_ref::<Float64Array>().unwrap();
384 let r = right.as_any().downcast_ref::<Float64Array>().unwrap();
385 match op {
386 BinaryOperator::Plus => Ok(compute::kernels::numeric::add(l, r)?),
387 BinaryOperator::Minus => Ok(compute::kernels::numeric::sub(l, r)?),
388 BinaryOperator::Multiply => Ok(compute::kernels::numeric::mul(l, r)?),
389 BinaryOperator::Divide => Ok(compute::kernels::numeric::div(l, r)?),
390 BinaryOperator::Modulo => Ok(compute::kernels::numeric::rem(l, r)?),
391 _ => unreachable!(),
392 }
393 }
394 dt => Err(DbxError::NotImplemented(format!(
395 "arithmetic for type {:?}",
396 dt
397 ))),
398 }
399}
400
401fn logical_op(left: &ArrayRef, right: &ArrayRef, op: &BinaryOperator) -> DbxResult<ArrayRef> {
403 let l = left.as_any().downcast_ref::<BooleanArray>().unwrap();
404 let r = right.as_any().downcast_ref::<BooleanArray>().unwrap();
405 let result = match op {
406 BinaryOperator::And => compute::and(l, r)?,
407 BinaryOperator::Or => compute::or(l, r)?,
408 _ => unreachable!(),
409 };
410 Ok(Arc::new(result))
411}