taxa-core 0.1.0

taxa engine core: manifest model, formula AST→Polars Expr, bounded query generators over Polars.
//! frames/views engine behaviors:
//!  - `dims_from`: a narrow series frame `{id,date,metric}` left-joined to a
//!    snapshot's axis-level columns yields the right per-branch lines.
//!  - `branch_set: "treemap"`: the series branch set equals the treemap's top-K
//!    branch keys for the same axis/focus/size_by — even when ranking by a metric
//!    the series frame lacks.

use std::collections::HashMap;
use std::sync::Arc;

use polars::prelude::*;
use serde_json::{json, Value as Json};
use taxa_core::series::SeriesArgs;
use taxa_core::treemap::{branch_set, TreemapArgs};
use taxa_core::{series, FrameDataset, FrameSource, JoinSource, Source};

const D0: i32 = 20454; // 2026-01-01
const D1: i32 = 20461; // 2026-01-08

fn by_key(out: &Json) -> HashMap<String, Json> {
    out["series"]
        .as_array()
        .unwrap()
        .iter()
        .map(|s| (s["key"].as_str().unwrap().to_string(), s.clone()))
        .collect()
}
fn vals(s: &Json) -> Vec<f64> {
    s["values"]
        .as_array()
        .unwrap()
        .iter()
        .map(|v| v.as_f64().unwrap())
        .collect()
}

// ── dims_from: the series frame is narrow {symbol,date,mcap}; the sector level
//    lives only on the snapshot. JoinSource enriches the series frame with the
//    snapshot's `sector` column so series() can roll up per-sector branches. ──
#[test]
fn dims_from_joins_levels_for_series() {
    // Narrow series frame: NO sector column.
    let series_df = df![
        "symbol" => ["AAA", "AAA", "BBB", "BBB", "CCC", "CCC"],
        "date" => Series::new("date".into(), [D0, D1, D0, D1, D0, D1]).cast(&DataType::Date).unwrap(),
        "mcap" => [100.0f64, 110.0, 200.0, 210.0, 50.0, 55.0],
    ]
    .unwrap();
    // Snapshot frame carries the sector level per symbol.
    let snap_df = df![
        "symbol" => ["AAA", "BBB", "CCC"],
        "sector" => ["tech", "tech", "energy"],
        "mcap" => [110.0f64, 210.0, 55.0],
    ]
    .unwrap();

    let snap_src: Arc<dyn Source> = Arc::new(FrameSource::new(snap_df));
    let narrow: Arc<dyn Source> = Arc::new(FrameSource::new(series_df));
    // The series source enriched with the snapshot's `sector` (+ symbol) columns.
    let joined = JoinSource::from_snapshot(
        narrow,
        &*snap_src,
        "symbol",
        &["sector".into(), "symbol".into()],
    )
    .unwrap();

    // The series-frame FrameDataset: axes reference `sector` (joined in), timestamp on `date`.
    let ds: FrameDataset = serde_json::from_value(json!({
        "source": "facts", "id_column": "symbol", "label_column": "symbol", "timestamp_column": "date",
        "axes": [{"id": "sector", "levels": ["sector", "symbol"]}],
        "metrics": [{"id": "mcap", "agg": "sum", "column": "mcap", "unit": "money"}],
        "default_axis": "sector", "default_size_by": "mcap"
    }))
    .unwrap();

    let out = series(&ds, &joined, &SeriesArgs::new("sector", "mcap")).unwrap();
    let by = by_key(&out);
    // tech = AAA + BBB per date; energy = CCC. The level only existed on the snapshot.
    assert_eq!(by["tech"]["dates"], json!(["2026-01-01", "2026-01-08"]));
    assert_eq!(vals(&by["tech"]), vec![300.0, 320.0]);
    assert_eq!(vals(&by["energy"]), vec![50.0, 55.0]);
}

