Skip to main content

genomicframe_core/
expression.rs

1//! Expression system for GenomicFrame query language
2//!
3//! This module provides the expression AST that GenomicFrame uses for building
4//! lazy query plans. Expressions can represent predicates, aggregations, and
5//! transformations that will later be compiled into actual filter operations.
6//!
7//! # Examples
8//!
9//! ```
10//! use genomicframe_core::expression::{col, lit, Expr};
11//!
12//! // Quality filter: qual > 30.0
13//! let predicate = col("qual").gt(lit(30.0));
14//!
15//! // Complex predicate: (qual > 30 AND is_snp) OR is_pass
16//! let complex = col("qual").gt(lit(30.0))
17//!     .and(Expr::IsSnp)
18//!     .or(Expr::IsPass);
19//! ```
20
21use crate::interval::GenomicInterval;
22use std::fmt;
23
24// ============================================================================
25// Core Expression Type
26// ============================================================================
27
28/// Expression in the GenomicFrame query language
29///
30/// Expressions form an abstract syntax tree (AST) that represents operations
31/// on genomic data. They are lazy - constructing an expression doesn't execute
32/// anything, it just builds a plan that can later be optimized and executed.
33#[derive(Debug, Clone, PartialEq)]
34pub enum Expr {
35    // ========================================================================
36    // Leaf Nodes (Values)
37    // ========================================================================
38    /// Reference to a column by name (e.g., "qual", "chrom", "pos")
39    Column(String),
40
41    /// Literal scalar value
42    Literal(ScalarValue),
43
44    // ========================================================================
45    // Comparison Operators
46    // ========================================================================
47    /// Equal: column == value
48    Eq(Box<Expr>, Box<Expr>),
49
50    /// Not equal: column != value
51    Neq(Box<Expr>, Box<Expr>),
52
53    /// Greater than: column > value
54    Gt(Box<Expr>, Box<Expr>),
55
56    /// Greater than or equal: column >= value
57    Gte(Box<Expr>, Box<Expr>),
58
59    /// Less than: column < value
60    Lt(Box<Expr>, Box<Expr>),
61
62    /// Less than or equal: column <= value
63    Lte(Box<Expr>, Box<Expr>),
64
65    // ========================================================================
66    // Boolean Logic
67    // ========================================================================
68    /// Logical AND of multiple expressions
69    And(Vec<Expr>),
70
71    /// Logical OR of multiple expressions
72    Or(Vec<Expr>),
73
74    /// Logical NOT
75    Not(Box<Expr>),
76
77    // ========================================================================
78    // Genomic-Specific Predicates
79    // ========================================================================
80    /// Test if variant is a transition (A<->G or C<->T)
81    IsTransition,
82
83    /// Test if variant is a transversion
84    IsTransversion,
85
86    /// Test if variant is a SNP (single nucleotide polymorphism)
87    IsSnp,
88
89    /// Test if variant is an indel (insertion/deletion)
90    IsIndel,
91
92    /// Test if variant passed all filters (FILTER == "PASS")
93    IsPass,
94
95    /// Test if record overlaps a genomic region
96    InRegion(GenomicInterval),
97
98    /// Test if record overlaps any of multiple regions
99    InRegions(Vec<GenomicInterval>),
100
101    /// Test if chromosome matches
102    OnChromosome(String),
103
104    // ========================================================================
105    // String Operations
106    // ========================================================================
107    /// String contains substring
108    Contains(Box<Expr>, Box<Expr>),
109
110    /// String starts with prefix
111    StartsWith(Box<Expr>, Box<Expr>),
112
113    /// String matches regex pattern
114    Matches(Box<Expr>, String),
115
116    // ========================================================================
117    // Aggregations (for future group_by support)
118    // ========================================================================
119    /// Count records
120    Count,
121
122    /// Mean of expression values
123    Mean(Box<Expr>),
124
125    /// Sum of expression values
126    Sum(Box<Expr>),
127
128    /// Minimum value
129    Min(Box<Expr>),
130
131    /// Maximum value
132    Max(Box<Expr>),
133
134    /// Transition/transversion ratio (genomic-specific)
135    TsTvRatio,
136
137    /// Allele frequency calculation
138    AlleleFrequency,
139}
140
141// ============================================================================
142// Scalar Value Types
143// ============================================================================
144
145/// Scalar values that can appear in expressions
146#[derive(Debug, Clone, PartialEq)]
147pub enum ScalarValue {
148    /// Boolean value
149    Boolean(bool),
150
151    /// 64-bit integer
152    Int64(i64),
153
154    /// 64-bit float
155    Float64(f64),
156
157    /// String value
158    String(String),
159
160    /// Null/missing value
161    Null,
162}
163
164// ============================================================================
165// Expression Builder API (Ergonomic Methods)
166// ============================================================================
167
168impl Expr {
169    /// Create an equality comparison
170    pub fn eq(self, other: Expr) -> Expr {
171        Expr::Eq(Box::new(self), Box::new(other))
172    }
173
174    /// Create a not-equal comparison
175    pub fn neq(self, other: Expr) -> Expr {
176        Expr::Neq(Box::new(self), Box::new(other))
177    }
178
179    /// Create a greater-than comparison
180    pub fn gt(self, other: Expr) -> Expr {
181        Expr::Gt(Box::new(self), Box::new(other))
182    }
183
184    /// Create a greater-than-or-equal comparison
185    pub fn gte(self, other: Expr) -> Expr {
186        Expr::Gte(Box::new(self), Box::new(other))
187    }
188
189    /// Create a less-than comparison
190    pub fn lt(self, other: Expr) -> Expr {
191        Expr::Lt(Box::new(self), Box::new(other))
192    }
193
194    /// Create a less-than-or-equal comparison
195    pub fn lte(self, other: Expr) -> Expr {
196        Expr::Lte(Box::new(self), Box::new(other))
197    }
198
199    /// Combine with AND
200    pub fn and(self, other: Expr) -> Expr {
201        match (self, other) {
202            // Flatten nested ANDs
203            (Expr::And(mut left), Expr::And(right)) => {
204                left.extend(right);
205                Expr::And(left)
206            }
207            (Expr::And(mut exprs), other) => {
208                exprs.push(other);
209                Expr::And(exprs)
210            }
211            (this, Expr::And(mut exprs)) => {
212                exprs.insert(0, this);
213                Expr::And(exprs)
214            }
215            (left, right) => Expr::And(vec![left, right]),
216        }
217    }
218
219    /// Combine with OR
220    pub fn or(self, other: Expr) -> Expr {
221        match (self, other) {
222            // Flatten nested ORs
223            (Expr::Or(mut left), Expr::Or(right)) => {
224                left.extend(right);
225                Expr::Or(left)
226            }
227            (Expr::Or(mut exprs), other) => {
228                exprs.push(other);
229                Expr::Or(exprs)
230            }
231            (this, Expr::Or(mut exprs)) => {
232                exprs.insert(0, this);
233                Expr::Or(exprs)
234            }
235            (left, right) => Expr::Or(vec![left, right]),
236        }
237    }
238
239    /// Negate this expression
240    pub fn not(self) -> Expr {
241        Expr::Not(Box::new(self))
242    }
243}
244
245// ============================================================================
246// Helper Functions for Ergonomic Construction
247// ============================================================================
248
249/// Create a column reference expression
250///
251/// # Example
252/// ```
253/// use genomicframe_core::expression::col;
254/// let quality = col("qual");
255/// ```
256pub fn col(name: &str) -> Expr {
257    Expr::Column(name.to_string())
258}
259
260/// Create a literal value expression
261///
262/// # Example
263/// ```
264/// use genomicframe_core::expression::lit;
265///
266/// let num = lit(30.0);
267/// let text = lit("PASS");
268/// let flag = lit(true);
269/// ```
270pub fn lit<T: Into<ScalarValue>>(value: T) -> Expr {
271    Expr::Literal(value.into())
272}
273
274// ============================================================================
275// Conversions to ScalarValue
276// ============================================================================
277
278impl From<bool> for ScalarValue {
279    fn from(v: bool) -> Self {
280        ScalarValue::Boolean(v)
281    }
282}
283
284impl From<i64> for ScalarValue {
285    fn from(v: i64) -> Self {
286        ScalarValue::Int64(v)
287    }
288}
289
290impl From<i32> for ScalarValue {
291    fn from(v: i32) -> Self {
292        ScalarValue::Int64(v as i64)
293    }
294}
295
296impl From<f64> for ScalarValue {
297    fn from(v: f64) -> Self {
298        ScalarValue::Float64(v)
299    }
300}
301
302impl From<f32> for ScalarValue {
303    fn from(v: f32) -> Self {
304        ScalarValue::Float64(v as f64)
305    }
306}
307
308impl From<String> for ScalarValue {
309    fn from(v: String) -> Self {
310        ScalarValue::String(v)
311    }
312}
313
314impl From<&str> for ScalarValue {
315    fn from(v: &str) -> Self {
316        ScalarValue::String(v.to_string())
317    }
318}
319
320// ============================================================================
321// Display Implementations
322// ============================================================================
323
324impl fmt::Display for Expr {
325    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326        match self {
327            Expr::Column(name) => write!(f, "{}", name),
328            Expr::Literal(val) => write!(f, "{}", val),
329            Expr::Eq(left, right) => write!(f, "({} == {})", left, right),
330            Expr::Neq(left, right) => write!(f, "({} != {})", left, right),
331            Expr::Gt(left, right) => write!(f, "({} > {})", left, right),
332            Expr::Gte(left, right) => write!(f, "({} >= {})", left, right),
333            Expr::Lt(left, right) => write!(f, "({} < {})", left, right),
334            Expr::Lte(left, right) => write!(f, "({} <= {})", left, right),
335            Expr::And(exprs) => {
336                write!(f, "(")?;
337                for (i, expr) in exprs.iter().enumerate() {
338                    if i > 0 {
339                        write!(f, " AND ")?;
340                    }
341                    write!(f, "{}", expr)?;
342                }
343                write!(f, ")")
344            }
345            Expr::Or(exprs) => {
346                write!(f, "(")?;
347                for (i, expr) in exprs.iter().enumerate() {
348                    if i > 0 {
349                        write!(f, " OR ")?;
350                    }
351                    write!(f, "{}", expr)?;
352                }
353                write!(f, ")")
354            }
355            Expr::Not(expr) => write!(f, "NOT {}", expr),
356            Expr::IsTransition => write!(f, "is_transition"),
357            Expr::IsTransversion => write!(f, "is_transversion"),
358            Expr::IsSnp => write!(f, "is_snp"),
359            Expr::IsIndel => write!(f, "is_indel"),
360            Expr::IsPass => write!(f, "is_pass"),
361            Expr::InRegion(interval) => write!(f, "in_region({})", interval),
362            Expr::InRegions(intervals) => {
363                write!(f, "in_regions([{} intervals])", intervals.len())
364            }
365            Expr::OnChromosome(chrom) => write!(f, "on_chromosome({})", chrom),
366            Expr::Contains(expr, substr) => write!(f, "{}.contains({})", expr, substr),
367            Expr::StartsWith(expr, prefix) => write!(f, "{}.starts_with({})", expr, prefix),
368            Expr::Matches(expr, pattern) => write!(f, "{}.matches('{}')", expr, pattern),
369            Expr::Count => write!(f, "count()"),
370            Expr::Mean(expr) => write!(f, "mean({})", expr),
371            Expr::Sum(expr) => write!(f, "sum({})", expr),
372            Expr::Min(expr) => write!(f, "min({})", expr),
373            Expr::Max(expr) => write!(f, "max({})", expr),
374            Expr::TsTvRatio => write!(f, "ts_tv_ratio()"),
375            Expr::AlleleFrequency => write!(f, "allele_frequency()"),
376        }
377    }
378}
379
380impl fmt::Display for ScalarValue {
381    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382        match self {
383            ScalarValue::Boolean(b) => write!(f, "{}", b),
384            ScalarValue::Int64(i) => write!(f, "{}", i),
385            ScalarValue::Float64(fl) => write!(f, "{}", fl),
386            ScalarValue::String(s) => write!(f, "'{}'", s),
387            ScalarValue::Null => write!(f, "null"),
388        }
389    }
390}
391
392// ============================================================================
393// Expression to Filter Compilation
394// ============================================================================
395
396use crate::error::{Error, Result};
397use crate::filters::RecordFilter;
398
399/// Trait for compiling expressions into record filters
400///
401/// This is the critical piece that connects the lazy expression world
402/// to the eager filtering world. Each format implements this to convert
403/// generic expressions into format-specific filters.
404pub trait ExprToFilter<R> {
405    /// Compile this expression into a filter for record type R
406    fn compile(&self) -> Result<Box<dyn RecordFilter<R>>>;
407}
408// Helper filters for compiled boolean logic
409pub struct CompiledAndFilter<R> {
410    pub left: Box<dyn RecordFilter<R>>,
411    pub right: Box<dyn RecordFilter<R>>,
412}
413
414impl<R: Send + Sync> RecordFilter<R> for CompiledAndFilter<R> {
415    fn test(&self, record: &R) -> bool {
416        self.left.test(record) && self.right.test(record)
417    }
418}
419
420pub struct CompiledOrFilter<R> {
421    pub left: Box<dyn RecordFilter<R>>,
422    pub right: Box<dyn RecordFilter<R>>,
423}
424
425impl<R: Send + Sync> RecordFilter<R> for CompiledOrFilter<R> {
426    fn test(&self, record: &R) -> bool {
427        self.left.test(record) || self.right.test(record)
428    }
429}
430
431pub struct CompiledNotFilter<R> {
432    pub inner: Box<dyn RecordFilter<R>>,
433}
434
435impl<R: Send + Sync> RecordFilter<R> for CompiledNotFilter<R> {
436    fn test(&self, record: &R) -> bool {
437        !self.inner.test(record)
438    }
439}
440
441// ============================================================================
442// Helper Functions for Value Extraction
443// ============================================================================
444
445pub fn extract_f64(expr: &Expr) -> Result<f64> {
446    match expr {
447        Expr::Literal(ScalarValue::Float64(v)) => Ok(*v),
448        Expr::Literal(ScalarValue::Int64(v)) => Ok(*v as f64),
449        _ => Err(Error::invalid_input(format!(
450            "Expected float literal, got {}",
451            expr
452        ))),
453    }
454}
455
456pub fn extract_i64(expr: &Expr) -> Result<i64> {
457    match expr {
458        Expr::Literal(ScalarValue::Int64(v)) => Ok(*v),
459        _ => Err(Error::invalid_input(format!(
460            "Expected int literal, got {}",
461            expr
462        ))),
463    }
464}
465
466pub fn extract_u32(expr: &Expr) -> Result<u32> {
467    match expr {
468        Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 && *v <= u32::MAX as i64 => Ok(*v as u32),
469        _ => Err(Error::invalid_input(format!(
470            "Expected u32 literal, got {}",
471            expr
472        ))),
473    }
474}
475
476pub fn extract_u64(expr: &Expr) -> Result<u64> {
477    match expr {
478        Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 => Ok(*v as u64),
479        _ => Err(Error::invalid_input(format!(
480            "Expected u64 literal, got {}",
481            expr
482        ))),
483    }
484}
485
486pub fn extract_u8(expr: &Expr) -> Result<u8> {
487    match expr {
488        Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 && *v <= 255 => Ok(*v as u8),
489        _ => Err(Error::invalid_input(format!(
490            "Expected u8 literal, got {}",
491            expr
492        ))),
493    }
494}
495
496pub fn extract_usize(expr: &Expr) -> Result<usize> {
497    match expr {
498        Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 => Ok(*v as usize),
499        _ => Err(Error::invalid_input(format!(
500            "Expected usize literal, got {}",
501            expr
502        ))),
503    }
504}
505
506pub fn extract_string(expr: &Expr) -> Result<String> {
507    match expr {
508        Expr::Literal(ScalarValue::String(s)) => Ok(s.clone()),
509        _ => Err(Error::invalid_input(format!(
510            "Expected string literal, got {}",
511            expr
512        ))),
513    }
514}
515
516// ============================================================================
517// Tests
518// ============================================================================
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    #[test]
525    fn test_column_reference() {
526        let expr = col("qual");
527        assert_eq!(expr, Expr::Column("qual".to_string()));
528        assert_eq!(format!("{}", expr), "qual");
529    }
530
531    #[test]
532    fn test_literal_values() {
533        assert_eq!(lit(30.0), Expr::Literal(ScalarValue::Float64(30.0)));
534        assert_eq!(lit(42), Expr::Literal(ScalarValue::Int64(42)));
535        assert_eq!(
536            lit("PASS"),
537            Expr::Literal(ScalarValue::String("PASS".to_string()))
538        );
539        assert_eq!(lit(true), Expr::Literal(ScalarValue::Boolean(true)));
540    }
541
542    #[test]
543    fn test_comparison_builders() {
544        let expr = col("qual").gt(lit(30.0));
545        match expr {
546            Expr::Gt(left, right) => {
547                assert_eq!(*left, col("qual"));
548                assert_eq!(*right, lit(30.0));
549            }
550            _ => panic!("Expected Gt"),
551        }
552    }
553
554    #[test]
555    fn test_and_flattening() {
556        let expr = col("qual").gt(lit(30.0)).and(Expr::IsSnp).and(Expr::IsPass);
557
558        match expr {
559            Expr::And(exprs) => {
560                assert_eq!(exprs.len(), 3);
561            }
562            _ => panic!("Expected And"),
563        }
564    }
565
566    #[test]
567    fn test_complex_expression() {
568        let expr = col("qual").gt(lit(30.0)).and(Expr::IsSnp).or(Expr::IsPass);
569
570        let display = format!("{}", expr);
571        assert!(display.contains("qual"));
572        assert!(display.contains("30"));
573        assert!(display.contains("is_snp"));
574        assert!(display.contains("is_pass"));
575    }
576
577    #[test]
578    fn test_genomic_predicates() {
579        let transition = Expr::IsTransition;
580        assert_eq!(format!("{}", transition), "is_transition");
581
582        let region = Expr::InRegion(GenomicInterval::new("chr1", 1000, 2000));
583        assert!(format!("{}", region).contains("in_region"));
584    }
585}