1use std::collections::HashMap;
39use std::error::Error;
40use std::fmt;
41
42use crate::type_utils::DataType;
43
44#[derive(Debug, Clone, PartialEq)]
45pub enum Literal {
46 I64(i64),
47 F64(f64),
48 Str(String),
49}
50
51impl From<i32> for Literal {
52 fn from(v: i32) -> Self {
53 Literal::I64(v as i64)
54 }
55}
56impl From<i64> for Literal {
57 fn from(v: i64) -> Self {
58 Literal::I64(v)
59 }
60}
61impl From<f32> for Literal {
62 fn from(v: f32) -> Self {
63 Literal::F64(v as f64)
64 }
65}
66impl From<f64> for Literal {
67 fn from(v: f64) -> Self {
68 Literal::F64(v)
69 }
70}
71impl From<&str> for Literal {
72 fn from(v: &str) -> Self {
73 Literal::Str(v.to_string())
74 }
75}
76impl From<String> for Literal {
77 fn from(v: String) -> Self {
78 Literal::Str(v)
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum CmpOp {
85 Eq,
86 Neq,
87 Lt,
88 Lte,
89 Gt,
90 Gte,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub enum Expr {
95 Column(String),
96 Literal(Literal),
97 Cmp {
98 left: Box<Expr>,
99 right: Box<Expr>,
100 op: CmpOp,
101 },
102 And(Box<Expr>, Box<Expr>),
103 Or(Box<Expr>, Box<Expr>),
104 }
107
108pub fn col(name: &str) -> Expr {
110 Expr::Column(name.to_string())
111}
112pub fn lit<T: Into<Literal>>(v: T) -> Expr {
114 Expr::Literal(v.into())
115}
116
117impl Expr {
118 pub fn eq<T: Into<Literal>>(self, v: T) -> Expr {
121 Expr::Cmp {
122 left: Box::new(self),
123 right: Box::new(lit(v)),
124 op: CmpOp::Eq,
125 }
126 }
127 pub fn neq<T: Into<Literal>>(self, v: T) -> Expr {
129 Expr::Cmp {
130 left: Box::new(self),
131 right: Box::new(lit(v)),
132 op: CmpOp::Neq,
133 }
134 }
135 pub fn lt<T: Into<Literal>>(self, v: T) -> Expr {
137 Expr::Cmp {
138 left: Box::new(self),
139 right: Box::new(lit(v)),
140 op: CmpOp::Lt,
141 }
142 }
143 pub fn lte<T: Into<Literal>>(self, v: T) -> Expr {
145 Expr::Cmp {
146 left: Box::new(self),
147 right: Box::new(lit(v)),
148 op: CmpOp::Lte,
149 }
150 }
151 pub fn gt<T: Into<Literal>>(self, v: T) -> Expr {
153 Expr::Cmp {
154 left: Box::new(self),
155 right: Box::new(lit(v)),
156 op: CmpOp::Gt,
157 }
158 }
159 pub fn gte<T: Into<Literal>>(self, v: T) -> Expr {
161 Expr::Cmp {
162 left: Box::new(self),
163 right: Box::new(lit(v)),
164 op: CmpOp::Gte,
165 }
166 }
167
168 pub fn and(self, other: Expr) -> Expr {
170 Expr::And(Box::new(self), Box::new(other))
171 }
172 pub fn or(self, other: Expr) -> Expr {
174 Expr::Or(Box::new(self), Box::new(other))
175 }
176}
177
178impl std::ops::BitAnd for Expr {
180 type Output = Expr;
181 fn bitand(self, rhs: Self) -> Self::Output {
182 self.and(rhs)
183 }
184}
185impl std::ops::BitOr for Expr {
186 type Output = Expr;
187 fn bitor(self, rhs: Self) -> Self::Output {
188 self.or(rhs)
189 }
190}
191
192#[derive(Debug, Clone, PartialEq)]
193pub enum NumericLiteral {
194 I64(i64),
195 F64(f64),
196}
197
198#[derive(Debug, Clone, PartialEq)]
200pub enum ColumnFilter {
201 Numeric {
202 column: String,
203 cmp: CmpOp,
204 rhs: NumericLiteral,
205 },
206 String {
207 column: String,
208 cmp: CmpOp,
209 rhs: String,
210 },
211}
212
213pub type Plan = Vec<Vec<ColumnFilter>>;
221
222#[derive(Debug, Clone, PartialEq)]
224pub struct CompiledFilter {
225 pub clauses: Plan,
226}
227
228#[derive(Debug, PartialEq)]
230pub enum ExprError {
231 UnknownColumn(String),
232 TypeMismatch(String, DataType, &'static str),
233 UnsupportedStringOp(String),
234 InvalidComparison,
235 InvalidExpression,
236}
237
238impl fmt::Display for ExprError {
239 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240 match self {
241 ExprError::UnknownColumn(c) => write!(f, "Unknown column '{c}'"),
242 ExprError::TypeMismatch(c, dt, got) => {
243 write!(
244 f,
245 "Type mismatch for column '{c}': expected {dt:?}, got literal {got}"
246 )
247 }
248 ExprError::UnsupportedStringOp(c) => {
249 write!(f, "Unsupported comparator for string column '{c}'")
250 }
251 ExprError::InvalidComparison => write!(
252 f,
253 "Invalid expression shape for comparison (expect column vs literal)"
254 ),
255 ExprError::InvalidExpression => write!(
256 f,
257 "Invalid expression (unexpected literal or column without comparator)"
258 ),
259 }
260 }
261}
262
263impl Error for ExprError {}
264
265fn parse_datetime_literal_millis(s: &str) -> Option<i64> {
268 use chrono::{DateTime as ChronoDateTime, NaiveDate, NaiveDateTime, Utc};
270
271 if let Ok(dt) = ChronoDateTime::parse_from_rfc3339(s) {
272 return Some(dt.with_timezone(&Utc).timestamp_millis());
273 }
274 if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d")
275 && let Some(dt) = date.and_hms_opt(0, 0, 0)
276 {
277 return Some(dt.and_utc().timestamp_millis());
278 }
279 if let Ok(dt) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") {
280 return Some(dt.and_utc().timestamp_millis());
281 }
282 None
283}
284
285impl Expr {
286 pub fn compile(&self, schema: &HashMap<String, DataType>) -> Result<CompiledFilter, ExprError> {
291 let plan = lower_to_plan(self, schema)?;
294 Ok(CompiledFilter {
295 clauses: normalize_plan(plan),
296 })
297 }
298}
299
300fn normalize_plan(mut plan: Plan) -> Plan {
303 let mut out: Plan = Vec::with_capacity(plan.len());
304
305 for clause in plan.drain(..) {
306 let mut is_tautology = false;
307
308 for lf in &clause {
310 match lf {
311 ColumnFilter::Numeric { column, cmp, rhs } if *cmp == CmpOp::Eq => {
312 let conflict = clause.iter().any(|x| matches!(
313 x,
314 ColumnFilter::Numeric { column: c2, cmp: CmpOp::Neq, rhs: v2 } if c2 == column && v2 == rhs
315 ));
316 if conflict {
317 is_tautology = true;
318 break;
319 }
320 }
321 ColumnFilter::String { column, cmp, rhs } if *cmp == CmpOp::Eq => {
322 let conflict = clause.iter().any(|x| matches!(
323 x,
324 ColumnFilter::String { column: c2, cmp: CmpOp::Neq, rhs: v2 } if c2 == column && v2 == rhs
325 ));
326 if conflict {
327 is_tautology = true;
328 break;
329 }
330 }
331 _ => {}
332 }
333 }
334
335 if is_tautology {
336 continue;
337 }
338
339 out.push(clause);
341 }
342 out
343}
344
345fn lower_to_plan(expr: &Expr, schema: &HashMap<String, DataType>) -> Result<Plan, ExprError> {
356 match expr {
357 Expr::And(a, b) => {
358 let left = lower_to_plan(a, schema)?;
359 let right = lower_to_plan(b, schema)?;
360 Ok(and_concat_clauses(left, right))
361 }
362 Expr::Or(a, b) => {
363 let left = lower_to_plan(a, schema)?;
364 let right = lower_to_plan(b, schema)?;
365 Ok(or_distribute_clauses(left, right))
366 }
367 Expr::Cmp { left, right, op } => {
368 compile_cmp_leaf(left, right, *op, schema).map(|f| vec![vec![f]])
369 }
370 Expr::Column(_) | Expr::Literal(_) => Err(ExprError::InvalidExpression),
371 }
372}
373
374fn compile_cmp_leaf(
386 left: &Expr,
387 right: &Expr,
388 op: CmpOp,
389 schema: &HashMap<String, DataType>,
390) -> Result<ColumnFilter, ExprError> {
391 let (col_name, lit) = match (left, right) {
392 (Expr::Column(name), Expr::Literal(l)) => (name.clone(), l.clone()),
393 _ => return Err(ExprError::InvalidComparison),
394 };
395
396 let dtype = schema
397 .get(&col_name)
398 .ok_or_else(|| ExprError::UnknownColumn(col_name.clone()))?;
399
400 match dtype {
401 DataType::String => {
402 let cmp = match op {
404 CmpOp::Eq => CmpOp::Eq,
405 CmpOp::Neq => CmpOp::Neq,
406 _ => return Err(ExprError::UnsupportedStringOp(col_name)),
407 };
408 let rhs = match lit {
409 Literal::Str(s) => s,
410 Literal::I64(_) | Literal::F64(_) => {
411 return Err(ExprError::TypeMismatch(col_name, *dtype, "string"));
412 }
413 };
414 Ok(ColumnFilter::String {
415 column: col_name,
416 cmp,
417 rhs,
418 })
419 }
420 DataType::Int32 | DataType::Int64 => {
421 let rhs = match lit {
423 Literal::I64(v) => NumericLiteral::I64(v),
424 Literal::F64(_) => return Err(ExprError::TypeMismatch(col_name, *dtype, "float")),
425 Literal::Str(_) => return Err(ExprError::TypeMismatch(col_name, *dtype, "string")),
426 };
427 Ok(ColumnFilter::Numeric {
428 column: col_name,
429 cmp: op,
430 rhs,
431 })
432 }
433 DataType::DateTime => {
434 let millis = match lit {
436 Literal::Str(s) => match parse_datetime_literal_millis(&s) {
437 Some(ms) => ms,
438 None => {
439 return Err(ExprError::TypeMismatch(col_name, *dtype, "datetime string"));
440 }
441 },
442 Literal::I64(_) | Literal::F64(_) => {
443 return Err(ExprError::TypeMismatch(col_name, *dtype, "datetime string"));
444 }
445 };
446 Ok(ColumnFilter::Numeric {
447 column: col_name,
448 cmp: op,
449 rhs: NumericLiteral::I64(millis),
450 })
451 }
452 DataType::Float32 | DataType::Float64 => {
453 let rhs = match lit {
455 Literal::I64(v) => NumericLiteral::F64(v as f64),
456 Literal::F64(v) => NumericLiteral::F64(v),
457 Literal::Str(_) => return Err(ExprError::TypeMismatch(col_name, *dtype, "string")),
458 };
459 Ok(ColumnFilter::Numeric {
460 column: col_name,
461 cmp: op,
462 rhs,
463 })
464 }
465 }
466}
467
468fn and_concat_clauses(mut a: Plan, b: Plan) -> Plan {
475 if a.is_empty() {
476 return b;
477 }
478 if b.is_empty() {
479 return a;
480 }
481 a.extend(b);
482 a
483}
484
485fn or_distribute_clauses(a: Plan, b: Plan) -> Plan {
495 if a.is_empty() {
496 return b;
497 }
498 if b.is_empty() {
499 return a;
500 }
501 a.iter()
502 .flat_map(|ca| {
503 b.iter().map(move |cb| {
504 let mut merged = Vec::with_capacity(ca.len() + cb.len());
505 merged.extend_from_slice(ca);
506 merged.extend_from_slice(cb);
507 merged
508 })
509 })
510 .collect()
511}