taxa-core 0.1.0

taxa engine core: manifest model, formula AST→Polars Expr, bounded query generators over Polars.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
//! Per-metric aggregation plan via sufficient statistics — Polars edition.
//! Each metric declares the aggregate stat columns it needs, how each folds
//! into "Other" (Sum/Min/Max, or None = non-decomposable → NULL), a rank
//! expression (the size value used for top-K ordering), and a finalizer from a
//! row's stat cells to the displayed measure.

use std::collections::HashSet;

use polars::prelude::*;
use serde_json::Value as Json;

use crate::error::{Error, Result};
use crate::formula::compile_formula;
use crate::manifest::Metric;
use crate::output::{av_to_f64, av_to_json};

#[derive(Debug, Clone, Copy)]
pub enum Combine {
    Sum,
    Min,
    Max,
}

#[derive(Clone)]
enum Finalize {
    /// value = cell[alias] (sum/count/min/max/count_distinct/median)
    Single(String),
    /// value = num/den if den else null (mean/weighted_mean) — over stat aliases.
    Ratio { num: String, den: String },
    /// value = numerator-metric / denominator-metric (a `kind:"ratio"` metric):
    /// each operand finalizes from its OWN aggregated stats at this node, then we
    /// divide. Correct at every altitude (root / branch / "+N others") because
    /// the operands roll up additively and the division happens last.
    MetricRatio {
        num: Box<MetricPlan>,
        den: Box<MetricPlan>,
    },
}

#[derive(Clone)]
pub struct MetricPlan {
    pub id: String,
    /// (alias, aliased aggregate expr) — flattened into the level `.agg([...])`.
    pub stats: Vec<(String, Expr)>,
    /// (alias, fold op) for the "Other" branch.
    pub combine: Vec<(String, Option<Combine>)>,
    /// Size value over the stat-alias columns, for top-K ranking.
    pub rank_expr: Expr,
    finalize: Finalize,
}

impl MetricPlan {
    /// row stat cells → displayed measure (preserving int-ness / NULL).
    pub fn finalize_row(&self, get: &dyn Fn(&str) -> AnyValue<'static>) -> Json {
        match &self.finalize {
            Finalize::Single(a) => av_to_json(&get(a)),
            Finalize::Ratio { num, den } => match av_to_f64(&get(den)) {
                Some(n) if n != 0.0 => {
                    let s = av_to_f64(&get(num)).unwrap_or(0.0);
                    serde_json::Number::from_f64(s / n)
                        .map(Json::Number)
                        .unwrap_or(Json::Null)
                }
                _ => Json::Null,
            },
            Finalize::MetricRatio { num, den } => {
                let n = num.finalize_row(get).as_f64();
                let d = den.finalize_row(get).as_f64();
                match (n, d) {
                    (Some(n), Some(d)) if d != 0.0 => serde_json::Number::from_f64(n / d)
                        .map(Json::Number)
                        .unwrap_or(Json::Null),
                    _ => Json::Null,
                }
            }
        }
    }

    /// The aggregate expr to fold this alias over an "Other" tail (already
    /// aggregated to per-group stats).
    pub fn other_expr(alias: &str, combine: Option<Combine>) -> Expr {
        match combine {
            Some(Combine::Sum) => col(alias).sum().alias(alias),
            Some(Combine::Min) => col(alias).min().alias(alias),
            Some(Combine::Max) => col(alias).max().alias(alias),
            None => lit(NULL).alias(alias),
        }
    }
}

/// Per-row value expression (before aggregation).
pub fn value_expr(m: &Metric, columns: &HashSet<String>) -> Result<Expr> {
    let base = if let Some(f) = &m.formula {
        compile_formula(f, columns)?
    } else if let Some(c) = &m.column {
        if !columns.contains(c) {
            return Err(Error::Schema(format!(
                "metric {:?} column {:?} not in source",
                m.id, c
            )));
        }
        col(c.as_str())
    } else {
        return Err(Error::Schema(format!(
            "metric {:?} needs a column or formula",
            m.id
        )));
    };
    Ok(if m.null_policy == "zero" {
        coalesce(&[base, lit(0)])
    } else {
        base
    })
}

