1use crate::ast::expr::{BinaryOp, ColumnRef};
4use alloc::boxed::Box;
5use alloc::vec::Vec;
6use cynos_core::{Row, Value};
7
8#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10pub enum EvalType {
11 Eq,
12 Ne,
13 Lt,
14 Le,
15 Gt,
16 Ge,
17 Match,
18 Between,
19 In,
20}
21
22impl From<BinaryOp> for EvalType {
23 fn from(op: BinaryOp) -> Self {
24 match op {
25 BinaryOp::Eq => EvalType::Eq,
26 BinaryOp::Ne => EvalType::Ne,
27 BinaryOp::Lt => EvalType::Lt,
28 BinaryOp::Le => EvalType::Le,
29 BinaryOp::Gt => EvalType::Gt,
30 BinaryOp::Ge => EvalType::Ge,
31 BinaryOp::Like => EvalType::Match,
32 BinaryOp::In => EvalType::In,
33 BinaryOp::Between => EvalType::Between,
34 _ => EvalType::Eq,
35 }
36 }
37}
38
39pub trait Predicate {
41 fn eval(&self, row: &Row) -> bool;
43
44 fn columns(&self) -> Vec<&ColumnRef>;
46
47 fn tables(&self) -> Vec<&str>;
49}
50
51#[derive(Clone, Debug)]
53pub struct ValuePredicate {
54 pub column: ColumnRef,
55 pub eval_type: EvalType,
56 pub value: Value,
57}
58
59impl ValuePredicate {
60 pub fn new(column: ColumnRef, eval_type: EvalType, value: Value) -> Self {
61 Self {
62 column,
63 eval_type,
64 value,
65 }
66 }
67
68 pub fn eq(column: ColumnRef, value: Value) -> Self {
69 Self::new(column, EvalType::Eq, value)
70 }
71
72 pub fn ne(column: ColumnRef, value: Value) -> Self {
73 Self::new(column, EvalType::Ne, value)
74 }
75
76 pub fn lt(column: ColumnRef, value: Value) -> Self {
77 Self::new(column, EvalType::Lt, value)
78 }
79
80 pub fn le(column: ColumnRef, value: Value) -> Self {
81 Self::new(column, EvalType::Le, value)
82 }
83
84 pub fn gt(column: ColumnRef, value: Value) -> Self {
85 Self::new(column, EvalType::Gt, value)
86 }
87
88 pub fn ge(column: ColumnRef, value: Value) -> Self {
89 Self::new(column, EvalType::Ge, value)
90 }
91}
92
93impl Predicate for ValuePredicate {
94 fn eval(&self, row: &Row) -> bool {
95 let row_value = match row.get(self.column.index) {
96 Some(v) => v,
97 None => return false,
98 };
99
100 match self.eval_type {
101 EvalType::Eq => row_value == &self.value,
102 EvalType::Ne => row_value != &self.value,
103 EvalType::Lt => row_value < &self.value,
104 EvalType::Le => row_value <= &self.value,
105 EvalType::Gt => row_value > &self.value,
106 EvalType::Ge => row_value >= &self.value,
107 _ => false,
108 }
109 }
110
111 fn columns(&self) -> Vec<&ColumnRef> {
112 alloc::vec![&self.column]
113 }
114
115 fn tables(&self) -> Vec<&str> {
116 alloc::vec![self.column.table.as_str()]
117 }
118}
119
120#[derive(Clone, Copy, Debug, PartialEq, Eq)]
122pub enum JoinType {
123 Inner,
124 LeftOuter,
125 RightOuter,
126 FullOuter,
127 Cross,
128}
129
130#[derive(Clone, Debug)]
132pub struct JoinPredicate {
133 pub left_column: ColumnRef,
134 pub right_column: ColumnRef,
135 pub eval_type: EvalType,
136 pub join_type: JoinType,
137}
138
139impl JoinPredicate {
140 pub fn new(
141 left_column: ColumnRef,
142 right_column: ColumnRef,
143 eval_type: EvalType,
144 join_type: JoinType,
145 ) -> Self {
146 Self {
147 left_column,
148 right_column,
149 eval_type,
150 join_type,
151 }
152 }
153
154 pub fn inner(left_column: ColumnRef, right_column: ColumnRef, eval_type: EvalType) -> Self {
155 Self::new(left_column, right_column, eval_type, JoinType::Inner)
156 }
157
158 pub fn left_outer(
159 left_column: ColumnRef,
160 right_column: ColumnRef,
161 eval_type: EvalType,
162 ) -> Self {
163 Self::new(left_column, right_column, eval_type, JoinType::LeftOuter)
164 }
165
166 pub fn reverse(&self) -> Self {
168 let new_eval_type = match self.eval_type {
169 EvalType::Lt => EvalType::Gt,
170 EvalType::Le => EvalType::Ge,
171 EvalType::Gt => EvalType::Lt,
172 EvalType::Ge => EvalType::Le,
173 other => other,
174 };
175 Self::new(
176 self.right_column.clone(),
177 self.left_column.clone(),
178 new_eval_type,
179 self.join_type,
180 )
181 }
182
183 pub fn is_equi_join(&self) -> bool {
185 self.eval_type == EvalType::Eq
186 }
187
188 pub fn eval_rows(&self, left_row: &Row, right_row: &Row) -> bool {
190 let left_value = match left_row.get(self.left_column.index) {
191 Some(v) => v,
192 None => return false,
193 };
194 let right_value = match right_row.get(self.right_column.index) {
195 Some(v) => v,
196 None => return false,
197 };
198
199 if left_value.is_null() || right_value.is_null() {
201 return false;
202 }
203
204 match self.eval_type {
205 EvalType::Eq => left_value == right_value,
206 EvalType::Ne => left_value != right_value,
207 EvalType::Lt => left_value < right_value,
208 EvalType::Le => left_value <= right_value,
209 EvalType::Gt => left_value > right_value,
210 EvalType::Ge => left_value >= right_value,
211 _ => false,
212 }
213 }
214}
215
216impl Predicate for JoinPredicate {
217 fn eval(&self, row: &Row) -> bool {
218 let left_value = match row.get(self.left_column.index) {
220 Some(v) => v,
221 None => return false,
222 };
223 let right_value = match row.get(self.right_column.index) {
224 Some(v) => v,
225 None => return false,
226 };
227
228 if left_value.is_null() || right_value.is_null() {
229 return false;
230 }
231
232 match self.eval_type {
233 EvalType::Eq => left_value == right_value,
234 EvalType::Ne => left_value != right_value,
235 EvalType::Lt => left_value < right_value,
236 EvalType::Le => left_value <= right_value,
237 EvalType::Gt => left_value > right_value,
238 EvalType::Ge => left_value >= right_value,
239 _ => false,
240 }
241 }
242
243 fn columns(&self) -> Vec<&ColumnRef> {
244 alloc::vec![&self.left_column, &self.right_column]
245 }
246
247 fn tables(&self) -> Vec<&str> {
248 alloc::vec![
249 self.left_column.table.as_str(),
250 self.right_column.table.as_str()
251 ]
252 }
253}
254
255#[derive(Clone, Copy, Debug, PartialEq, Eq)]
257pub enum LogicalOp {
258 And,
259 Or,
260}
261
262#[derive(Clone, Debug)]
264pub struct CombinedPredicate {
265 pub op: LogicalOp,
266 pub children: Vec<Box<dyn PredicateClone>>,
267}
268
269pub trait PredicateClone: Predicate {
271 fn clone_box(&self) -> Box<dyn PredicateClone>;
272}
273
274impl<T: Predicate + Clone + 'static> PredicateClone for T {
275 fn clone_box(&self) -> Box<dyn PredicateClone> {
276 Box::new(self.clone())
277 }
278}
279
280impl Clone for Box<dyn PredicateClone> {
281 fn clone(&self) -> Self {
282 self.clone_box()
283 }
284}
285
286impl core::fmt::Debug for Box<dyn PredicateClone> {
287 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
288 write!(f, "PredicateClone")
289 }
290}
291
292impl CombinedPredicate {
293 pub fn and(children: Vec<Box<dyn PredicateClone>>) -> Self {
294 Self {
295 op: LogicalOp::And,
296 children,
297 }
298 }
299
300 pub fn or(children: Vec<Box<dyn PredicateClone>>) -> Self {
301 Self {
302 op: LogicalOp::Or,
303 children,
304 }
305 }
306}
307
308impl Predicate for CombinedPredicate {
309 fn eval(&self, row: &Row) -> bool {
310 match self.op {
311 LogicalOp::And => self.children.iter().all(|p| p.eval(row)),
312 LogicalOp::Or => self.children.iter().any(|p| p.eval(row)),
313 }
314 }
315
316 fn columns(&self) -> Vec<&ColumnRef> {
317 self.children.iter().flat_map(|p| p.columns()).collect()
318 }
319
320 fn tables(&self) -> Vec<&str> {
321 self.children.iter().flat_map(|p| p.tables()).collect()
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use alloc::vec;
329
330 #[test]
331 fn test_value_predicate_eq() {
332 let col = ColumnRef::new("t", "id", 0);
333 let pred = ValuePredicate::eq(col, Value::Int64(42));
334
335 let row_match = Row::new(1, vec![Value::Int64(42)]);
336 let row_no_match = Row::new(2, vec![Value::Int64(100)]);
337
338 assert!(pred.eval(&row_match));
339 assert!(!pred.eval(&row_no_match));
340 }
341
342 #[test]
343 fn test_value_predicate_comparison() {
344 let col = ColumnRef::new("t", "value", 0);
345
346 let pred_lt = ValuePredicate::lt(col.clone(), Value::Int64(50));
347 let pred_gt = ValuePredicate::gt(col.clone(), Value::Int64(50));
348
349 let row = Row::new(1, vec![Value::Int64(30)]);
350
351 assert!(pred_lt.eval(&row));
352 assert!(!pred_gt.eval(&row));
353 }
354
355 #[test]
356 fn test_join_predicate() {
357 let left_col = ColumnRef::new("a", "id", 0);
358 let right_col = ColumnRef::new("b", "a_id", 1);
359 let pred = JoinPredicate::inner(left_col, right_col, EvalType::Eq);
360
361 let left_row = Row::new(1, vec![Value::Int64(10)]);
362 let right_row_match = Row::new(2, vec![Value::Int64(10)]);
363 let right_row_no_match = Row::new(3, vec![Value::Int64(20)]);
364
365 assert!(pred.is_equi_join());
368 }
369
370 #[test]
371 fn test_join_predicate_reverse() {
372 let left_col = ColumnRef::new("a", "id", 0);
373 let right_col = ColumnRef::new("b", "a_id", 1);
374 let pred = JoinPredicate::inner(left_col, right_col, EvalType::Lt);
375
376 let reversed = pred.reverse();
377 assert_eq!(reversed.eval_type, EvalType::Gt);
378 assert_eq!(reversed.left_column.table, "b");
379 assert_eq!(reversed.right_column.table, "a");
380 }
381}