// ── Fix 2 + Fix 3: dims_from joins NOT just level columns but also filter facet
//    columns and per-axis row_filter input columns; and the snapshot is deduped
//    on the JOIN KEY (not whole rows) so a duplicate-id snapshot row can't
//    multiply the series. The snapshot carries `sector` (level), `founding_year`
//    (filter facet) and `mcap` (row_filter input); the narrow series frame has
//    none of them. A shared filter on founding_year and a per-axis row_filter on
//    mcap must both apply to the joined series. ──
#[test]
fn dims_from_joins_filter_and_row_filter_columns() {
    // Narrow series frame: only {symbol, date, flow}.
    let series_df = df![
        "symbol" => ["AAA", "AAA", "BBB", "BBB", "CCC", "CCC"],
        "date" => Series::new("date".into(), [D0, D1, D0, D1, D0, D1]).cast(&DataType::Date).unwrap(),
        "flow" => [100.0f64, 110.0, 200.0, 210.0, 50.0, 55.0],
    ]
    .unwrap();
    // Snapshot: sector (level) + founding_year (filter) + mcap (row_filter input).
    // Note CCC appears TWICE with the SAME dims — full-row dedup would still leave
    // 2 rows and double CCC's series; join-key dedup keeps exactly one.
    let snap_df = df![
        "symbol" => ["AAA", "BBB", "CCC", "CCC"],
        "sector" => ["tech", "tech", "energy", "energy"],
        "founding_year" => [1999i64, 2010, 1980, 1980],
        "mcap" => [5e7f64, 2e7, 5e6, 5e6], // CCC below a 1e7 row_filter floor
    ]
    .unwrap();

    let snap_src: Arc<dyn Source> = Arc::new(FrameSource::new(snap_df));
    let narrow: Arc<dyn Source> = Arc::new(FrameSource::new(series_df));
    // Mirror the CLI loader's dims set: level ∪ filter facet ∪ row_filter inputs.
    let joined = JoinSource::from_snapshot(
        narrow,
        &*snap_src,
        "symbol",
        &[
            "sector".into(),
            "symbol".into(),
            "founding_year".into(),
            "mcap".into(),
        ],
    )
    .unwrap();

    // Axis carries a row_filter on mcap (snapshot-only column); a shared range
    // filter targets founding_year.
    let ds: FrameDataset = serde_json::from_value(json!({
        "source": "facts", "id_column": "symbol", "label_column": "symbol", "timestamp_column": "date",
        "axes": [{"id": "sector", "levels": ["sector", "symbol"],
                  "row_filter": {"op": ">=", "args": [{"col": "mcap"}, {"lit": 1e7}]}}],
        "filters": [{"id": "founding_year", "column": "founding_year", "type": "range"}],
        "metrics": [{"id": "flow", "agg": "sum", "column": "flow", "unit": "number"}],
        "default_axis": "sector", "default_size_by": "flow"
    }))
    .unwrap();

    // (a) row_filter on mcap drops CCC (mcap 5e6 < 1e7) → no `energy` branch.
    let out = series(&ds, &joined, &SeriesArgs::new("sector", "flow")).unwrap();
    let by = by_key(&out);
    assert!(by.contains_key("tech"), "tech kept (AAA+BBB ≥ 1e7 mcap)");
    assert!(
        !by.contains_key("energy"),
        "energy dropped by the mcap row_filter"
    );
    // tech = AAA + BBB; CCC's row appears exactly once (join-key dedup) — even with
    // it folded out, the dedup is proven by tech's clean per-date totals.
    assert_eq!(vals(&by["tech"]), vec![300.0, 320.0]);

    // (b) a founding_year filter applies on the joined snapshot column. Keep only
    // founding_year >= 2005 → only BBB (2010); AAA (1999) drops; CCC already gone.
    let mut a = SeriesArgs::new("sector", "flow");
    a.filters.insert("founding_year_min".into(), json!(2005));
    let out2 = series(&ds, &joined, &a).unwrap();
    let by2 = by_key(&out2);
    assert_eq!(
        vals(&by2["tech"]),
        vec![200.0, 210.0],
        "only BBB survives the founding_year filter"
    );
}

