1use crate::interval::GenomicInterval;
22use std::fmt;
23
24#[derive(Debug, Clone, PartialEq)]
34pub enum Expr {
35 Column(String),
40
41 Literal(ScalarValue),
43
44 Eq(Box<Expr>, Box<Expr>),
49
50 Neq(Box<Expr>, Box<Expr>),
52
53 Gt(Box<Expr>, Box<Expr>),
55
56 Gte(Box<Expr>, Box<Expr>),
58
59 Lt(Box<Expr>, Box<Expr>),
61
62 Lte(Box<Expr>, Box<Expr>),
64
65 And(Vec<Expr>),
70
71 Or(Vec<Expr>),
73
74 Not(Box<Expr>),
76
77 IsTransition,
82
83 IsTransversion,
85
86 IsSnp,
88
89 IsIndel,
91
92 IsPass,
94
95 InRegion(GenomicInterval),
97
98 InRegions(Vec<GenomicInterval>),
100
101 OnChromosome(String),
103
104 Contains(Box<Expr>, Box<Expr>),
109
110 StartsWith(Box<Expr>, Box<Expr>),
112
113 Matches(Box<Expr>, String),
115
116 Count,
121
122 Mean(Box<Expr>),
124
125 Sum(Box<Expr>),
127
128 Min(Box<Expr>),
130
131 Max(Box<Expr>),
133
134 TsTvRatio,
136
137 AlleleFrequency,
139}
140
141#[derive(Debug, Clone, PartialEq)]
147pub enum ScalarValue {
148 Boolean(bool),
150
151 Int64(i64),
153
154 Float64(f64),
156
157 String(String),
159
160 Null,
162}
163
164impl Expr {
169 pub fn eq(self, other: Expr) -> Expr {
171 Expr::Eq(Box::new(self), Box::new(other))
172 }
173
174 pub fn neq(self, other: Expr) -> Expr {
176 Expr::Neq(Box::new(self), Box::new(other))
177 }
178
179 pub fn gt(self, other: Expr) -> Expr {
181 Expr::Gt(Box::new(self), Box::new(other))
182 }
183
184 pub fn gte(self, other: Expr) -> Expr {
186 Expr::Gte(Box::new(self), Box::new(other))
187 }
188
189 pub fn lt(self, other: Expr) -> Expr {
191 Expr::Lt(Box::new(self), Box::new(other))
192 }
193
194 pub fn lte(self, other: Expr) -> Expr {
196 Expr::Lte(Box::new(self), Box::new(other))
197 }
198
199 pub fn and(self, other: Expr) -> Expr {
201 match (self, other) {
202 (Expr::And(mut left), Expr::And(right)) => {
204 left.extend(right);
205 Expr::And(left)
206 }
207 (Expr::And(mut exprs), other) => {
208 exprs.push(other);
209 Expr::And(exprs)
210 }
211 (this, Expr::And(mut exprs)) => {
212 exprs.insert(0, this);
213 Expr::And(exprs)
214 }
215 (left, right) => Expr::And(vec![left, right]),
216 }
217 }
218
219 pub fn or(self, other: Expr) -> Expr {
221 match (self, other) {
222 (Expr::Or(mut left), Expr::Or(right)) => {
224 left.extend(right);
225 Expr::Or(left)
226 }
227 (Expr::Or(mut exprs), other) => {
228 exprs.push(other);
229 Expr::Or(exprs)
230 }
231 (this, Expr::Or(mut exprs)) => {
232 exprs.insert(0, this);
233 Expr::Or(exprs)
234 }
235 (left, right) => Expr::Or(vec![left, right]),
236 }
237 }
238
239 pub fn not(self) -> Expr {
241 Expr::Not(Box::new(self))
242 }
243}
244
245pub fn col(name: &str) -> Expr {
257 Expr::Column(name.to_string())
258}
259
260pub fn lit<T: Into<ScalarValue>>(value: T) -> Expr {
271 Expr::Literal(value.into())
272}
273
274impl From<bool> for ScalarValue {
279 fn from(v: bool) -> Self {
280 ScalarValue::Boolean(v)
281 }
282}
283
284impl From<i64> for ScalarValue {
285 fn from(v: i64) -> Self {
286 ScalarValue::Int64(v)
287 }
288}
289
290impl From<i32> for ScalarValue {
291 fn from(v: i32) -> Self {
292 ScalarValue::Int64(v as i64)
293 }
294}
295
296impl From<f64> for ScalarValue {
297 fn from(v: f64) -> Self {
298 ScalarValue::Float64(v)
299 }
300}
301
302impl From<f32> for ScalarValue {
303 fn from(v: f32) -> Self {
304 ScalarValue::Float64(v as f64)
305 }
306}
307
308impl From<String> for ScalarValue {
309 fn from(v: String) -> Self {
310 ScalarValue::String(v)
311 }
312}
313
314impl From<&str> for ScalarValue {
315 fn from(v: &str) -> Self {
316 ScalarValue::String(v.to_string())
317 }
318}
319
320impl fmt::Display for Expr {
325 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326 match self {
327 Expr::Column(name) => write!(f, "{}", name),
328 Expr::Literal(val) => write!(f, "{}", val),
329 Expr::Eq(left, right) => write!(f, "({} == {})", left, right),
330 Expr::Neq(left, right) => write!(f, "({} != {})", left, right),
331 Expr::Gt(left, right) => write!(f, "({} > {})", left, right),
332 Expr::Gte(left, right) => write!(f, "({} >= {})", left, right),
333 Expr::Lt(left, right) => write!(f, "({} < {})", left, right),
334 Expr::Lte(left, right) => write!(f, "({} <= {})", left, right),
335 Expr::And(exprs) => {
336 write!(f, "(")?;
337 for (i, expr) in exprs.iter().enumerate() {
338 if i > 0 {
339 write!(f, " AND ")?;
340 }
341 write!(f, "{}", expr)?;
342 }
343 write!(f, ")")
344 }
345 Expr::Or(exprs) => {
346 write!(f, "(")?;
347 for (i, expr) in exprs.iter().enumerate() {
348 if i > 0 {
349 write!(f, " OR ")?;
350 }
351 write!(f, "{}", expr)?;
352 }
353 write!(f, ")")
354 }
355 Expr::Not(expr) => write!(f, "NOT {}", expr),
356 Expr::IsTransition => write!(f, "is_transition"),
357 Expr::IsTransversion => write!(f, "is_transversion"),
358 Expr::IsSnp => write!(f, "is_snp"),
359 Expr::IsIndel => write!(f, "is_indel"),
360 Expr::IsPass => write!(f, "is_pass"),
361 Expr::InRegion(interval) => write!(f, "in_region({})", interval),
362 Expr::InRegions(intervals) => {
363 write!(f, "in_regions([{} intervals])", intervals.len())
364 }
365 Expr::OnChromosome(chrom) => write!(f, "on_chromosome({})", chrom),
366 Expr::Contains(expr, substr) => write!(f, "{}.contains({})", expr, substr),
367 Expr::StartsWith(expr, prefix) => write!(f, "{}.starts_with({})", expr, prefix),
368 Expr::Matches(expr, pattern) => write!(f, "{}.matches('{}')", expr, pattern),
369 Expr::Count => write!(f, "count()"),
370 Expr::Mean(expr) => write!(f, "mean({})", expr),
371 Expr::Sum(expr) => write!(f, "sum({})", expr),
372 Expr::Min(expr) => write!(f, "min({})", expr),
373 Expr::Max(expr) => write!(f, "max({})", expr),
374 Expr::TsTvRatio => write!(f, "ts_tv_ratio()"),
375 Expr::AlleleFrequency => write!(f, "allele_frequency()"),
376 }
377 }
378}
379
380impl fmt::Display for ScalarValue {
381 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382 match self {
383 ScalarValue::Boolean(b) => write!(f, "{}", b),
384 ScalarValue::Int64(i) => write!(f, "{}", i),
385 ScalarValue::Float64(fl) => write!(f, "{}", fl),
386 ScalarValue::String(s) => write!(f, "'{}'", s),
387 ScalarValue::Null => write!(f, "null"),
388 }
389 }
390}
391
392use crate::error::{Error, Result};
397use crate::filters::RecordFilter;
398
399pub trait ExprToFilter<R> {
405 fn compile(&self) -> Result<Box<dyn RecordFilter<R>>>;
407}
408pub struct CompiledAndFilter<R> {
410 pub left: Box<dyn RecordFilter<R>>,
411 pub right: Box<dyn RecordFilter<R>>,
412}
413
414impl<R: Send + Sync> RecordFilter<R> for CompiledAndFilter<R> {
415 fn test(&self, record: &R) -> bool {
416 self.left.test(record) && self.right.test(record)
417 }
418}
419
420pub struct CompiledOrFilter<R> {
421 pub left: Box<dyn RecordFilter<R>>,
422 pub right: Box<dyn RecordFilter<R>>,
423}
424
425impl<R: Send + Sync> RecordFilter<R> for CompiledOrFilter<R> {
426 fn test(&self, record: &R) -> bool {
427 self.left.test(record) || self.right.test(record)
428 }
429}
430
431pub struct CompiledNotFilter<R> {
432 pub inner: Box<dyn RecordFilter<R>>,
433}
434
435impl<R: Send + Sync> RecordFilter<R> for CompiledNotFilter<R> {
436 fn test(&self, record: &R) -> bool {
437 !self.inner.test(record)
438 }
439}
440
441pub fn extract_f64(expr: &Expr) -> Result<f64> {
446 match expr {
447 Expr::Literal(ScalarValue::Float64(v)) => Ok(*v),
448 Expr::Literal(ScalarValue::Int64(v)) => Ok(*v as f64),
449 _ => Err(Error::invalid_input(format!(
450 "Expected float literal, got {}",
451 expr
452 ))),
453 }
454}
455
456pub fn extract_i64(expr: &Expr) -> Result<i64> {
457 match expr {
458 Expr::Literal(ScalarValue::Int64(v)) => Ok(*v),
459 _ => Err(Error::invalid_input(format!(
460 "Expected int literal, got {}",
461 expr
462 ))),
463 }
464}
465
466pub fn extract_u32(expr: &Expr) -> Result<u32> {
467 match expr {
468 Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 && *v <= u32::MAX as i64 => Ok(*v as u32),
469 _ => Err(Error::invalid_input(format!(
470 "Expected u32 literal, got {}",
471 expr
472 ))),
473 }
474}
475
476pub fn extract_u64(expr: &Expr) -> Result<u64> {
477 match expr {
478 Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 => Ok(*v as u64),
479 _ => Err(Error::invalid_input(format!(
480 "Expected u64 literal, got {}",
481 expr
482 ))),
483 }
484}
485
486pub fn extract_u8(expr: &Expr) -> Result<u8> {
487 match expr {
488 Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 && *v <= 255 => Ok(*v as u8),
489 _ => Err(Error::invalid_input(format!(
490 "Expected u8 literal, got {}",
491 expr
492 ))),
493 }
494}
495
496pub fn extract_usize(expr: &Expr) -> Result<usize> {
497 match expr {
498 Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 => Ok(*v as usize),
499 _ => Err(Error::invalid_input(format!(
500 "Expected usize literal, got {}",
501 expr
502 ))),
503 }
504}
505
506pub fn extract_string(expr: &Expr) -> Result<String> {
507 match expr {
508 Expr::Literal(ScalarValue::String(s)) => Ok(s.clone()),
509 _ => Err(Error::invalid_input(format!(
510 "Expected string literal, got {}",
511 expr
512 ))),
513 }
514}
515
516#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_column_reference() {
526 let expr = col("qual");
527 assert_eq!(expr, Expr::Column("qual".to_string()));
528 assert_eq!(format!("{}", expr), "qual");
529 }
530
531 #[test]
532 fn test_literal_values() {
533 assert_eq!(lit(30.0), Expr::Literal(ScalarValue::Float64(30.0)));
534 assert_eq!(lit(42), Expr::Literal(ScalarValue::Int64(42)));
535 assert_eq!(
536 lit("PASS"),
537 Expr::Literal(ScalarValue::String("PASS".to_string()))
538 );
539 assert_eq!(lit(true), Expr::Literal(ScalarValue::Boolean(true)));
540 }
541
542 #[test]
543 fn test_comparison_builders() {
544 let expr = col("qual").gt(lit(30.0));
545 match expr {
546 Expr::Gt(left, right) => {
547 assert_eq!(*left, col("qual"));
548 assert_eq!(*right, lit(30.0));
549 }
550 _ => panic!("Expected Gt"),
551 }
552 }
553
554 #[test]
555 fn test_and_flattening() {
556 let expr = col("qual").gt(lit(30.0)).and(Expr::IsSnp).and(Expr::IsPass);
557
558 match expr {
559 Expr::And(exprs) => {
560 assert_eq!(exprs.len(), 3);
561 }
562 _ => panic!("Expected And"),
563 }
564 }
565
566 #[test]
567 fn test_complex_expression() {
568 let expr = col("qual").gt(lit(30.0)).and(Expr::IsSnp).or(Expr::IsPass);
569
570 let display = format!("{}", expr);
571 assert!(display.contains("qual"));
572 assert!(display.contains("30"));
573 assert!(display.contains("is_snp"));
574 assert!(display.contains("is_pass"));
575 }
576
577 #[test]
578 fn test_genomic_predicates() {
579 let transition = Expr::IsTransition;
580 assert_eq!(format!("{}", transition), "is_transition");
581
582 let region = Expr::InRegion(GenomicInterval::new("chr1", 1000, 2000));
583 assert!(format!("{}", region).contains("in_region"));
584 }
585}