use crate::interval::GenomicInterval;
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Column(String),
Literal(ScalarValue),
Eq(Box<Expr>, Box<Expr>),
Neq(Box<Expr>, Box<Expr>),
Gt(Box<Expr>, Box<Expr>),
Gte(Box<Expr>, Box<Expr>),
Lt(Box<Expr>, Box<Expr>),
Lte(Box<Expr>, Box<Expr>),
And(Vec<Expr>),
Or(Vec<Expr>),
Not(Box<Expr>),
IsTransition,
IsTransversion,
IsSnp,
IsIndel,
IsPass,
InRegion(GenomicInterval),
InRegions(Vec<GenomicInterval>),
OnChromosome(String),
Contains(Box<Expr>, Box<Expr>),
StartsWith(Box<Expr>, Box<Expr>),
Matches(Box<Expr>, String),
Count,
Mean(Box<Expr>),
Sum(Box<Expr>),
Min(Box<Expr>),
Max(Box<Expr>),
TsTvRatio,
AlleleFrequency,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ScalarValue {
Boolean(bool),
Int64(i64),
Float64(f64),
String(String),
Null,
}
impl Expr {
pub fn eq(self, other: Expr) -> Expr {
Expr::Eq(Box::new(self), Box::new(other))
}
pub fn neq(self, other: Expr) -> Expr {
Expr::Neq(Box::new(self), Box::new(other))
}
pub fn gt(self, other: Expr) -> Expr {
Expr::Gt(Box::new(self), Box::new(other))
}
pub fn gte(self, other: Expr) -> Expr {
Expr::Gte(Box::new(self), Box::new(other))
}
pub fn lt(self, other: Expr) -> Expr {
Expr::Lt(Box::new(self), Box::new(other))
}
pub fn lte(self, other: Expr) -> Expr {
Expr::Lte(Box::new(self), Box::new(other))
}
pub fn and(self, other: Expr) -> Expr {
match (self, other) {
(Expr::And(mut left), Expr::And(right)) => {
left.extend(right);
Expr::And(left)
}
(Expr::And(mut exprs), other) => {
exprs.push(other);
Expr::And(exprs)
}
(this, Expr::And(mut exprs)) => {
exprs.insert(0, this);
Expr::And(exprs)
}
(left, right) => Expr::And(vec![left, right]),
}
}
pub fn or(self, other: Expr) -> Expr {
match (self, other) {
(Expr::Or(mut left), Expr::Or(right)) => {
left.extend(right);
Expr::Or(left)
}
(Expr::Or(mut exprs), other) => {
exprs.push(other);
Expr::Or(exprs)
}
(this, Expr::Or(mut exprs)) => {
exprs.insert(0, this);
Expr::Or(exprs)
}
(left, right) => Expr::Or(vec![left, right]),
}
}
pub fn not(self) -> Expr {
Expr::Not(Box::new(self))
}
}
pub fn col(name: &str) -> Expr {
Expr::Column(name.to_string())
}
pub fn lit<T: Into<ScalarValue>>(value: T) -> Expr {
Expr::Literal(value.into())
}
impl From<bool> for ScalarValue {
fn from(v: bool) -> Self {
ScalarValue::Boolean(v)
}
}
impl From<i64> for ScalarValue {
fn from(v: i64) -> Self {
ScalarValue::Int64(v)
}
}
impl From<i32> for ScalarValue {
fn from(v: i32) -> Self {
ScalarValue::Int64(v as i64)
}
}
impl From<f64> for ScalarValue {
fn from(v: f64) -> Self {
ScalarValue::Float64(v)
}
}
impl From<f32> for ScalarValue {
fn from(v: f32) -> Self {
ScalarValue::Float64(v as f64)
}
}
impl From<String> for ScalarValue {
fn from(v: String) -> Self {
ScalarValue::String(v)
}
}
impl From<&str> for ScalarValue {
fn from(v: &str) -> Self {
ScalarValue::String(v.to_string())
}
}
impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expr::Column(name) => write!(f, "{}", name),
Expr::Literal(val) => write!(f, "{}", val),
Expr::Eq(left, right) => write!(f, "({} == {})", left, right),
Expr::Neq(left, right) => write!(f, "({} != {})", left, right),
Expr::Gt(left, right) => write!(f, "({} > {})", left, right),
Expr::Gte(left, right) => write!(f, "({} >= {})", left, right),
Expr::Lt(left, right) => write!(f, "({} < {})", left, right),
Expr::Lte(left, right) => write!(f, "({} <= {})", left, right),
Expr::And(exprs) => {
write!(f, "(")?;
for (i, expr) in exprs.iter().enumerate() {
if i > 0 {
write!(f, " AND ")?;
}
write!(f, "{}", expr)?;
}
write!(f, ")")
}
Expr::Or(exprs) => {
write!(f, "(")?;
for (i, expr) in exprs.iter().enumerate() {
if i > 0 {
write!(f, " OR ")?;
}
write!(f, "{}", expr)?;
}
write!(f, ")")
}
Expr::Not(expr) => write!(f, "NOT {}", expr),
Expr::IsTransition => write!(f, "is_transition"),
Expr::IsTransversion => write!(f, "is_transversion"),
Expr::IsSnp => write!(f, "is_snp"),
Expr::IsIndel => write!(f, "is_indel"),
Expr::IsPass => write!(f, "is_pass"),
Expr::InRegion(interval) => write!(f, "in_region({})", interval),
Expr::InRegions(intervals) => {
write!(f, "in_regions([{} intervals])", intervals.len())
}
Expr::OnChromosome(chrom) => write!(f, "on_chromosome({})", chrom),
Expr::Contains(expr, substr) => write!(f, "{}.contains({})", expr, substr),
Expr::StartsWith(expr, prefix) => write!(f, "{}.starts_with({})", expr, prefix),
Expr::Matches(expr, pattern) => write!(f, "{}.matches('{}')", expr, pattern),
Expr::Count => write!(f, "count()"),
Expr::Mean(expr) => write!(f, "mean({})", expr),
Expr::Sum(expr) => write!(f, "sum({})", expr),
Expr::Min(expr) => write!(f, "min({})", expr),
Expr::Max(expr) => write!(f, "max({})", expr),
Expr::TsTvRatio => write!(f, "ts_tv_ratio()"),
Expr::AlleleFrequency => write!(f, "allele_frequency()"),
}
}
}
impl fmt::Display for ScalarValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ScalarValue::Boolean(b) => write!(f, "{}", b),
ScalarValue::Int64(i) => write!(f, "{}", i),
ScalarValue::Float64(fl) => write!(f, "{}", fl),
ScalarValue::String(s) => write!(f, "'{}'", s),
ScalarValue::Null => write!(f, "null"),
}
}
}
use crate::error::{Error, Result};
use crate::filters::RecordFilter;
pub trait ExprToFilter<R> {
fn compile(&self) -> Result<Box<dyn RecordFilter<R>>>;
}
pub struct CompiledAndFilter<R> {
pub left: Box<dyn RecordFilter<R>>,
pub right: Box<dyn RecordFilter<R>>,
}
impl<R: Send + Sync> RecordFilter<R> for CompiledAndFilter<R> {
fn test(&self, record: &R) -> bool {
self.left.test(record) && self.right.test(record)
}
}
pub struct CompiledOrFilter<R> {
pub left: Box<dyn RecordFilter<R>>,
pub right: Box<dyn RecordFilter<R>>,
}
impl<R: Send + Sync> RecordFilter<R> for CompiledOrFilter<R> {
fn test(&self, record: &R) -> bool {
self.left.test(record) || self.right.test(record)
}
}
pub struct CompiledNotFilter<R> {
pub inner: Box<dyn RecordFilter<R>>,
}
impl<R: Send + Sync> RecordFilter<R> for CompiledNotFilter<R> {
fn test(&self, record: &R) -> bool {
!self.inner.test(record)
}
}
pub fn extract_f64(expr: &Expr) -> Result<f64> {
match expr {
Expr::Literal(ScalarValue::Float64(v)) => Ok(*v),
Expr::Literal(ScalarValue::Int64(v)) => Ok(*v as f64),
_ => Err(Error::invalid_input(format!(
"Expected float literal, got {}",
expr
))),
}
}
pub fn extract_i64(expr: &Expr) -> Result<i64> {
match expr {
Expr::Literal(ScalarValue::Int64(v)) => Ok(*v),
_ => Err(Error::invalid_input(format!(
"Expected int literal, got {}",
expr
))),
}
}
pub fn extract_u32(expr: &Expr) -> Result<u32> {
match expr {
Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 && *v <= u32::MAX as i64 => Ok(*v as u32),
_ => Err(Error::invalid_input(format!(
"Expected u32 literal, got {}",
expr
))),
}
}
pub fn extract_u64(expr: &Expr) -> Result<u64> {
match expr {
Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 => Ok(*v as u64),
_ => Err(Error::invalid_input(format!(
"Expected u64 literal, got {}",
expr
))),
}
}
pub fn extract_u8(expr: &Expr) -> Result<u8> {
match expr {
Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 && *v <= 255 => Ok(*v as u8),
_ => Err(Error::invalid_input(format!(
"Expected u8 literal, got {}",
expr
))),
}
}
pub fn extract_usize(expr: &Expr) -> Result<usize> {
match expr {
Expr::Literal(ScalarValue::Int64(v)) if *v >= 0 => Ok(*v as usize),
_ => Err(Error::invalid_input(format!(
"Expected usize literal, got {}",
expr
))),
}
}
pub fn extract_string(expr: &Expr) -> Result<String> {
match expr {
Expr::Literal(ScalarValue::String(s)) => Ok(s.clone()),
_ => Err(Error::invalid_input(format!(
"Expected string literal, got {}",
expr
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_column_reference() {
let expr = col("qual");
assert_eq!(expr, Expr::Column("qual".to_string()));
assert_eq!(format!("{}", expr), "qual");
}
#[test]
fn test_literal_values() {
assert_eq!(lit(30.0), Expr::Literal(ScalarValue::Float64(30.0)));
assert_eq!(lit(42), Expr::Literal(ScalarValue::Int64(42)));
assert_eq!(
lit("PASS"),
Expr::Literal(ScalarValue::String("PASS".to_string()))
);
assert_eq!(lit(true), Expr::Literal(ScalarValue::Boolean(true)));
}
#[test]
fn test_comparison_builders() {
let expr = col("qual").gt(lit(30.0));
match expr {
Expr::Gt(left, right) => {
assert_eq!(*left, col("qual"));
assert_eq!(*right, lit(30.0));
}
_ => panic!("Expected Gt"),
}
}
#[test]
fn test_and_flattening() {
let expr = col("qual").gt(lit(30.0)).and(Expr::IsSnp).and(Expr::IsPass);
match expr {
Expr::And(exprs) => {
assert_eq!(exprs.len(), 3);
}
_ => panic!("Expected And"),
}
}
#[test]
fn test_complex_expression() {
let expr = col("qual").gt(lit(30.0)).and(Expr::IsSnp).or(Expr::IsPass);
let display = format!("{}", expr);
assert!(display.contains("qual"));
assert!(display.contains("30"));
assert!(display.contains("is_snp"));
assert!(display.contains("is_pass"));
}
#[test]
fn test_genomic_predicates() {
let transition = Expr::IsTransition;
assert_eq!(format!("{}", transition), "is_transition");
let region = Expr::InRegion(GenomicInterval::new("chr1", 1000, 2000));
assert!(format!("{}", region).contains("in_region"));
}
}