// ── Fix 3: a `dims_from` column present on BOTH the series frame and the snapshot
//    is a COLLISION — joining it would silently bind a filter/row_filter to the
//    series' time-varying value, not the snapshot dim. `from_snapshot` must fail
//    closed with a clear, column-naming error instead of silently misbinding. ──
#[test]
fn dims_from_colliding_column_errors() {
    // Series frame ALREADY has `mcap` (a time-varying metric). The snapshot also
    // has `mcap` (a per-entity dim). Requesting `mcap` as a dim is ambiguous.
    let series_df = df![
        "symbol" => ["AAA", "BBB"],
        "date" => Series::new("date".into(), [D0, D0]).cast(&DataType::Date).unwrap(),
        "mcap" => [100.0f64, 200.0],
    ]
    .unwrap();
    let snap_df = df![
        "symbol" => ["AAA", "BBB"],
        "sector" => ["tech", "energy"],
        "mcap" => [110.0f64, 210.0],
    ]
    .unwrap();
    let snap_src: Arc<dyn Source> = Arc::new(FrameSource::new(snap_df));
    let narrow: Arc<dyn Source> = Arc::new(FrameSource::new(series_df));

    let res = JoinSource::from_snapshot(
        narrow,
        &*snap_src,
        "symbol",
        &["sector".into(), "mcap".into()], // mcap collides
    );
    let msg = match res {
        Ok(_) => panic!("expected a collision error, got Ok"),
        Err(e) => e.to_string(),
    };
    assert!(
        msg.contains("mcap"),
        "error names the colliding column: {msg}"
    );
    assert!(
        msg.contains("BOTH") || msg.contains("both"),
        "error explains the collision: {msg}"
    );
}

// ── branch_set: the series branch set equals the treemap's top-K branch keys.
//    The treemap ranks by `mcap` (a snapshot metric the series frame LACKS); the
//    series plots `flow`. With top_k=2 the treemap keeps the two biggest-by-mcap
//    sectors and folds the rest into Other; series() must plot exactly those. ──
#[test]
fn branch_set_matches_treemap_top_k() {
    // Snapshot: three sectors by mcap. tech=1000, fin=500, energy=100. top_k=2 →
    // keep {tech, fin}, fold {energy} into Other.
    let snap_df = df![
        "symbol" => ["A", "B", "C"],
        "sector" => ["tech", "fin", "energy"],
        "mcap" => [1000.0f64, 500.0, 100.0],
    ]
    .unwrap();
    let snap_ds: FrameDataset = serde_json::from_value(json!({
        "source": "snap", "id_column": "symbol", "label_column": "symbol",
        "axes": [{"id": "sector", "levels": ["sector", "symbol"]}],
        "metrics": [{"id": "mcap", "agg": "sum", "column": "mcap", "unit": "money"}],
        "default_axis": "sector", "default_size_by": "mcap"
    }))
    .unwrap();
    let snap_src = FrameSource::new(snap_df);

    // Resolve the treemap's kept branch set at the sector level, top_k=2.
    let mut ta = TreemapArgs::new("sector");
    ta.top_k = 2;
    ta.size_by = Some("mcap".into());
    let bs = branch_set(&snap_ds, &snap_src, &ta).unwrap();
    assert_eq!(bs.keep, vec!["tech".to_string(), "fin".to_string()]);
    assert!(bs.has_other, "energy folds into Other");

    // Series frame plots `flow` (NO mcap here). It carries sector + date + flow.
    // flow values are engineered so a flow-ranking would pick a DIFFERENT set —
    // proving the branch set comes from the treemap (mcap), not the series metric.
    let series_df = df![
        "symbol" => ["A", "B", "C"],
        "sector" => ["tech", "fin", "energy"],
        "date" => Series::new("date".into(), [D0, D0, D0]).cast(&DataType::Date).unwrap(),
        "flow" => [1.0f64, 2.0, 999.0], // energy has the biggest flow, but isn't in the treemap set
    ]
    .unwrap();
    let series_ds: FrameDataset = serde_json::from_value(json!({
        "source": "facts", "id_column": "symbol", "label_column": "symbol", "timestamp_column": "date",
        "axes": [{"id": "sector", "levels": ["sector", "symbol"]}],
        "metrics": [{"id": "flow", "agg": "sum", "column": "flow", "unit": "number"}],
        "default_axis": "sector", "default_size_by": "flow"
    }))
    .unwrap();
    let series_src = FrameSource::new(series_df);

    let mut a = SeriesArgs::new("sector", "flow");
    a.branches = Some(bs.keep.clone());
    a.include_other = bs.has_other;
    let out = series(&series_ds, &series_src, &a).unwrap();
    let by = by_key(&out);

    // Plotted branches == the treemap's kept set, NOT a flow-ranking (which would
    // keep energy). energy (huge flow) is folded into Other.
    assert!(by.contains_key("tech"), "tech is in the treemap top-K");
    assert!(by.contains_key("fin"), "fin is in the treemap top-K");
    assert!(
        !by.contains_key("energy"),
        "energy is NOT in the treemap top-K → folded"
    );
    assert!(
        by.contains_key("__other__"),
        "the tail folds into Other (has_other)"
    );
    // Other carries energy's flow (999), re-aggregated from its rows.
    assert_eq!(vals(&by["__other__"]), vec![999.0]);

    // branch_keys order follows the treemap's order.
    let keys: Vec<&str> = out["meta"]["branch_keys"]
        .as_array()
        .unwrap()
        .iter()
        .map(|v| v.as_str().unwrap())
        .collect();
    assert_eq!(keys, ["tech", "fin", "__other__"]);
}

