Skip to main content

oxigdal_query/executor/
filter.rs

1//! Filter executor.
2
3use crate::error::{QueryError, Result};
4use crate::executor::scan::{ColumnData, RecordBatch};
5use crate::parser::ast::{BinaryOperator, Expr, Literal, UnaryOperator};
6use oxigdal_core::error::OxiGdalError;
7
8/// Filter operator.
9pub struct Filter {
10    /// Filter predicate.
11    pub predicate: Expr,
12}
13
14impl Filter {
15    /// Create a new filter.
16    pub fn new(predicate: Expr) -> Self {
17        Self { predicate }
18    }
19
20    /// Execute the filter on a record batch.
21    pub fn execute(&self, batch: &RecordBatch) -> Result<RecordBatch> {
22        let mut selection = vec![false; batch.num_rows];
23
24        // Evaluate predicate for each row
25        for (row_idx, sel) in selection.iter_mut().enumerate().take(batch.num_rows) {
26            let result = self.evaluate_expr(&self.predicate, batch, row_idx)?;
27            if let Value::Boolean(b) = result {
28                *sel = b;
29            } else {
30                return Err(QueryError::execution(
31                    OxiGdalError::invalid_operation_builder(
32                        "Filter predicate must evaluate to boolean type",
33                    )
34                    .with_operation("filter_evaluation")
35                    .with_parameter("row_index", row_idx.to_string())
36                    .with_parameter("actual_type", format!("{:?}", result))
37                    .with_suggestion("Ensure WHERE clause uses comparison or boolean operators")
38                    .build()
39                    .to_string(),
40                ));
41            }
42        }
43
44        // Filter columns based on selection
45        let mut filtered_columns = Vec::new();
46        for column in &batch.columns {
47            filtered_columns.push(self.filter_column(column, &selection));
48        }
49
50        let filtered_rows = selection.iter().filter(|&&b| b).count();
51
52        RecordBatch::new(batch.schema.clone(), filtered_columns, filtered_rows)
53    }
54
55    /// Filter a column based on selection.
56    fn filter_column(&self, column: &ColumnData, selection: &[bool]) -> ColumnData {
57        match column {
58            ColumnData::Boolean(data) => {
59                let filtered: Vec<Option<bool>> = data
60                    .iter()
61                    .zip(selection)
62                    .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
63                    .collect();
64                ColumnData::Boolean(filtered)
65            }
66            ColumnData::Int32(data) => {
67                let filtered: Vec<Option<i32>> = data
68                    .iter()
69                    .zip(selection)
70                    .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
71                    .collect();
72                ColumnData::Int32(filtered)
73            }
74            ColumnData::Int64(data) => {
75                let filtered: Vec<Option<i64>> = data
76                    .iter()
77                    .zip(selection)
78                    .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
79                    .collect();
80                ColumnData::Int64(filtered)
81            }
82            ColumnData::Float32(data) => {
83                let filtered: Vec<Option<f32>> = data
84                    .iter()
85                    .zip(selection)
86                    .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
87                    .collect();
88                ColumnData::Float32(filtered)
89            }
90            ColumnData::Float64(data) => {
91                let filtered: Vec<Option<f64>> = data
92                    .iter()
93                    .zip(selection)
94                    .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
95                    .collect();
96                ColumnData::Float64(filtered)
97            }
98            ColumnData::String(data) => {
99                let filtered: Vec<Option<String>> = data
100                    .iter()
101                    .zip(selection)
102                    .filter_map(|(v, &sel)| if sel { Some(v.clone()) } else { None })
103                    .collect();
104                ColumnData::String(filtered)
105            }
106            ColumnData::Binary(data) => {
107                let filtered = data
108                    .iter()
109                    .zip(selection)
110                    .filter_map(|(v, &sel)| if sel { Some(v.clone()) } else { None })
111                    .collect();
112                ColumnData::Binary(filtered)
113            }
114        }
115    }
116
117    /// Evaluate an expression for a specific row.
118    fn evaluate_expr(&self, expr: &Expr, batch: &RecordBatch, row_idx: usize) -> Result<Value> {
119        match expr {
120            Expr::Column { table: _, name } => {
121                let column = batch
122                    .column_by_name(name)
123                    .ok_or_else(|| QueryError::ColumnNotFound(name.clone()))?;
124                self.get_column_value(column, row_idx)
125            }
126            Expr::Literal(lit) => Ok(Value::from_literal(lit)),
127            Expr::BinaryOp { left, op, right } => {
128                let left_val = self.evaluate_expr(left, batch, row_idx)?;
129                let right_val = self.evaluate_expr(right, batch, row_idx)?;
130                self.evaluate_binary_op(&left_val, *op, &right_val)
131            }
132            Expr::UnaryOp { op, expr } => {
133                let val = self.evaluate_expr(expr, batch, row_idx)?;
134                self.evaluate_unary_op(*op, &val)
135            }
136            Expr::IsNull(expr) => {
137                let val = self.evaluate_expr(expr, batch, row_idx)?;
138                Ok(Value::Boolean(matches!(val, Value::Null)))
139            }
140            Expr::IsNotNull(expr) => {
141                let val = self.evaluate_expr(expr, batch, row_idx)?;
142                Ok(Value::Boolean(!matches!(val, Value::Null)))
143            }
144            Expr::Function { name, args } => {
145                // Evaluate each argument first.
146                let arg_values: Vec<Value> = args
147                    .iter()
148                    .map(|a| self.evaluate_expr(a, batch, row_idx))
149                    .collect::<Result<Vec<_>>>()?;
150                // Dispatch to the spatial-function evaluator. The coordinate
151                // dimension is 2-D for the current row-based interpreter.
152                crate::executor::spatial_funcs::evaluate_spatial_function(name, &arg_values, 2)
153            }
154            _ => Err(QueryError::unsupported(
155                OxiGdalError::not_supported_builder("Unsupported expression type in filter")
156                    .with_operation("filter_evaluation")
157                    .with_parameter("expression_type", format!("{:?}", expr))
158                    .with_suggestion(
159                        "Use simpler expressions: columns, literals, binary/unary operators, IS [NOT] NULL",
160                    )
161                    .build()
162                    .to_string(),
163            )),
164        }
165    }
166
167    /// Get value from column at row index.
168    fn get_column_value(&self, column: &ColumnData, row_idx: usize) -> Result<Value> {
169        match column {
170            ColumnData::Boolean(data) => Ok(data
171                .get(row_idx)
172                .and_then(|v| v.as_ref())
173                .map(|&v| Value::Boolean(v))
174                .unwrap_or(Value::Null)),
175            ColumnData::Int32(data) => Ok(data
176                .get(row_idx)
177                .and_then(|v| v.as_ref())
178                .map(|&v| Value::Int32(v))
179                .unwrap_or(Value::Null)),
180            ColumnData::Int64(data) => Ok(data
181                .get(row_idx)
182                .and_then(|v| v.as_ref())
183                .map(|&v| Value::Int64(v))
184                .unwrap_or(Value::Null)),
185            ColumnData::Float32(data) => Ok(data
186                .get(row_idx)
187                .and_then(|v| v.as_ref())
188                .map(|&v| Value::Float32(v))
189                .unwrap_or(Value::Null)),
190            ColumnData::Float64(data) => Ok(data
191                .get(row_idx)
192                .and_then(|v| v.as_ref())
193                .map(|&v| Value::Float64(v))
194                .unwrap_or(Value::Null)),
195            ColumnData::String(data) => Ok(data
196                .get(row_idx)
197                .and_then(|v| v.as_ref())
198                .map(|v| Value::String(v.clone()))
199                .unwrap_or(Value::Null)),
200            ColumnData::Binary(_) => Err(QueryError::unsupported(
201                OxiGdalError::not_supported_builder(
202                    "Binary column type not supported in filter predicates",
203                )
204                .with_operation("column_value_extraction")
205                .with_parameter("row_index", row_idx.to_string())
206                .with_suggestion(
207                    "Cast binary columns to supported types or filter at a different stage",
208                )
209                .build()
210                .to_string(),
211            )),
212        }
213    }
214
215    /// Evaluate a binary operation.
216    fn evaluate_binary_op(&self, left: &Value, op: BinaryOperator, right: &Value) -> Result<Value> {
217        match (left, right) {
218            (Value::Null, _) | (_, Value::Null) => Ok(Value::Null),
219            // Type coercion: Int32 with Int64
220            (Value::Int32(l), Value::Int64(r)) => {
221                self.evaluate_binary_op(&Value::Int64(*l as i64), op, &Value::Int64(*r))
222            }
223            (Value::Int64(l), Value::Int32(r)) => {
224                self.evaluate_binary_op(&Value::Int64(*l), op, &Value::Int64(*r as i64))
225            }
226            (Value::Int32(l), Value::Int32(r)) => match op {
227                BinaryOperator::Plus => Ok(Value::Int32(l + r)),
228                BinaryOperator::Minus => Ok(Value::Int32(l - r)),
229                BinaryOperator::Multiply => Ok(Value::Int32(l * r)),
230                BinaryOperator::Divide => {
231                    if *r == 0 {
232                        Ok(Value::Null)
233                    } else {
234                        Ok(Value::Int32(l / r))
235                    }
236                }
237                BinaryOperator::Modulo => {
238                    if *r == 0 {
239                        Ok(Value::Null)
240                    } else {
241                        Ok(Value::Int32(l % r))
242                    }
243                }
244                BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
245                BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
246                BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
247                BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
248                BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
249                BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
250                _ => Err(QueryError::unsupported("Unsupported operator for integers")),
251            },
252            (Value::Int64(l), Value::Int64(r)) => match op {
253                BinaryOperator::Plus => Ok(Value::Int64(l + r)),
254                BinaryOperator::Minus => Ok(Value::Int64(l - r)),
255                BinaryOperator::Multiply => Ok(Value::Int64(l * r)),
256                BinaryOperator::Divide => {
257                    if *r == 0 {
258                        Ok(Value::Null)
259                    } else {
260                        Ok(Value::Int64(l / r))
261                    }
262                }
263                BinaryOperator::Modulo => {
264                    if *r == 0 {
265                        Ok(Value::Null)
266                    } else {
267                        Ok(Value::Int64(l % r))
268                    }
269                }
270                BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
271                BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
272                BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
273                BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
274                BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
275                BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
276                _ => Err(QueryError::unsupported("Unsupported operator for integers")),
277            },
278            // Type coercion: Float32 with Float64
279            (Value::Float32(l), Value::Float64(r)) => {
280                self.evaluate_binary_op(&Value::Float64(*l as f64), op, &Value::Float64(*r))
281            }
282            (Value::Float64(l), Value::Float32(r)) => {
283                self.evaluate_binary_op(&Value::Float64(*l), op, &Value::Float64(*r as f64))
284            }
285            (Value::Float32(l), Value::Float32(r)) => match op {
286                BinaryOperator::Plus => Ok(Value::Float32(l + r)),
287                BinaryOperator::Minus => Ok(Value::Float32(l - r)),
288                BinaryOperator::Multiply => Ok(Value::Float32(l * r)),
289                BinaryOperator::Divide => Ok(Value::Float32(l / r)),
290                BinaryOperator::Eq => Ok(Value::Boolean((l - r).abs() < f32::EPSILON)),
291                BinaryOperator::NotEq => Ok(Value::Boolean((l - r).abs() >= f32::EPSILON)),
292                BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
293                BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
294                BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
295                BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
296                _ => Err(QueryError::unsupported("Unsupported operator for floats")),
297            },
298            (Value::Float64(l), Value::Float64(r)) => match op {
299                BinaryOperator::Plus => Ok(Value::Float64(l + r)),
300                BinaryOperator::Minus => Ok(Value::Float64(l - r)),
301                BinaryOperator::Multiply => Ok(Value::Float64(l * r)),
302                BinaryOperator::Divide => Ok(Value::Float64(l / r)),
303                BinaryOperator::Eq => Ok(Value::Boolean((l - r).abs() < f64::EPSILON)),
304                BinaryOperator::NotEq => Ok(Value::Boolean((l - r).abs() >= f64::EPSILON)),
305                BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
306                BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
307                BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
308                BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
309                _ => Err(QueryError::unsupported("Unsupported operator for floats")),
310            },
311            // Type coercion: Int with Float
312            (Value::Int32(l), Value::Float64(r)) => {
313                self.evaluate_binary_op(&Value::Float64(*l as f64), op, &Value::Float64(*r))
314            }
315            (Value::Int64(l), Value::Float64(r)) => {
316                self.evaluate_binary_op(&Value::Float64(*l as f64), op, &Value::Float64(*r))
317            }
318            (Value::Float64(l), Value::Int32(r)) => {
319                self.evaluate_binary_op(&Value::Float64(*l), op, &Value::Float64(*r as f64))
320            }
321            (Value::Float64(l), Value::Int64(r)) => {
322                self.evaluate_binary_op(&Value::Float64(*l), op, &Value::Float64(*r as f64))
323            }
324            (Value::Boolean(l), Value::Boolean(r)) => match op {
325                BinaryOperator::And => Ok(Value::Boolean(*l && *r)),
326                BinaryOperator::Or => Ok(Value::Boolean(*l || *r)),
327                BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
328                BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
329                _ => Err(QueryError::unsupported("Unsupported operator for booleans")),
330            },
331            (Value::String(l), Value::String(r)) => match op {
332                BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
333                BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
334                BinaryOperator::Concat => Ok(Value::String(format!("{}{}", l, r))),
335                _ => Err(QueryError::unsupported("Unsupported operator for strings")),
336            },
337            _ => Err(QueryError::execution(
338                OxiGdalError::invalid_operation_builder("Type mismatch in binary operation")
339                    .with_operation("binary_operator_evaluation")
340                    .with_parameter("left_type", format!("{:?}", left))
341                    .with_parameter("right_type", format!("{:?}", right))
342                    .with_parameter("operator", format!("{:?}", op))
343                    .with_suggestion(
344                        "Ensure both operands have compatible types or use explicit type casts",
345                    )
346                    .build()
347                    .to_string(),
348            )),
349        }
350    }
351
352    /// Evaluate a unary operation.
353    fn evaluate_unary_op(&self, op: UnaryOperator, val: &Value) -> Result<Value> {
354        match (op, val) {
355            (UnaryOperator::Minus, Value::Int64(i)) => Ok(Value::Int64(-i)),
356            (UnaryOperator::Minus, Value::Float64(f)) => Ok(Value::Float64(-f)),
357            (UnaryOperator::Not, Value::Boolean(b)) => Ok(Value::Boolean(!b)),
358            (_, Value::Null) => Ok(Value::Null),
359            _ => Err(QueryError::unsupported("Unsupported unary operation")),
360        }
361    }
362}
363
364/// Runtime value.
365#[derive(Debug, Clone, PartialEq)]
366pub enum Value {
367    /// Null value.
368    Null,
369    /// Boolean value.
370    Boolean(bool),
371    /// 32-bit integer value.
372    Int32(i32),
373    /// 64-bit integer value.
374    Int64(i64),
375    /// 32-bit float value.
376    Float32(f32),
377    /// 64-bit float value.
378    Float64(f64),
379    /// String value.
380    String(String),
381    /// Geometry value (constructed by spatial functions or parsed from WKT).
382    Geometry(geo::Geometry<f64>),
383}
384
385impl Value {
386    /// Convert from a literal.
387    fn from_literal(lit: &Literal) -> Self {
388        match lit {
389            Literal::Null => Value::Null,
390            Literal::Boolean(b) => Value::Boolean(*b),
391            Literal::Integer(i) => Value::Int64(*i),
392            Literal::Float(f) => Value::Float64(*f),
393            Literal::String(s) => Value::String(s.clone()),
394        }
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use crate::executor::scan::{Field, Schema};
402    use std::sync::Arc;
403
404    #[test]
405    fn test_filter_execution() -> Result<()> {
406        let schema = Arc::new(Schema::new(vec![
407            Field::new(
408                "id".to_string(),
409                crate::executor::scan::DataType::Int64,
410                false,
411            ),
412            Field::new(
413                "value".to_string(),
414                crate::executor::scan::DataType::Int64,
415                false,
416            ),
417        ]));
418
419        let columns = vec![
420            ColumnData::Int64(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]),
421            ColumnData::Int64(vec![Some(10), Some(20), Some(30), Some(40), Some(50)]),
422        ];
423
424        let batch = RecordBatch::new(schema, columns, 5)?;
425
426        // Filter: value > 25
427        let predicate = Expr::BinaryOp {
428            left: Box::new(Expr::Column {
429                table: None,
430                name: "value".to_string(),
431            }),
432            op: BinaryOperator::Gt,
433            right: Box::new(Expr::Literal(Literal::Integer(25))),
434        };
435
436        let filter = Filter::new(predicate);
437        let filtered = filter.execute(&batch)?;
438
439        assert_eq!(filtered.num_rows, 3); // 30, 40, 50 are > 25
440
441        Ok(())
442    }
443}