Skip to main content

oxigdal_query/executor/
filter.rs

1//! Filter executor.
2
3use crate::error::{QueryError, Result};
4use crate::executor::scan::{ColumnData, RecordBatch};
5use crate::parser::ast::{BinaryOperator, Expr, Literal, UnaryOperator};
6use oxigdal_core::error::OxiGdalError;
7
8/// Filter operator.
9pub struct Filter {
10    /// Filter predicate.
11    pub predicate: Expr,
12}
13
14impl Filter {
15    /// Create a new filter.
16    pub fn new(predicate: Expr) -> Self {
17        Self { predicate }
18    }
19
20    /// Execute the filter on a record batch.
21    pub fn execute(&self, batch: &RecordBatch) -> Result<RecordBatch> {
22        let mut selection = vec![false; batch.num_rows];
23
24        // Evaluate predicate for each row
25        for (row_idx, sel) in selection.iter_mut().enumerate().take(batch.num_rows) {
26            let result = self.evaluate_expr(&self.predicate, batch, row_idx)?;
27            if let Value::Boolean(b) = result {
28                *sel = b;
29            } else {
30                return Err(QueryError::execution(
31                    OxiGdalError::invalid_operation_builder(
32                        "Filter predicate must evaluate to boolean type",
33                    )
34                    .with_operation("filter_evaluation")
35                    .with_parameter("row_index", row_idx.to_string())
36                    .with_parameter("actual_type", format!("{:?}", result))
37                    .with_suggestion("Ensure WHERE clause uses comparison or boolean operators")
38                    .build()
39                    .to_string(),
40                ));
41            }
42        }
43
44        // Filter columns based on selection
45        let mut filtered_columns = Vec::new();
46        for column in &batch.columns {
47            filtered_columns.push(self.filter_column(column, &selection));
48        }
49
50        let filtered_rows = selection.iter().filter(|&&b| b).count();
51
52        RecordBatch::new(batch.schema.clone(), filtered_columns, filtered_rows)
53    }
54
55    /// Filter a column based on selection.
56    fn filter_column(&self, column: &ColumnData, selection: &[bool]) -> ColumnData {
57        match column {
58            ColumnData::Boolean(data) => {
59                let filtered: Vec<Option<bool>> = data
60                    .iter()
61                    .zip(selection)
62                    .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
63                    .collect();
64                ColumnData::Boolean(filtered)
65            }
66            ColumnData::Int32(data) => {
67                let filtered: Vec<Option<i32>> = data
68                    .iter()
69                    .zip(selection)
70                    .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
71                    .collect();
72                ColumnData::Int32(filtered)
73            }
74            ColumnData::Int64(data) => {
75                let filtered: Vec<Option<i64>> = data
76                    .iter()
77                    .zip(selection)
78                    .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
79                    .collect();
80                ColumnData::Int64(filtered)
81            }
82            ColumnData::Float32(data) => {
83                let filtered: Vec<Option<f32>> = data
84                    .iter()
85                    .zip(selection)
86                    .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
87                    .collect();
88                ColumnData::Float32(filtered)
89            }
90            ColumnData::Float64(data) => {
91                let filtered: Vec<Option<f64>> = data
92                    .iter()
93                    .zip(selection)
94                    .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
95                    .collect();
96                ColumnData::Float64(filtered)
97            }
98            ColumnData::String(data) => {
99                let filtered: Vec<Option<String>> = data
100                    .iter()
101                    .zip(selection)
102                    .filter_map(|(v, &sel)| if sel { Some(v.clone()) } else { None })
103                    .collect();
104                ColumnData::String(filtered)
105            }
106            ColumnData::Binary(data) => {
107                let filtered = data
108                    .iter()
109                    .zip(selection)
110                    .filter_map(|(v, &sel)| if sel { Some(v.clone()) } else { None })
111                    .collect();
112                ColumnData::Binary(filtered)
113            }
114        }
115    }
116
117    /// Evaluate an expression for a specific row.
118    fn evaluate_expr(&self, expr: &Expr, batch: &RecordBatch, row_idx: usize) -> Result<Value> {
119        match expr {
120            Expr::Column { table: _, name } => {
121                let column = batch
122                    .column_by_name(name)
123                    .ok_or_else(|| QueryError::ColumnNotFound(name.clone()))?;
124                self.get_column_value(column, row_idx)
125            }
126            Expr::Literal(lit) => Ok(Value::from_literal(lit)),
127            Expr::BinaryOp { left, op, right } => {
128                let left_val = self.evaluate_expr(left, batch, row_idx)?;
129                let right_val = self.evaluate_expr(right, batch, row_idx)?;
130                self.evaluate_binary_op(&left_val, *op, &right_val)
131            }
132            Expr::UnaryOp { op, expr } => {
133                let val = self.evaluate_expr(expr, batch, row_idx)?;
134                self.evaluate_unary_op(*op, &val)
135            }
136            Expr::IsNull(expr) => {
137                let val = self.evaluate_expr(expr, batch, row_idx)?;
138                Ok(Value::Boolean(matches!(val, Value::Null)))
139            }
140            Expr::IsNotNull(expr) => {
141                let val = self.evaluate_expr(expr, batch, row_idx)?;
142                Ok(Value::Boolean(!matches!(val, Value::Null)))
143            }
144            _ => Err(QueryError::unsupported(
145                OxiGdalError::not_supported_builder("Unsupported expression type in filter")
146                    .with_operation("filter_evaluation")
147                    .with_parameter("expression_type", format!("{:?}", expr))
148                    .with_suggestion(
149                        "Use simpler expressions: columns, literals, binary/unary operators, IS [NOT] NULL",
150                    )
151                    .build()
152                    .to_string(),
153            )),
154        }
155    }
156
157    /// Get value from column at row index.
158    fn get_column_value(&self, column: &ColumnData, row_idx: usize) -> Result<Value> {
159        match column {
160            ColumnData::Boolean(data) => Ok(data
161                .get(row_idx)
162                .and_then(|v| v.as_ref())
163                .map(|&v| Value::Boolean(v))
164                .unwrap_or(Value::Null)),
165            ColumnData::Int32(data) => Ok(data
166                .get(row_idx)
167                .and_then(|v| v.as_ref())
168                .map(|&v| Value::Int32(v))
169                .unwrap_or(Value::Null)),
170            ColumnData::Int64(data) => Ok(data
171                .get(row_idx)
172                .and_then(|v| v.as_ref())
173                .map(|&v| Value::Int64(v))
174                .unwrap_or(Value::Null)),
175            ColumnData::Float32(data) => Ok(data
176                .get(row_idx)
177                .and_then(|v| v.as_ref())
178                .map(|&v| Value::Float32(v))
179                .unwrap_or(Value::Null)),
180            ColumnData::Float64(data) => Ok(data
181                .get(row_idx)
182                .and_then(|v| v.as_ref())
183                .map(|&v| Value::Float64(v))
184                .unwrap_or(Value::Null)),
185            ColumnData::String(data) => Ok(data
186                .get(row_idx)
187                .and_then(|v| v.as_ref())
188                .map(|v| Value::String(v.clone()))
189                .unwrap_or(Value::Null)),
190            ColumnData::Binary(_) => Err(QueryError::unsupported(
191                OxiGdalError::not_supported_builder(
192                    "Binary column type not supported in filter predicates",
193                )
194                .with_operation("column_value_extraction")
195                .with_parameter("row_index", row_idx.to_string())
196                .with_suggestion(
197                    "Cast binary columns to supported types or filter at a different stage",
198                )
199                .build()
200                .to_string(),
201            )),
202        }
203    }
204
205    /// Evaluate a binary operation.
206    fn evaluate_binary_op(&self, left: &Value, op: BinaryOperator, right: &Value) -> Result<Value> {
207        match (left, right) {
208            (Value::Null, _) | (_, Value::Null) => Ok(Value::Null),
209            // Type coercion: Int32 with Int64
210            (Value::Int32(l), Value::Int64(r)) => {
211                self.evaluate_binary_op(&Value::Int64(*l as i64), op, &Value::Int64(*r))
212            }
213            (Value::Int64(l), Value::Int32(r)) => {
214                self.evaluate_binary_op(&Value::Int64(*l), op, &Value::Int64(*r as i64))
215            }
216            (Value::Int32(l), Value::Int32(r)) => match op {
217                BinaryOperator::Plus => Ok(Value::Int32(l + r)),
218                BinaryOperator::Minus => Ok(Value::Int32(l - r)),
219                BinaryOperator::Multiply => Ok(Value::Int32(l * r)),
220                BinaryOperator::Divide => {
221                    if *r == 0 {
222                        Ok(Value::Null)
223                    } else {
224                        Ok(Value::Int32(l / r))
225                    }
226                }
227                BinaryOperator::Modulo => {
228                    if *r == 0 {
229                        Ok(Value::Null)
230                    } else {
231                        Ok(Value::Int32(l % r))
232                    }
233                }
234                BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
235                BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
236                BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
237                BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
238                BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
239                BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
240                _ => Err(QueryError::unsupported("Unsupported operator for integers")),
241            },
242            (Value::Int64(l), Value::Int64(r)) => match op {
243                BinaryOperator::Plus => Ok(Value::Int64(l + r)),
244                BinaryOperator::Minus => Ok(Value::Int64(l - r)),
245                BinaryOperator::Multiply => Ok(Value::Int64(l * r)),
246                BinaryOperator::Divide => {
247                    if *r == 0 {
248                        Ok(Value::Null)
249                    } else {
250                        Ok(Value::Int64(l / r))
251                    }
252                }
253                BinaryOperator::Modulo => {
254                    if *r == 0 {
255                        Ok(Value::Null)
256                    } else {
257                        Ok(Value::Int64(l % r))
258                    }
259                }
260                BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
261                BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
262                BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
263                BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
264                BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
265                BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
266                _ => Err(QueryError::unsupported("Unsupported operator for integers")),
267            },
268            // Type coercion: Float32 with Float64
269            (Value::Float32(l), Value::Float64(r)) => {
270                self.evaluate_binary_op(&Value::Float64(*l as f64), op, &Value::Float64(*r))
271            }
272            (Value::Float64(l), Value::Float32(r)) => {
273                self.evaluate_binary_op(&Value::Float64(*l), op, &Value::Float64(*r as f64))
274            }
275            (Value::Float32(l), Value::Float32(r)) => match op {
276                BinaryOperator::Plus => Ok(Value::Float32(l + r)),
277                BinaryOperator::Minus => Ok(Value::Float32(l - r)),
278                BinaryOperator::Multiply => Ok(Value::Float32(l * r)),
279                BinaryOperator::Divide => Ok(Value::Float32(l / r)),
280                BinaryOperator::Eq => Ok(Value::Boolean((l - r).abs() < f32::EPSILON)),
281                BinaryOperator::NotEq => Ok(Value::Boolean((l - r).abs() >= f32::EPSILON)),
282                BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
283                BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
284                BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
285                BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
286                _ => Err(QueryError::unsupported("Unsupported operator for floats")),
287            },
288            (Value::Float64(l), Value::Float64(r)) => match op {
289                BinaryOperator::Plus => Ok(Value::Float64(l + r)),
290                BinaryOperator::Minus => Ok(Value::Float64(l - r)),
291                BinaryOperator::Multiply => Ok(Value::Float64(l * r)),
292                BinaryOperator::Divide => Ok(Value::Float64(l / r)),
293                BinaryOperator::Eq => Ok(Value::Boolean((l - r).abs() < f64::EPSILON)),
294                BinaryOperator::NotEq => Ok(Value::Boolean((l - r).abs() >= f64::EPSILON)),
295                BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
296                BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
297                BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
298                BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
299                _ => Err(QueryError::unsupported("Unsupported operator for floats")),
300            },
301            // Type coercion: Int with Float
302            (Value::Int32(l), Value::Float64(r)) => {
303                self.evaluate_binary_op(&Value::Float64(*l as f64), op, &Value::Float64(*r))
304            }
305            (Value::Int64(l), Value::Float64(r)) => {
306                self.evaluate_binary_op(&Value::Float64(*l as f64), op, &Value::Float64(*r))
307            }
308            (Value::Float64(l), Value::Int32(r)) => {
309                self.evaluate_binary_op(&Value::Float64(*l), op, &Value::Float64(*r as f64))
310            }
311            (Value::Float64(l), Value::Int64(r)) => {
312                self.evaluate_binary_op(&Value::Float64(*l), op, &Value::Float64(*r as f64))
313            }
314            (Value::Boolean(l), Value::Boolean(r)) => match op {
315                BinaryOperator::And => Ok(Value::Boolean(*l && *r)),
316                BinaryOperator::Or => Ok(Value::Boolean(*l || *r)),
317                BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
318                BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
319                _ => Err(QueryError::unsupported("Unsupported operator for booleans")),
320            },
321            (Value::String(l), Value::String(r)) => match op {
322                BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
323                BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
324                BinaryOperator::Concat => Ok(Value::String(format!("{}{}", l, r))),
325                _ => Err(QueryError::unsupported("Unsupported operator for strings")),
326            },
327            _ => Err(QueryError::execution(
328                OxiGdalError::invalid_operation_builder("Type mismatch in binary operation")
329                    .with_operation("binary_operator_evaluation")
330                    .with_parameter("left_type", format!("{:?}", left))
331                    .with_parameter("right_type", format!("{:?}", right))
332                    .with_parameter("operator", format!("{:?}", op))
333                    .with_suggestion(
334                        "Ensure both operands have compatible types or use explicit type casts",
335                    )
336                    .build()
337                    .to_string(),
338            )),
339        }
340    }
341
342    /// Evaluate a unary operation.
343    fn evaluate_unary_op(&self, op: UnaryOperator, val: &Value) -> Result<Value> {
344        match (op, val) {
345            (UnaryOperator::Minus, Value::Int64(i)) => Ok(Value::Int64(-i)),
346            (UnaryOperator::Minus, Value::Float64(f)) => Ok(Value::Float64(-f)),
347            (UnaryOperator::Not, Value::Boolean(b)) => Ok(Value::Boolean(!b)),
348            (_, Value::Null) => Ok(Value::Null),
349            _ => Err(QueryError::unsupported("Unsupported unary operation")),
350        }
351    }
352}
353
354/// Runtime value.
355#[derive(Debug, Clone, PartialEq)]
356pub enum Value {
357    /// Null value.
358    Null,
359    /// Boolean value.
360    Boolean(bool),
361    /// 32-bit integer value.
362    Int32(i32),
363    /// 64-bit integer value.
364    Int64(i64),
365    /// 32-bit float value.
366    Float32(f32),
367    /// 64-bit float value.
368    Float64(f64),
369    /// String value.
370    String(String),
371}
372
373impl Value {
374    /// Convert from a literal.
375    fn from_literal(lit: &Literal) -> Self {
376        match lit {
377            Literal::Null => Value::Null,
378            Literal::Boolean(b) => Value::Boolean(*b),
379            Literal::Integer(i) => Value::Int64(*i),
380            Literal::Float(f) => Value::Float64(*f),
381            Literal::String(s) => Value::String(s.clone()),
382        }
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use crate::executor::scan::{Field, Schema};
390    use std::sync::Arc;
391
392    #[test]
393    fn test_filter_execution() -> Result<()> {
394        let schema = Arc::new(Schema::new(vec![
395            Field::new(
396                "id".to_string(),
397                crate::executor::scan::DataType::Int64,
398                false,
399            ),
400            Field::new(
401                "value".to_string(),
402                crate::executor::scan::DataType::Int64,
403                false,
404            ),
405        ]));
406
407        let columns = vec![
408            ColumnData::Int64(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]),
409            ColumnData::Int64(vec![Some(10), Some(20), Some(30), Some(40), Some(50)]),
410        ];
411
412        let batch = RecordBatch::new(schema, columns, 5)?;
413
414        // Filter: value > 25
415        let predicate = Expr::BinaryOp {
416            left: Box::new(Expr::Column {
417                table: None,
418                name: "value".to_string(),
419            }),
420            op: BinaryOperator::Gt,
421            right: Box::new(Expr::Literal(Literal::Integer(25))),
422        };
423
424        let filter = Filter::new(predicate);
425        let filtered = filter.execute(&batch)?;
426
427        assert_eq!(filtered.num_rows, 3); // 30, 40, 50 are > 25
428
429        Ok(())
430    }
431}