// ── branch_set without an "Other" (treemap kept ALL branches) → series has no
//    Other line, even though there are no leftover branches. ──
#[test]
fn branch_set_no_other_when_treemap_keeps_all() {
    let snap_df = df![
        "symbol" => ["A", "B"],
        "sector" => ["tech", "fin"],
        "mcap" => [1000.0f64, 500.0],
    ]
    .unwrap();
    let snap_ds: FrameDataset = serde_json::from_value(json!({
        "source": "snap", "id_column": "symbol", "label_column": "symbol",
        "axes": [{"id": "sector", "levels": ["sector", "symbol"]}],
        "metrics": [{"id": "mcap", "agg": "sum", "column": "mcap", "unit": "money"}],
        "default_axis": "sector", "default_size_by": "mcap"
    }))
    .unwrap();
    let mut ta = TreemapArgs::new("sector");
    ta.top_k = 5; // more than the 2 sectors → no Other
    ta.size_by = Some("mcap".into());
    let bs = branch_set(&snap_ds, &FrameSource::new(snap_df), &ta).unwrap();
    assert_eq!(bs.keep.len(), 2);
    assert!(!bs.has_other);

    let series_df = df![
        "symbol" => ["A", "B"],
        "sector" => ["tech", "fin"],
        "date" => Series::new("date".into(), [D0, D0]).cast(&DataType::Date).unwrap(),
        "flow" => [1.0f64, 2.0],
    ]
    .unwrap();
    let series_ds: FrameDataset = serde_json::from_value(json!({
        "source": "facts", "id_column": "symbol", "label_column": "symbol", "timestamp_column": "date",
        "axes": [{"id": "sector", "levels": ["sector", "symbol"]}],
        "metrics": [{"id": "flow", "agg": "sum", "column": "flow", "unit": "number"}],
        "default_axis": "sector", "default_size_by": "flow"
    }))
    .unwrap();
    let mut a = SeriesArgs::new("sector", "flow");
    a.branches = Some(bs.keep.clone());
    a.include_other = bs.has_other;
    let out = series(&series_ds, &FrameSource::new(series_df), &a).unwrap();
    let by = by_key(&out);
    assert!(
        !by.contains_key("__other__"),
        "no Other when treemap kept all branches"
    );
    assert_eq!(out["series"].as_array().unwrap().len(), 2);
}