1use crate::error::{QueryError, Result};
4use crate::executor::scan::{ColumnData, RecordBatch};
5use crate::parser::ast::{BinaryOperator, Expr, Literal, UnaryOperator};
6use oxigdal_core::error::OxiGdalError;
7
8pub struct Filter {
10 pub predicate: Expr,
12}
13
14impl Filter {
15 pub fn new(predicate: Expr) -> Self {
17 Self { predicate }
18 }
19
20 pub fn execute(&self, batch: &RecordBatch) -> Result<RecordBatch> {
22 let mut selection = vec![false; batch.num_rows];
23
24 for (row_idx, sel) in selection.iter_mut().enumerate().take(batch.num_rows) {
26 let result = self.evaluate_expr(&self.predicate, batch, row_idx)?;
27 if let Value::Boolean(b) = result {
28 *sel = b;
29 } else {
30 return Err(QueryError::execution(
31 OxiGdalError::invalid_operation_builder(
32 "Filter predicate must evaluate to boolean type",
33 )
34 .with_operation("filter_evaluation")
35 .with_parameter("row_index", row_idx.to_string())
36 .with_parameter("actual_type", format!("{:?}", result))
37 .with_suggestion("Ensure WHERE clause uses comparison or boolean operators")
38 .build()
39 .to_string(),
40 ));
41 }
42 }
43
44 let mut filtered_columns = Vec::new();
46 for column in &batch.columns {
47 filtered_columns.push(self.filter_column(column, &selection));
48 }
49
50 let filtered_rows = selection.iter().filter(|&&b| b).count();
51
52 RecordBatch::new(batch.schema.clone(), filtered_columns, filtered_rows)
53 }
54
55 fn filter_column(&self, column: &ColumnData, selection: &[bool]) -> ColumnData {
57 match column {
58 ColumnData::Boolean(data) => {
59 let filtered: Vec<Option<bool>> = data
60 .iter()
61 .zip(selection)
62 .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
63 .collect();
64 ColumnData::Boolean(filtered)
65 }
66 ColumnData::Int32(data) => {
67 let filtered: Vec<Option<i32>> = data
68 .iter()
69 .zip(selection)
70 .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
71 .collect();
72 ColumnData::Int32(filtered)
73 }
74 ColumnData::Int64(data) => {
75 let filtered: Vec<Option<i64>> = data
76 .iter()
77 .zip(selection)
78 .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
79 .collect();
80 ColumnData::Int64(filtered)
81 }
82 ColumnData::Float32(data) => {
83 let filtered: Vec<Option<f32>> = data
84 .iter()
85 .zip(selection)
86 .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
87 .collect();
88 ColumnData::Float32(filtered)
89 }
90 ColumnData::Float64(data) => {
91 let filtered: Vec<Option<f64>> = data
92 .iter()
93 .zip(selection)
94 .filter_map(|(v, &sel)| if sel { Some(*v) } else { None })
95 .collect();
96 ColumnData::Float64(filtered)
97 }
98 ColumnData::String(data) => {
99 let filtered: Vec<Option<String>> = data
100 .iter()
101 .zip(selection)
102 .filter_map(|(v, &sel)| if sel { Some(v.clone()) } else { None })
103 .collect();
104 ColumnData::String(filtered)
105 }
106 ColumnData::Binary(data) => {
107 let filtered = data
108 .iter()
109 .zip(selection)
110 .filter_map(|(v, &sel)| if sel { Some(v.clone()) } else { None })
111 .collect();
112 ColumnData::Binary(filtered)
113 }
114 }
115 }
116
117 fn evaluate_expr(&self, expr: &Expr, batch: &RecordBatch, row_idx: usize) -> Result<Value> {
119 match expr {
120 Expr::Column { table: _, name } => {
121 let column = batch
122 .column_by_name(name)
123 .ok_or_else(|| QueryError::ColumnNotFound(name.clone()))?;
124 self.get_column_value(column, row_idx)
125 }
126 Expr::Literal(lit) => Ok(Value::from_literal(lit)),
127 Expr::BinaryOp { left, op, right } => {
128 let left_val = self.evaluate_expr(left, batch, row_idx)?;
129 let right_val = self.evaluate_expr(right, batch, row_idx)?;
130 self.evaluate_binary_op(&left_val, *op, &right_val)
131 }
132 Expr::UnaryOp { op, expr } => {
133 let val = self.evaluate_expr(expr, batch, row_idx)?;
134 self.evaluate_unary_op(*op, &val)
135 }
136 Expr::IsNull(expr) => {
137 let val = self.evaluate_expr(expr, batch, row_idx)?;
138 Ok(Value::Boolean(matches!(val, Value::Null)))
139 }
140 Expr::IsNotNull(expr) => {
141 let val = self.evaluate_expr(expr, batch, row_idx)?;
142 Ok(Value::Boolean(!matches!(val, Value::Null)))
143 }
144 _ => Err(QueryError::unsupported(
145 OxiGdalError::not_supported_builder("Unsupported expression type in filter")
146 .with_operation("filter_evaluation")
147 .with_parameter("expression_type", format!("{:?}", expr))
148 .with_suggestion(
149 "Use simpler expressions: columns, literals, binary/unary operators, IS [NOT] NULL",
150 )
151 .build()
152 .to_string(),
153 )),
154 }
155 }
156
157 fn get_column_value(&self, column: &ColumnData, row_idx: usize) -> Result<Value> {
159 match column {
160 ColumnData::Boolean(data) => Ok(data
161 .get(row_idx)
162 .and_then(|v| v.as_ref())
163 .map(|&v| Value::Boolean(v))
164 .unwrap_or(Value::Null)),
165 ColumnData::Int32(data) => Ok(data
166 .get(row_idx)
167 .and_then(|v| v.as_ref())
168 .map(|&v| Value::Int32(v))
169 .unwrap_or(Value::Null)),
170 ColumnData::Int64(data) => Ok(data
171 .get(row_idx)
172 .and_then(|v| v.as_ref())
173 .map(|&v| Value::Int64(v))
174 .unwrap_or(Value::Null)),
175 ColumnData::Float32(data) => Ok(data
176 .get(row_idx)
177 .and_then(|v| v.as_ref())
178 .map(|&v| Value::Float32(v))
179 .unwrap_or(Value::Null)),
180 ColumnData::Float64(data) => Ok(data
181 .get(row_idx)
182 .and_then(|v| v.as_ref())
183 .map(|&v| Value::Float64(v))
184 .unwrap_or(Value::Null)),
185 ColumnData::String(data) => Ok(data
186 .get(row_idx)
187 .and_then(|v| v.as_ref())
188 .map(|v| Value::String(v.clone()))
189 .unwrap_or(Value::Null)),
190 ColumnData::Binary(_) => Err(QueryError::unsupported(
191 OxiGdalError::not_supported_builder(
192 "Binary column type not supported in filter predicates",
193 )
194 .with_operation("column_value_extraction")
195 .with_parameter("row_index", row_idx.to_string())
196 .with_suggestion(
197 "Cast binary columns to supported types or filter at a different stage",
198 )
199 .build()
200 .to_string(),
201 )),
202 }
203 }
204
205 fn evaluate_binary_op(&self, left: &Value, op: BinaryOperator, right: &Value) -> Result<Value> {
207 match (left, right) {
208 (Value::Null, _) | (_, Value::Null) => Ok(Value::Null),
209 (Value::Int32(l), Value::Int64(r)) => {
211 self.evaluate_binary_op(&Value::Int64(*l as i64), op, &Value::Int64(*r))
212 }
213 (Value::Int64(l), Value::Int32(r)) => {
214 self.evaluate_binary_op(&Value::Int64(*l), op, &Value::Int64(*r as i64))
215 }
216 (Value::Int32(l), Value::Int32(r)) => match op {
217 BinaryOperator::Plus => Ok(Value::Int32(l + r)),
218 BinaryOperator::Minus => Ok(Value::Int32(l - r)),
219 BinaryOperator::Multiply => Ok(Value::Int32(l * r)),
220 BinaryOperator::Divide => {
221 if *r == 0 {
222 Ok(Value::Null)
223 } else {
224 Ok(Value::Int32(l / r))
225 }
226 }
227 BinaryOperator::Modulo => {
228 if *r == 0 {
229 Ok(Value::Null)
230 } else {
231 Ok(Value::Int32(l % r))
232 }
233 }
234 BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
235 BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
236 BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
237 BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
238 BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
239 BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
240 _ => Err(QueryError::unsupported("Unsupported operator for integers")),
241 },
242 (Value::Int64(l), Value::Int64(r)) => match op {
243 BinaryOperator::Plus => Ok(Value::Int64(l + r)),
244 BinaryOperator::Minus => Ok(Value::Int64(l - r)),
245 BinaryOperator::Multiply => Ok(Value::Int64(l * r)),
246 BinaryOperator::Divide => {
247 if *r == 0 {
248 Ok(Value::Null)
249 } else {
250 Ok(Value::Int64(l / r))
251 }
252 }
253 BinaryOperator::Modulo => {
254 if *r == 0 {
255 Ok(Value::Null)
256 } else {
257 Ok(Value::Int64(l % r))
258 }
259 }
260 BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
261 BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
262 BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
263 BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
264 BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
265 BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
266 _ => Err(QueryError::unsupported("Unsupported operator for integers")),
267 },
268 (Value::Float32(l), Value::Float64(r)) => {
270 self.evaluate_binary_op(&Value::Float64(*l as f64), op, &Value::Float64(*r))
271 }
272 (Value::Float64(l), Value::Float32(r)) => {
273 self.evaluate_binary_op(&Value::Float64(*l), op, &Value::Float64(*r as f64))
274 }
275 (Value::Float32(l), Value::Float32(r)) => match op {
276 BinaryOperator::Plus => Ok(Value::Float32(l + r)),
277 BinaryOperator::Minus => Ok(Value::Float32(l - r)),
278 BinaryOperator::Multiply => Ok(Value::Float32(l * r)),
279 BinaryOperator::Divide => Ok(Value::Float32(l / r)),
280 BinaryOperator::Eq => Ok(Value::Boolean((l - r).abs() < f32::EPSILON)),
281 BinaryOperator::NotEq => Ok(Value::Boolean((l - r).abs() >= f32::EPSILON)),
282 BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
283 BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
284 BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
285 BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
286 _ => Err(QueryError::unsupported("Unsupported operator for floats")),
287 },
288 (Value::Float64(l), Value::Float64(r)) => match op {
289 BinaryOperator::Plus => Ok(Value::Float64(l + r)),
290 BinaryOperator::Minus => Ok(Value::Float64(l - r)),
291 BinaryOperator::Multiply => Ok(Value::Float64(l * r)),
292 BinaryOperator::Divide => Ok(Value::Float64(l / r)),
293 BinaryOperator::Eq => Ok(Value::Boolean((l - r).abs() < f64::EPSILON)),
294 BinaryOperator::NotEq => Ok(Value::Boolean((l - r).abs() >= f64::EPSILON)),
295 BinaryOperator::Lt => Ok(Value::Boolean(l < r)),
296 BinaryOperator::LtEq => Ok(Value::Boolean(l <= r)),
297 BinaryOperator::Gt => Ok(Value::Boolean(l > r)),
298 BinaryOperator::GtEq => Ok(Value::Boolean(l >= r)),
299 _ => Err(QueryError::unsupported("Unsupported operator for floats")),
300 },
301 (Value::Int32(l), Value::Float64(r)) => {
303 self.evaluate_binary_op(&Value::Float64(*l as f64), op, &Value::Float64(*r))
304 }
305 (Value::Int64(l), Value::Float64(r)) => {
306 self.evaluate_binary_op(&Value::Float64(*l as f64), op, &Value::Float64(*r))
307 }
308 (Value::Float64(l), Value::Int32(r)) => {
309 self.evaluate_binary_op(&Value::Float64(*l), op, &Value::Float64(*r as f64))
310 }
311 (Value::Float64(l), Value::Int64(r)) => {
312 self.evaluate_binary_op(&Value::Float64(*l), op, &Value::Float64(*r as f64))
313 }
314 (Value::Boolean(l), Value::Boolean(r)) => match op {
315 BinaryOperator::And => Ok(Value::Boolean(*l && *r)),
316 BinaryOperator::Or => Ok(Value::Boolean(*l || *r)),
317 BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
318 BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
319 _ => Err(QueryError::unsupported("Unsupported operator for booleans")),
320 },
321 (Value::String(l), Value::String(r)) => match op {
322 BinaryOperator::Eq => Ok(Value::Boolean(l == r)),
323 BinaryOperator::NotEq => Ok(Value::Boolean(l != r)),
324 BinaryOperator::Concat => Ok(Value::String(format!("{}{}", l, r))),
325 _ => Err(QueryError::unsupported("Unsupported operator for strings")),
326 },
327 _ => Err(QueryError::execution(
328 OxiGdalError::invalid_operation_builder("Type mismatch in binary operation")
329 .with_operation("binary_operator_evaluation")
330 .with_parameter("left_type", format!("{:?}", left))
331 .with_parameter("right_type", format!("{:?}", right))
332 .with_parameter("operator", format!("{:?}", op))
333 .with_suggestion(
334 "Ensure both operands have compatible types or use explicit type casts",
335 )
336 .build()
337 .to_string(),
338 )),
339 }
340 }
341
342 fn evaluate_unary_op(&self, op: UnaryOperator, val: &Value) -> Result<Value> {
344 match (op, val) {
345 (UnaryOperator::Minus, Value::Int64(i)) => Ok(Value::Int64(-i)),
346 (UnaryOperator::Minus, Value::Float64(f)) => Ok(Value::Float64(-f)),
347 (UnaryOperator::Not, Value::Boolean(b)) => Ok(Value::Boolean(!b)),
348 (_, Value::Null) => Ok(Value::Null),
349 _ => Err(QueryError::unsupported("Unsupported unary operation")),
350 }
351 }
352}
353
354#[derive(Debug, Clone, PartialEq)]
356pub enum Value {
357 Null,
359 Boolean(bool),
361 Int32(i32),
363 Int64(i64),
365 Float32(f32),
367 Float64(f64),
369 String(String),
371}
372
373impl Value {
374 fn from_literal(lit: &Literal) -> Self {
376 match lit {
377 Literal::Null => Value::Null,
378 Literal::Boolean(b) => Value::Boolean(*b),
379 Literal::Integer(i) => Value::Int64(*i),
380 Literal::Float(f) => Value::Float64(*f),
381 Literal::String(s) => Value::String(s.clone()),
382 }
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use crate::executor::scan::{Field, Schema};
390 use std::sync::Arc;
391
392 #[test]
393 fn test_filter_execution() -> Result<()> {
394 let schema = Arc::new(Schema::new(vec![
395 Field::new(
396 "id".to_string(),
397 crate::executor::scan::DataType::Int64,
398 false,
399 ),
400 Field::new(
401 "value".to_string(),
402 crate::executor::scan::DataType::Int64,
403 false,
404 ),
405 ]));
406
407 let columns = vec![
408 ColumnData::Int64(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]),
409 ColumnData::Int64(vec![Some(10), Some(20), Some(30), Some(40), Some(50)]),
410 ];
411
412 let batch = RecordBatch::new(schema, columns, 5)?;
413
414 let predicate = Expr::BinaryOp {
416 left: Box::new(Expr::Column {
417 table: None,
418 name: "value".to_string(),
419 }),
420 op: BinaryOperator::Gt,
421 right: Box::new(Expr::Literal(Literal::Integer(25))),
422 };
423
424 let filter = Filter::new(predicate);
425 let filtered = filter.execute(&batch)?;
426
427 assert_eq!(filtered.num_rows, 3); Ok(())
430 }
431}