use anyhow::{bail, Result};
use polars::lazy::dsl::Expr as PolarsExpr;
use polars::prelude::*;
use crate::parser::{Expr, Operator};
use super::*;
pub fn eval(args: &[Expr], ctx: &mut Context) -> Result<()> {
if let Some(mut df) = ctx.take_df() {
for arg in args {
let expr = df
.collect_schema()
.map_err(anyhow::Error::from)
.and_then(|schema| eval_expr(arg, &schema))
.map_err(|e| anyhow!("filter error: {e}"))?;
df = df.filter(expr);
}
ctx.set_df(df)?;
} else if ctx.is_grouping() {
bail!("filter error: must call summarize after a group_by");
} else {
bail!("filter error: missing input dataframe");
}
Ok(())
}
fn eval_expr(expr: &Expr, schema: &Schema) -> Result<PolarsExpr> {
match expr {
Expr::BinaryOp(lhs, op, rhs) => {
let lhs = eval_expr(lhs, schema)?;
let rhs = eval_expr(rhs, schema)?;
let result = match op {
Operator::Eq => lhs.eq(rhs),
Operator::NotEq => lhs.neq(rhs),
Operator::Lt => lhs.lt(rhs),
Operator::LtEq => lhs.lt_eq(rhs),
Operator::Gt => lhs.gt(rhs),
Operator::GtEq => lhs.gt_eq(rhs),
Operator::And => lhs.and(rhs),
Operator::Or => lhs.or(rhs),
_ => panic!("Unexpected filter operator {op}"),
};
Ok(result)
}
Expr::Identifier(_) => args::column(expr, schema),
Expr::String(s) => Ok(lit(s.clone())),
Expr::Number(n) => Ok(lit(*n)),
Expr::Function(name, args) if name == "dt" => {
let ts = args::timestamp(&args[0])?;
Ok(lit(ts))
}
Expr::UnaryOp(Operator::Not, expr) => eval_predicate(expr, schema).map(|expr| expr.not()),
Expr::Function(_, _) => eval_predicate(expr, schema),
_ => panic!("Unexpected filter expression {expr}"),
}
}
fn eval_predicate(expr: &Expr, schema: &Schema) -> Result<PolarsExpr> {
match expr {
Expr::Function(name, args) if name == "contains" => {
let column = args::identifier(&args[0]);
let column_type = schema
.get(&column)
.ok_or_else(|| anyhow!("Unknown contains column '{column}'"))?;
match column_type {
DataType::List(elem_type) => list_contains(&column, &args[1], elem_type),
DataType::String => string_contains(&column, &args[1]),
_ => Err(anyhow!("Column '{column}' must be a str or a list")),
}
}
Expr::Function(name, args) if name == "is_null" => {
let r = args::column(&args[0], schema).map(|c| c.is_null());
println!("{r:?}");
r
}
_ => panic!("Unexpected filter expression {expr}"),
}
}
fn list_contains(column: &str, pattern: &Expr, elem_type: &DataType) -> Result<PolarsExpr> {
println!("contain {elem_type}");
use DataType::*;
match (elem_type, pattern) {
(Int8, Expr::Number(n)) => Ok(col(column).list().contains(lit(*n as i8))),
(Int16, Expr::Number(n)) => Ok(col(column).list().contains(lit(*n as i16))),
(Int32, Expr::Number(n)) => Ok(col(column).list().contains(lit(*n as i32))),
(Int64, Expr::Number(n)) => Ok(col(column).list().contains(lit(*n as i64))),
(UInt8, Expr::Number(n)) => Ok(col(column).list().contains(lit(*n as u8))),
(UInt16, Expr::Number(n)) => Ok(col(column).list().contains(lit(*n as u16))),
(UInt32, Expr::Number(n)) => Ok(col(column).list().contains(lit(*n as u32))),
(UInt64, Expr::Number(n)) => Ok(col(column).list().contains(lit(*n as u64))),
(Float32, Expr::Number(n)) => Ok(col(column).list().contains(lit(*n as f32))),
(Float64, Expr::Number(n)) => Ok(col(column).list().contains(lit(*n))),
(String, Expr::String(s)) => {
let re = regex::Regex::new(s)
.map_err(|_| anyhow!("invalid contains regex '{s}' for column '{column}'"))?;
let function = move |c: Column| {
let ca = c.list()?;
let mut bools = Vec::with_capacity(ca.len());
ca.into_iter().for_each(|arr| {
let found = if let Some(s) = arr {
s.str()
.map(|ca| {
ca.into_iter()
.any(|s| s.map(|s| re.is_match(s)).unwrap_or(false))
})
.unwrap_or_default()
} else {
false
};
bools.push(found);
});
let s = BooleanChunked::new(ca.name().to_owned(), bools).into_series();
Ok(Some(Column::new(ca.name().to_owned(), s)))
};
Ok(col(column).apply(function, GetOutput::from_type(DataType::Boolean)))
}
_ => bail!("contains error: invalid type {elem_type} for column '{column}'"),
}
}
fn string_contains(column: &str, pattern: &Expr) -> Result<PolarsExpr> {
if let Expr::String(re) = pattern {
regex::Regex::new(re)
.map_err(|_| anyhow!("invalid contains regex '{re}' for column '{column}'"))?;
Ok(col(column).str().contains(lit(re.to_owned()), false))
} else {
Err(anyhow!(
"contains predicate for column '{column}' must be a regex"
))
}
}