/// Derived `<id>__eff` columns for every `kind:"entity"` metric: the value kept
/// ONLY on each grain group's picked (by default latest, by `pick.by`) non-null
/// row, NULL elsewhere. Add these to the already-FILTERED frame before grouping
/// so the metric's `rollup` (e.g. SUM) yields the per-entity total (filter-aware
/// — the picked row is recomputed over whatever rows survive the filters). The
/// `grain` columns key the per-entity window; `row_key` (the dataset id column)
/// breaks `pick.by` ties so exactly one row per group wins. Returns `[]` when no
/// entity metric is present, so callers can skip the work.
pub fn entity_mask_exprs(
    metrics: &[Metric],
    row_key: &str,
    columns: &HashSet<String>,
) -> Result<Vec<Expr>> {
    let mut out = Vec::new();
    for m in metrics {
        if !m.is_entity() {
            continue;
        }
        let val = m
            .column
            .clone()
            .ok_or_else(|| Error::Schema(format!("entity metric {:?} needs a column", m.id)))?;
        if m.grain.is_empty() {
            return Err(Error::Schema(format!(
                "entity metric {:?} needs a non-empty grain",
                m.id
            )));
        }
        // Ordering column for the pick (default: keep the latest row). With no
        // `pick`, the source is assumed one-row-per-grain and we order by the
        // row_key (stable, arbitrary-but-deterministic single winner).
        let (ord, take) = match &m.pick {
            Some(p) => (p.by.clone(), p.take.clone()),
            None => (row_key.to_string(), "last".to_string()),
        };
        for c in m.grain.iter().chain([&val, &ord]) {
            if !columns.contains(c) {
                return Err(Error::Schema(format!(
                    "entity metric {:?} column {:?} not in source",
                    m.id, c
                )));
            }
        }
        // Sortable key per non-null-value row: "<order>\x01<row_key>". The extreme
        // (max for last, min for first) over the grain group → the picked row,
        // ties broken by the unique row_key.
        let okey = concat_str(
            [
                col(ord.as_str()).cast(DataType::String).fill_null(lit("")),
                col(row_key).cast(DataType::String),
            ],
            "\u{1}",
            false,
        );
        let key = when(col(val.as_str()).is_not_null())
            .then(okey)
            .otherwise(lit(NULL));
        let grain_cols: Vec<Expr> = m.grain.iter().map(|c| col(c.as_str())).collect();
        // last/max → keep the group max key; first/min → the min key.
        let want = match take.as_str() {
            "first" | "min" => key.clone().min().over(grain_cols),
            _ => key.clone().max().over(grain_cols), // last | max (default)
        };
        out.push(
            when(key.eq(want))
                .then(col(val.as_str()))
                .otherwise(lit(NULL))
                .alias(format!("{}__eff", m.id).as_str()),
        );
    }
    Ok(out)
}

/// Build plans for a whole metric set, resolving `kind:"ratio"` metrics against
/// their operands. A ratio's plan reuses its numerator/denominator plans' stats
/// (so both aggregate at every node) and divides their finalized values last.
/// Operands must be non-ratio metrics declared in the same set.
pub fn metric_plans(metrics: &[Metric], columns: &HashSet<String>) -> Result<Vec<MetricPlan>> {
    use std::collections::HashMap;
    // Non-ratio plans first, indexed by id — the pool ratios compose from.
    let mut base: HashMap<&str, MetricPlan> = HashMap::new();
    for m in metrics {
        if !m.is_ratio() {
            base.insert(m.id.as_str(), metric_plan(m, columns)?);
        }
    }
    let mut out = Vec::with_capacity(metrics.len());
    for m in metrics {
        if !m.is_ratio() {
            out.push(base[m.id.as_str()].clone());
            continue;
        }
        let resolve = |which: &str, id: &Option<String>| -> Result<MetricPlan> {
            let id = id
                .as_deref()
                .ok_or_else(|| Error::Schema(format!("ratio metric {:?} needs a {which}", m.id)))?;
            base.get(id).cloned().ok_or_else(|| {
                Error::Schema(format!(
                    "ratio metric {:?} {which} {id:?} is not a (non-ratio) metric in this set",
                    m.id
                ))
            })
        };
        let num = resolve("numerator", &m.numerator)?;
        let den = resolve("denominator", &m.denominator)?;
        let mut stats = num.stats.clone();
        stats.extend(den.stats.clone());
        let mut combine = num.combine.clone();
        combine.extend(den.combine.clone());
        out.push(MetricPlan {
            id: m.id.clone(),
            stats,
            combine,
            // A ratio is not an area; treemap rejects it as size_by, so this rank
            // is only a safe placeholder.
            rank_expr: lit(NULL),
            finalize: Finalize::MetricRatio {
                num: Box::new(num),
                den: Box::new(den),
            },
        });
    }
    Ok(out)
}

