Skip to main content

cynos_query/ast/
predicate.rs

1//! Predicate definitions for query filtering.
2
3use crate::ast::expr::{BinaryOp, ColumnRef};
4use alloc::boxed::Box;
5use alloc::vec::Vec;
6use cynos_core::{Row, Value};
7
8/// Evaluation type for predicates.
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10pub enum EvalType {
11    Eq,
12    Ne,
13    Lt,
14    Le,
15    Gt,
16    Ge,
17    Match,
18    Between,
19    In,
20}
21
22impl From<BinaryOp> for EvalType {
23    fn from(op: BinaryOp) -> Self {
24        match op {
25            BinaryOp::Eq => EvalType::Eq,
26            BinaryOp::Ne => EvalType::Ne,
27            BinaryOp::Lt => EvalType::Lt,
28            BinaryOp::Le => EvalType::Le,
29            BinaryOp::Gt => EvalType::Gt,
30            BinaryOp::Ge => EvalType::Ge,
31            BinaryOp::Like => EvalType::Match,
32            BinaryOp::In => EvalType::In,
33            BinaryOp::Between => EvalType::Between,
34            _ => EvalType::Eq,
35        }
36    }
37}
38
39/// A predicate that can be evaluated against rows.
40pub trait Predicate {
41    /// Evaluates the predicate against a row.
42    fn eval(&self, row: &Row) -> bool;
43
44    /// Returns the columns referenced by this predicate.
45    fn columns(&self) -> Vec<&ColumnRef>;
46
47    /// Returns the tables referenced by this predicate.
48    fn tables(&self) -> Vec<&str>;
49}
50
51/// A value predicate compares a column to a literal value.
52#[derive(Clone, Debug)]
53pub struct ValuePredicate {
54    pub column: ColumnRef,
55    pub eval_type: EvalType,
56    pub value: Value,
57}
58
59impl ValuePredicate {
60    pub fn new(column: ColumnRef, eval_type: EvalType, value: Value) -> Self {
61        Self {
62            column,
63            eval_type,
64            value,
65        }
66    }
67
68    pub fn eq(column: ColumnRef, value: Value) -> Self {
69        Self::new(column, EvalType::Eq, value)
70    }
71
72    pub fn ne(column: ColumnRef, value: Value) -> Self {
73        Self::new(column, EvalType::Ne, value)
74    }
75
76    pub fn lt(column: ColumnRef, value: Value) -> Self {
77        Self::new(column, EvalType::Lt, value)
78    }
79
80    pub fn le(column: ColumnRef, value: Value) -> Self {
81        Self::new(column, EvalType::Le, value)
82    }
83
84    pub fn gt(column: ColumnRef, value: Value) -> Self {
85        Self::new(column, EvalType::Gt, value)
86    }
87
88    pub fn ge(column: ColumnRef, value: Value) -> Self {
89        Self::new(column, EvalType::Ge, value)
90    }
91}
92
93impl Predicate for ValuePredicate {
94    fn eval(&self, row: &Row) -> bool {
95        let row_value = match row.get(self.column.index) {
96            Some(v) => v,
97            None => return false,
98        };
99
100        match self.eval_type {
101            EvalType::Eq => row_value == &self.value,
102            EvalType::Ne => row_value != &self.value,
103            EvalType::Lt => row_value < &self.value,
104            EvalType::Le => row_value <= &self.value,
105            EvalType::Gt => row_value > &self.value,
106            EvalType::Ge => row_value >= &self.value,
107            _ => false,
108        }
109    }
110
111    fn columns(&self) -> Vec<&ColumnRef> {
112        alloc::vec![&self.column]
113    }
114
115    fn tables(&self) -> Vec<&str> {
116        alloc::vec![self.column.table.as_str()]
117    }
118}
119
120/// Join type for join predicates.
121#[derive(Clone, Copy, Debug, PartialEq, Eq)]
122pub enum JoinType {
123    Inner,
124    LeftOuter,
125    RightOuter,
126    FullOuter,
127    Cross,
128}
129
130/// A join predicate compares columns from two tables.
131#[derive(Clone, Debug)]
132pub struct JoinPredicate {
133    pub left_column: ColumnRef,
134    pub right_column: ColumnRef,
135    pub eval_type: EvalType,
136    pub join_type: JoinType,
137}
138
139impl JoinPredicate {
140    pub fn new(
141        left_column: ColumnRef,
142        right_column: ColumnRef,
143        eval_type: EvalType,
144        join_type: JoinType,
145    ) -> Self {
146        Self {
147            left_column,
148            right_column,
149            eval_type,
150            join_type,
151        }
152    }
153
154    pub fn inner(left_column: ColumnRef, right_column: ColumnRef, eval_type: EvalType) -> Self {
155        Self::new(left_column, right_column, eval_type, JoinType::Inner)
156    }
157
158    pub fn left_outer(
159        left_column: ColumnRef,
160        right_column: ColumnRef,
161        eval_type: EvalType,
162    ) -> Self {
163        Self::new(left_column, right_column, eval_type, JoinType::LeftOuter)
164    }
165
166    /// Reverses the join predicate (swaps left and right columns).
167    pub fn reverse(&self) -> Self {
168        let new_eval_type = match self.eval_type {
169            EvalType::Lt => EvalType::Gt,
170            EvalType::Le => EvalType::Ge,
171            EvalType::Gt => EvalType::Lt,
172            EvalType::Ge => EvalType::Le,
173            other => other,
174        };
175        Self::new(
176            self.right_column.clone(),
177            self.left_column.clone(),
178            new_eval_type,
179            self.join_type,
180        )
181    }
182
183    /// Checks if this is an equi-join (equality comparison).
184    pub fn is_equi_join(&self) -> bool {
185        self.eval_type == EvalType::Eq
186    }
187
188    /// Evaluates the join condition for two rows.
189    pub fn eval_rows(&self, left_row: &Row, right_row: &Row) -> bool {
190        let left_value = match left_row.get(self.left_column.index) {
191            Some(v) => v,
192            None => return false,
193        };
194        let right_value = match right_row.get(self.right_column.index) {
195            Some(v) => v,
196            None => return false,
197        };
198
199        // NULL values don't match in joins
200        if left_value.is_null() || right_value.is_null() {
201            return false;
202        }
203
204        match self.eval_type {
205            EvalType::Eq => left_value == right_value,
206            EvalType::Ne => left_value != right_value,
207            EvalType::Lt => left_value < right_value,
208            EvalType::Le => left_value <= right_value,
209            EvalType::Gt => left_value > right_value,
210            EvalType::Ge => left_value >= right_value,
211            _ => false,
212        }
213    }
214}
215
216impl Predicate for JoinPredicate {
217    fn eval(&self, row: &Row) -> bool {
218        // For a combined row, we need both column indices to be valid
219        let left_value = match row.get(self.left_column.index) {
220            Some(v) => v,
221            None => return false,
222        };
223        let right_value = match row.get(self.right_column.index) {
224            Some(v) => v,
225            None => return false,
226        };
227
228        if left_value.is_null() || right_value.is_null() {
229            return false;
230        }
231
232        match self.eval_type {
233            EvalType::Eq => left_value == right_value,
234            EvalType::Ne => left_value != right_value,
235            EvalType::Lt => left_value < right_value,
236            EvalType::Le => left_value <= right_value,
237            EvalType::Gt => left_value > right_value,
238            EvalType::Ge => left_value >= right_value,
239            _ => false,
240        }
241    }
242
243    fn columns(&self) -> Vec<&ColumnRef> {
244        alloc::vec![&self.left_column, &self.right_column]
245    }
246
247    fn tables(&self) -> Vec<&str> {
248        alloc::vec![
249            self.left_column.table.as_str(),
250            self.right_column.table.as_str()
251        ]
252    }
253}
254
255/// Logical operator for combining predicates.
256#[derive(Clone, Copy, Debug, PartialEq, Eq)]
257pub enum LogicalOp {
258    And,
259    Or,
260}
261
262/// A combined predicate joins multiple predicates with AND/OR.
263#[derive(Clone, Debug)]
264pub struct CombinedPredicate {
265    pub op: LogicalOp,
266    pub children: Vec<Box<dyn PredicateClone>>,
267}
268
269/// Helper trait for cloning boxed predicates.
270pub trait PredicateClone: Predicate {
271    fn clone_box(&self) -> Box<dyn PredicateClone>;
272}
273
274impl<T: Predicate + Clone + 'static> PredicateClone for T {
275    fn clone_box(&self) -> Box<dyn PredicateClone> {
276        Box::new(self.clone())
277    }
278}
279
280impl Clone for Box<dyn PredicateClone> {
281    fn clone(&self) -> Self {
282        self.clone_box()
283    }
284}
285
286impl core::fmt::Debug for Box<dyn PredicateClone> {
287    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
288        write!(f, "PredicateClone")
289    }
290}
291
292impl CombinedPredicate {
293    pub fn and(children: Vec<Box<dyn PredicateClone>>) -> Self {
294        Self {
295            op: LogicalOp::And,
296            children,
297        }
298    }
299
300    pub fn or(children: Vec<Box<dyn PredicateClone>>) -> Self {
301        Self {
302            op: LogicalOp::Or,
303            children,
304        }
305    }
306}
307
308impl Predicate for CombinedPredicate {
309    fn eval(&self, row: &Row) -> bool {
310        match self.op {
311            LogicalOp::And => self.children.iter().all(|p| p.eval(row)),
312            LogicalOp::Or => self.children.iter().any(|p| p.eval(row)),
313        }
314    }
315
316    fn columns(&self) -> Vec<&ColumnRef> {
317        self.children.iter().flat_map(|p| p.columns()).collect()
318    }
319
320    fn tables(&self) -> Vec<&str> {
321        self.children.iter().flat_map(|p| p.tables()).collect()
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use alloc::vec;
329
330    #[test]
331    fn test_value_predicate_eq() {
332        let col = ColumnRef::new("t", "id", 0);
333        let pred = ValuePredicate::eq(col, Value::Int64(42));
334
335        let row_match = Row::new(1, vec![Value::Int64(42)]);
336        let row_no_match = Row::new(2, vec![Value::Int64(100)]);
337
338        assert!(pred.eval(&row_match));
339        assert!(!pred.eval(&row_no_match));
340    }
341
342    #[test]
343    fn test_value_predicate_comparison() {
344        let col = ColumnRef::new("t", "value", 0);
345
346        let pred_lt = ValuePredicate::lt(col.clone(), Value::Int64(50));
347        let pred_gt = ValuePredicate::gt(col.clone(), Value::Int64(50));
348
349        let row = Row::new(1, vec![Value::Int64(30)]);
350
351        assert!(pred_lt.eval(&row));
352        assert!(!pred_gt.eval(&row));
353    }
354
355    #[test]
356    fn test_join_predicate() {
357        let left_col = ColumnRef::new("a", "id", 0);
358        let right_col = ColumnRef::new("b", "a_id", 1);
359        let pred = JoinPredicate::inner(left_col, right_col, EvalType::Eq);
360
361        let left_row = Row::new(1, vec![Value::Int64(10)]);
362        let right_row_match = Row::new(2, vec![Value::Int64(10)]);
363        let right_row_no_match = Row::new(3, vec![Value::Int64(20)]);
364
365        // For eval_rows, we pass separate rows
366        // Note: This test uses a simplified model where we check values at specific indices
367        assert!(pred.is_equi_join());
368    }
369
370    #[test]
371    fn test_join_predicate_reverse() {
372        let left_col = ColumnRef::new("a", "id", 0);
373        let right_col = ColumnRef::new("b", "a_id", 1);
374        let pred = JoinPredicate::inner(left_col, right_col, EvalType::Lt);
375
376        let reversed = pred.reverse();
377        assert_eq!(reversed.eval_type, EvalType::Gt);
378        assert_eq!(reversed.left_column.table, "b");
379        assert_eq!(reversed.right_column.table, "a");
380    }
381}