use std::collections::HashSet;
use polars::prelude::*;
use serde_json::Value as Json;
use crate::error::{Error, Result};
use crate::formula::compile_formula;
use crate::manifest::Metric;
use crate::output::{av_to_f64, av_to_json};
#[derive(Debug, Clone, Copy)]
pub enum Combine {
Sum,
Min,
Max,
}
#[derive(Clone)]
enum Finalize {
Single(String),
Ratio { num: String, den: String },
MetricRatio {
num: Box<MetricPlan>,
den: Box<MetricPlan>,
},
}
#[derive(Clone)]
pub struct MetricPlan {
pub id: String,
pub stats: Vec<(String, Expr)>,
pub combine: Vec<(String, Option<Combine>)>,
pub rank_expr: Expr,
finalize: Finalize,
}
impl MetricPlan {
pub fn finalize_row(&self, get: &dyn Fn(&str) -> AnyValue<'static>) -> Json {
match &self.finalize {
Finalize::Single(a) => av_to_json(&get(a)),
Finalize::Ratio { num, den } => match av_to_f64(&get(den)) {
Some(n) if n != 0.0 => {
let s = av_to_f64(&get(num)).unwrap_or(0.0);
serde_json::Number::from_f64(s / n)
.map(Json::Number)
.unwrap_or(Json::Null)
}
_ => Json::Null,
},
Finalize::MetricRatio { num, den } => {
let n = num.finalize_row(get).as_f64();
let d = den.finalize_row(get).as_f64();
match (n, d) {
(Some(n), Some(d)) if d != 0.0 => serde_json::Number::from_f64(n / d)
.map(Json::Number)
.unwrap_or(Json::Null),
_ => Json::Null,
}
}
}
}
pub fn other_expr(alias: &str, combine: Option<Combine>) -> Expr {
match combine {
Some(Combine::Sum) => col(alias).sum().alias(alias),
Some(Combine::Min) => col(alias).min().alias(alias),
Some(Combine::Max) => col(alias).max().alias(alias),
None => lit(NULL).alias(alias),
}
}
}
pub fn value_expr(m: &Metric, columns: &HashSet<String>) -> Result<Expr> {
let base = if let Some(f) = &m.formula {
compile_formula(f, columns)?
} else if let Some(c) = &m.column {
if !columns.contains(c) {
return Err(Error::Schema(format!(
"metric {:?} column {:?} not in source",
m.id, c
)));
}
col(c.as_str())
} else {
return Err(Error::Schema(format!(
"metric {:?} needs a column or formula",
m.id
)));
};
Ok(if m.null_policy == "zero" {
coalesce(&[base, lit(0)])
} else {
base
})
}
pub fn entity_mask_exprs(
metrics: &[Metric],
row_key: &str,
columns: &HashSet<String>,
) -> Result<Vec<Expr>> {
let mut out = Vec::new();
for m in metrics {
if !m.is_entity() {
continue;
}
let val = m
.column
.clone()
.ok_or_else(|| Error::Schema(format!("entity metric {:?} needs a column", m.id)))?;
if m.grain.is_empty() {
return Err(Error::Schema(format!(
"entity metric {:?} needs a non-empty grain",
m.id
)));
}
let (ord, take) = match &m.pick {
Some(p) => (p.by.clone(), p.take.clone()),
None => (row_key.to_string(), "last".to_string()),
};
for c in m.grain.iter().chain([&val, &ord]) {
if !columns.contains(c) {
return Err(Error::Schema(format!(
"entity metric {:?} column {:?} not in source",
m.id, c
)));
}
}
let okey = concat_str(
[
col(ord.as_str()).cast(DataType::String).fill_null(lit("")),
col(row_key).cast(DataType::String),
],
"\u{1}",
false,
);
let key = when(col(val.as_str()).is_not_null())
.then(okey)
.otherwise(lit(NULL));
let grain_cols: Vec<Expr> = m.grain.iter().map(|c| col(c.as_str())).collect();
let want = match take.as_str() {
"first" | "min" => key.clone().min().over(grain_cols),
_ => key.clone().max().over(grain_cols), };
out.push(
when(key.eq(want))
.then(col(val.as_str()))
.otherwise(lit(NULL))
.alias(format!("{}__eff", m.id).as_str()),
);
}
Ok(out)
}
pub fn metric_plans(metrics: &[Metric], columns: &HashSet<String>) -> Result<Vec<MetricPlan>> {
use std::collections::HashMap;
let mut base: HashMap<&str, MetricPlan> = HashMap::new();
for m in metrics {
if !m.is_ratio() {
base.insert(m.id.as_str(), metric_plan(m, columns)?);
}
}
let mut out = Vec::with_capacity(metrics.len());
for m in metrics {
if !m.is_ratio() {
out.push(base[m.id.as_str()].clone());
continue;
}
let resolve = |which: &str, id: &Option<String>| -> Result<MetricPlan> {
let id = id
.as_deref()
.ok_or_else(|| Error::Schema(format!("ratio metric {:?} needs a {which}", m.id)))?;
base.get(id).cloned().ok_or_else(|| {
Error::Schema(format!(
"ratio metric {:?} {which} {id:?} is not a (non-ratio) metric in this set",
m.id
))
})
};
let num = resolve("numerator", &m.numerator)?;
let den = resolve("denominator", &m.denominator)?;
let mut stats = num.stats.clone();
stats.extend(den.stats.clone());
let mut combine = num.combine.clone();
combine.extend(den.combine.clone());
out.push(MetricPlan {
id: m.id.clone(),
stats,
combine,
rank_expr: lit(NULL),
finalize: Finalize::MetricRatio {
num: Box::new(num),
den: Box::new(den),
},
});
}
Ok(out)
}
pub fn metric_plan(m: &Metric, columns: &HashSet<String>) -> Result<MetricPlan> {
let i = &m.id;
if m.is_ratio() {
return Err(Error::Schema(format!(
"ratio metric {:?} must be built via metric_plans (it composes other metrics)",
m.id
)));
}
let agg = m.cross_agg();
if agg == "count" {
let a = format!("{i}__v");
return Ok(MetricPlan {
id: i.clone(),
stats: vec![(a.clone(), len().alias(a.as_str()))],
combine: vec![(a.clone(), Some(Combine::Sum))],
rank_expr: col(a.as_str()),
finalize: Finalize::Single(a),
});
}
let v = if m.is_entity() {
col(format!("{i}__eff").as_str())
} else {
value_expr(m, columns)?
};
let single = |a: String, stat: Expr, c: Combine| MetricPlan {
id: i.clone(),
stats: vec![(a.clone(), stat.alias(a.as_str()))],
combine: vec![(a.clone(), Some(c))],
rank_expr: col(a.as_str()),
finalize: Finalize::Single(a),
};
let plan = match agg {
"sum" => {
let a = format!("{i}__v");
single(a, v.sum(), Combine::Sum)
}
"min" => {
let a = format!("{i}__v");
single(a, v.min(), Combine::Min)
}
"max" => {
let a = format!("{i}__v");
single(a, v.max(), Combine::Max)
}
"mean" => {
let (s, n) = (format!("{i}__s"), format!("{i}__n"));
let rank = when(col(n.as_str()).gt(lit(0)))
.then(
col(s.as_str()).cast(DataType::Float64)
/ col(n.as_str()).cast(DataType::Float64),
)
.otherwise(lit(NULL));
MetricPlan {
id: i.clone(),
stats: vec![
(s.clone(), v.clone().sum().alias(s.as_str())),
(n.clone(), v.count().alias(n.as_str())),
],
combine: vec![
(s.clone(), Some(Combine::Sum)),
(n.clone(), Some(Combine::Sum)),
],
rank_expr: rank,
finalize: Finalize::Ratio { num: s, den: n },
}
}
"weighted_mean" => {
let wcol = m
.weight_column
.as_ref()
.ok_or_else(|| Error::Schema("weighted_mean requires weight_column".into()))?;
if !columns.contains(wcol) {
return Err(Error::Schema(format!(
"weight_column {wcol:?} not in source"
)));
}
let w = col(wcol.as_str());
let (wx, ww) = (format!("{i}__wx"), format!("{i}__w"));
let rank = when(col(ww.as_str()).gt(lit(0)))
.then(
col(wx.as_str()).cast(DataType::Float64)
/ col(ww.as_str()).cast(DataType::Float64),
)
.otherwise(lit(NULL));
MetricPlan {
id: i.clone(),
stats: vec![
(wx.clone(), (v.clone() * w.clone()).sum().alias(wx.as_str())),
(
ww.clone(),
when(v.clone().is_not_null())
.then(w.clone())
.otherwise(lit(NULL))
.sum()
.alias(ww.as_str()),
),
],
combine: vec![
(wx.clone(), Some(Combine::Sum)),
(ww.clone(), Some(Combine::Sum)),
],
rank_expr: rank,
finalize: Finalize::Ratio { num: wx, den: ww },
}
}
"count_distinct" => {
let a = format!("{i}__v");
MetricPlan {
id: i.clone(),
stats: vec![(a.clone(), v.n_unique().alias(a.as_str()))],
combine: vec![(a.clone(), None)],
rank_expr: col(a.as_str()),
finalize: Finalize::Single(a),
}
}
"median" => {
let a = format!("{i}__v");
MetricPlan {
id: i.clone(),
stats: vec![(a.clone(), v.median().alias(a.as_str()))],
combine: vec![(a.clone(), None)],
rank_expr: col(a.as_str()),
finalize: Finalize::Single(a),
}
}
other => return Err(Error::Schema(format!("unsupported agg {other:?}"))),
};
Ok(plan)
}
pub fn stat_exprs(plans: &[MetricPlan]) -> Vec<Expr> {
let mut seen = HashSet::new();
let mut out = Vec::new();
for p in plans {
for (alias, e) in &p.stats {
if seen.insert(alias.clone()) {
out.push(e.clone());
}
}
}
out
}
pub fn combines(plans: &[MetricPlan]) -> Vec<(String, Option<Combine>)> {
let mut seen = HashSet::new();
let mut out = Vec::new();
for p in plans {
for (alias, c) in &p.combine {
if seen.insert(alias.clone()) {
out.push((alias.clone(), *c));
}
}
}
out
}