pub fn metric_plan(m: &Metric, columns: &HashSet<String>) -> Result<MetricPlan> {
    let i = &m.id;

    if m.is_ratio() {
        return Err(Error::Schema(format!(
            "ratio metric {:?} must be built via metric_plans (it composes other metrics)",
            m.id
        )));
    }

    // The cross-sectional aggregation: an entity metric's `rollup`, else `agg`.
    let agg = m.cross_agg();

    // count is special: the stat is the group row-count, no value expr needed.
    if agg == "count" {
        let a = format!("{i}__v");
        return Ok(MetricPlan {
            id: i.clone(),
            stats: vec![(a.clone(), len().alias(a.as_str()))],
            combine: vec![(a.clone(), Some(Combine::Sum))],
            rank_expr: col(a.as_str()),
            finalize: Finalize::Single(a),
        });
    }

    // An entity metric aggregates its MASKED per-grain value (the `<id>__eff`
    // column built by `entity_mask_exprs`): one row per grain group survives as
    // non-null, so the `rollup` over `__eff` is the per-entity total — additive,
    // never double-counting an entity across sub-axes. A non-entity metric
    // aggregates its raw value expression.
    let v = if m.is_entity() {
        col(format!("{i}__eff").as_str())
    } else {
        value_expr(m, columns)?
    };
    let single = |a: String, stat: Expr, c: Combine| MetricPlan {
        id: i.clone(),
        stats: vec![(a.clone(), stat.alias(a.as_str()))],
        combine: vec![(a.clone(), Some(c))],
        rank_expr: col(a.as_str()),
        finalize: Finalize::Single(a),
    };

    let plan = match agg {
        "sum" => {
            let a = format!("{i}__v");
            single(a, v.sum(), Combine::Sum)
        }
        "min" => {
            let a = format!("{i}__v");
            single(a, v.min(), Combine::Min)
        }
        "max" => {
            let a = format!("{i}__v");
            single(a, v.max(), Combine::Max)
        }
        "mean" => {
            let (s, n) = (format!("{i}__s"), format!("{i}__n"));
            let rank = when(col(n.as_str()).gt(lit(0)))
                .then(
                    col(s.as_str()).cast(DataType::Float64)
                        / col(n.as_str()).cast(DataType::Float64),
                )
                .otherwise(lit(NULL));
            MetricPlan {
                id: i.clone(),
                stats: vec![
                    (s.clone(), v.clone().sum().alias(s.as_str())),
                    (n.clone(), v.count().alias(n.as_str())),
                ],
                combine: vec![
                    (s.clone(), Some(Combine::Sum)),
                    (n.clone(), Some(Combine::Sum)),
                ],
                rank_expr: rank,
                finalize: Finalize::Ratio { num: s, den: n },
            }
        }
        "weighted_mean" => {
            let wcol = m
                .weight_column
                .as_ref()
                .ok_or_else(|| Error::Schema("weighted_mean requires weight_column".into()))?;
            if !columns.contains(wcol) {
                return Err(Error::Schema(format!(
                    "weight_column {wcol:?} not in source"
                )));
            }
            let w = col(wcol.as_str());
            let (wx, ww) = (format!("{i}__wx"), format!("{i}__w"));
            let rank = when(col(ww.as_str()).gt(lit(0)))
                .then(
                    col(wx.as_str()).cast(DataType::Float64)
                        / col(ww.as_str()).cast(DataType::Float64),
                )
                .otherwise(lit(NULL));
            MetricPlan {
                id: i.clone(),
                stats: vec![
                    (wx.clone(), (v.clone() * w.clone()).sum().alias(wx.as_str())),
                    // Denominator must only count weights of rows whose value is
                    // present, else null values (null_policy="drop") dilute the
                    // mean. The numerator already drops them (v*w is null → SUM
                    // skips), so match it here.
                    (
                        ww.clone(),
                        when(v.clone().is_not_null())
                            .then(w.clone())
                            .otherwise(lit(NULL))
                            .sum()
                            .alias(ww.as_str()),
                    ),
                ],
                combine: vec![
                    (wx.clone(), Some(Combine::Sum)),
                    (ww.clone(), Some(Combine::Sum)),
                ],
                rank_expr: rank,
                finalize: Finalize::Ratio { num: wx, den: ww },
            }
        }
        "count_distinct" => {
            let a = format!("{i}__v");
            MetricPlan {
                id: i.clone(),
                stats: vec![(a.clone(), v.n_unique().alias(a.as_str()))],
                combine: vec![(a.clone(), None)],
                rank_expr: col(a.as_str()),
                finalize: Finalize::Single(a),
            }
        }
        "median" => {
            let a = format!("{i}__v");
            MetricPlan {
                id: i.clone(),
                stats: vec![(a.clone(), v.median().alias(a.as_str()))],
                combine: vec![(a.clone(), None)],
                rank_expr: col(a.as_str()),
                finalize: Finalize::Single(a),
            }
        }
        other => return Err(Error::Schema(format!("unsupported agg {other:?}"))),
    };
    Ok(plan)
}

/// All aliased stat exprs across a metric set, in metric order, DEDUPED by alias.
/// (A ratio metric reuses its operands' stats, so the same alias can appear in
/// both the operand's own plan and the ratio's — selecting it twice in one
/// `.agg([...])` would be a duplicate-output error.)
pub fn stat_exprs(plans: &[MetricPlan]) -> Vec<Expr> {
    let mut seen = HashSet::new();
    let mut out = Vec::new();
    for p in plans {
        for (alias, e) in &p.stats {
            if seen.insert(alias.clone()) {
                out.push(e.clone());
            }
        }
    }
    out
}

/// (alias, combine) pairs across all plans, in order, DEDUPED by alias (see
/// `stat_exprs`).
pub fn combines(plans: &[MetricPlan]) -> Vec<(String, Option<Combine>)> {
    let mut seen = HashSet::new();
    let mut out = Vec::new();
    for p in plans {
        for (alias, c) in &p.combine {
            if seen.insert(alias.clone()) {
                out.push((alias.clone(), *c));
            }
        }
    }
    out
}