use std::collections::HashMap;
use std::sync::Arc;
use polars::prelude::*;
use serde_json::{json, Value as Json};
use taxa_core::series::SeriesArgs;
use taxa_core::treemap::{branch_set, TreemapArgs};
use taxa_core::{series, FrameDataset, FrameSource, JoinSource, Source};
const D0: i32 = 20454; const D1: i32 = 20461;
fn by_key(out: &Json) -> HashMap<String, Json> {
out["series"]
.as_array()
.unwrap()
.iter()
.map(|s| (s["key"].as_str().unwrap().to_string(), s.clone()))
.collect()
}
fn vals(s: &Json) -> Vec<f64> {
s["values"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_f64().unwrap())
.collect()
}
#[test]
fn dims_from_joins_levels_for_series() {
let series_df = df![
"symbol" => ["AAA", "AAA", "BBB", "BBB", "CCC", "CCC"],
"date" => Series::new("date".into(), [D0, D1, D0, D1, D0, D1]).cast(&DataType::Date).unwrap(),
"mcap" => [100.0f64, 110.0, 200.0, 210.0, 50.0, 55.0],
]
.unwrap();
let snap_df = df![
"symbol" => ["AAA", "BBB", "CCC"],
"sector" => ["tech", "tech", "energy"],
"mcap" => [110.0f64, 210.0, 55.0],
]
.unwrap();
let snap_src: Arc<dyn Source> = Arc::new(FrameSource::new(snap_df));
let narrow: Arc<dyn Source> = Arc::new(FrameSource::new(series_df));
let joined = JoinSource::from_snapshot(
narrow,
&*snap_src,
"symbol",
&["sector".into(), "symbol".into()],
)
.unwrap();
let ds: FrameDataset = serde_json::from_value(json!({
"source": "facts", "id_column": "symbol", "label_column": "symbol", "timestamp_column": "date",
"axes": [{"id": "sector", "levels": ["sector", "symbol"]}],
"metrics": [{"id": "mcap", "agg": "sum", "column": "mcap", "unit": "money"}],
"default_axis": "sector", "default_size_by": "mcap"
}))
.unwrap();
let out = series(&ds, &joined, &SeriesArgs::new("sector", "mcap")).unwrap();
let by = by_key(&out);
assert_eq!(by["tech"]["dates"], json!(["2026-01-01", "2026-01-08"]));
assert_eq!(vals(&by["tech"]), vec![300.0, 320.0]);
assert_eq!(vals(&by["energy"]), vec![50.0, 55.0]);
}
#[test]
fn dims_from_joins_filter_and_row_filter_columns() {
let series_df = df![
"symbol" => ["AAA", "AAA", "BBB", "BBB", "CCC", "CCC"],
"date" => Series::new("date".into(), [D0, D1, D0, D1, D0, D1]).cast(&DataType::Date).unwrap(),
"flow" => [100.0f64, 110.0, 200.0, 210.0, 50.0, 55.0],
]
.unwrap();
let snap_df = df![
"symbol" => ["AAA", "BBB", "CCC", "CCC"],
"sector" => ["tech", "tech", "energy", "energy"],
"founding_year" => [1999i64, 2010, 1980, 1980],
"mcap" => [5e7f64, 2e7, 5e6, 5e6], ]
.unwrap();
let snap_src: Arc<dyn Source> = Arc::new(FrameSource::new(snap_df));
let narrow: Arc<dyn Source> = Arc::new(FrameSource::new(series_df));
let joined = JoinSource::from_snapshot(
narrow,
&*snap_src,
"symbol",
&[
"sector".into(),
"symbol".into(),
"founding_year".into(),
"mcap".into(),
],
)
.unwrap();
let ds: FrameDataset = serde_json::from_value(json!({
"source": "facts", "id_column": "symbol", "label_column": "symbol", "timestamp_column": "date",
"axes": [{"id": "sector", "levels": ["sector", "symbol"],
"row_filter": {"op": ">=", "args": [{"col": "mcap"}, {"lit": 1e7}]}}],
"filters": [{"id": "founding_year", "column": "founding_year", "type": "range"}],
"metrics": [{"id": "flow", "agg": "sum", "column": "flow", "unit": "number"}],
"default_axis": "sector", "default_size_by": "flow"
}))
.unwrap();
let out = series(&ds, &joined, &SeriesArgs::new("sector", "flow")).unwrap();
let by = by_key(&out);
assert!(by.contains_key("tech"), "tech kept (AAA+BBB ≥ 1e7 mcap)");
assert!(
!by.contains_key("energy"),
"energy dropped by the mcap row_filter"
);
assert_eq!(vals(&by["tech"]), vec![300.0, 320.0]);
let mut a = SeriesArgs::new("sector", "flow");
a.filters.insert("founding_year_min".into(), json!(2005));
let out2 = series(&ds, &joined, &a).unwrap();
let by2 = by_key(&out2);
assert_eq!(
vals(&by2["tech"]),
vec![200.0, 210.0],
"only BBB survives the founding_year filter"
);
}
#[test]
fn dims_from_colliding_column_errors() {
let series_df = df![
"symbol" => ["AAA", "BBB"],
"date" => Series::new("date".into(), [D0, D0]).cast(&DataType::Date).unwrap(),
"mcap" => [100.0f64, 200.0],
]
.unwrap();
let snap_df = df![
"symbol" => ["AAA", "BBB"],
"sector" => ["tech", "energy"],
"mcap" => [110.0f64, 210.0],
]
.unwrap();
let snap_src: Arc<dyn Source> = Arc::new(FrameSource::new(snap_df));
let narrow: Arc<dyn Source> = Arc::new(FrameSource::new(series_df));
let res = JoinSource::from_snapshot(
narrow,
&*snap_src,
"symbol",
&["sector".into(), "mcap".into()], );
let msg = match res {
Ok(_) => panic!("expected a collision error, got Ok"),
Err(e) => e.to_string(),
};
assert!(
msg.contains("mcap"),
"error names the colliding column: {msg}"
);
assert!(
msg.contains("BOTH") || msg.contains("both"),
"error explains the collision: {msg}"
);
}
#[test]
fn branch_set_matches_treemap_top_k() {
let snap_df = df![
"symbol" => ["A", "B", "C"],
"sector" => ["tech", "fin", "energy"],
"mcap" => [1000.0f64, 500.0, 100.0],
]
.unwrap();
let snap_ds: FrameDataset = serde_json::from_value(json!({
"source": "snap", "id_column": "symbol", "label_column": "symbol",
"axes": [{"id": "sector", "levels": ["sector", "symbol"]}],
"metrics": [{"id": "mcap", "agg": "sum", "column": "mcap", "unit": "money"}],
"default_axis": "sector", "default_size_by": "mcap"
}))
.unwrap();
let snap_src = FrameSource::new(snap_df);
let mut ta = TreemapArgs::new("sector");
ta.top_k = 2;
ta.size_by = Some("mcap".into());
let bs = branch_set(&snap_ds, &snap_src, &ta).unwrap();
assert_eq!(bs.keep, vec!["tech".to_string(), "fin".to_string()]);
assert!(bs.has_other, "energy folds into Other");
let series_df = df![
"symbol" => ["A", "B", "C"],
"sector" => ["tech", "fin", "energy"],
"date" => Series::new("date".into(), [D0, D0, D0]).cast(&DataType::Date).unwrap(),
"flow" => [1.0f64, 2.0, 999.0], ]
.unwrap();
let series_ds: FrameDataset = serde_json::from_value(json!({
"source": "facts", "id_column": "symbol", "label_column": "symbol", "timestamp_column": "date",
"axes": [{"id": "sector", "levels": ["sector", "symbol"]}],
"metrics": [{"id": "flow", "agg": "sum", "column": "flow", "unit": "number"}],
"default_axis": "sector", "default_size_by": "flow"
}))
.unwrap();
let series_src = FrameSource::new(series_df);
let mut a = SeriesArgs::new("sector", "flow");
a.branches = Some(bs.keep.clone());
a.include_other = bs.has_other;
let out = series(&series_ds, &series_src, &a).unwrap();
let by = by_key(&out);
assert!(by.contains_key("tech"), "tech is in the treemap top-K");
assert!(by.contains_key("fin"), "fin is in the treemap top-K");
assert!(
!by.contains_key("energy"),
"energy is NOT in the treemap top-K → folded"
);
assert!(
by.contains_key("__other__"),
"the tail folds into Other (has_other)"
);
assert_eq!(vals(&by["__other__"]), vec![999.0]);
let keys: Vec<&str> = out["meta"]["branch_keys"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_str().unwrap())
.collect();
assert_eq!(keys, ["tech", "fin", "__other__"]);
}
#[test]
fn branch_set_no_other_when_treemap_keeps_all() {
let snap_df = df![
"symbol" => ["A", "B"],
"sector" => ["tech", "fin"],
"mcap" => [1000.0f64, 500.0],
]
.unwrap();
let snap_ds: FrameDataset = serde_json::from_value(json!({
"source": "snap", "id_column": "symbol", "label_column": "symbol",
"axes": [{"id": "sector", "levels": ["sector", "symbol"]}],
"metrics": [{"id": "mcap", "agg": "sum", "column": "mcap", "unit": "money"}],
"default_axis": "sector", "default_size_by": "mcap"
}))
.unwrap();
let mut ta = TreemapArgs::new("sector");
ta.top_k = 5; ta.size_by = Some("mcap".into());
let bs = branch_set(&snap_ds, &FrameSource::new(snap_df), &ta).unwrap();
assert_eq!(bs.keep.len(), 2);
assert!(!bs.has_other);
let series_df = df![
"symbol" => ["A", "B"],
"sector" => ["tech", "fin"],
"date" => Series::new("date".into(), [D0, D0]).cast(&DataType::Date).unwrap(),
"flow" => [1.0f64, 2.0],
]
.unwrap();
let series_ds: FrameDataset = serde_json::from_value(json!({
"source": "facts", "id_column": "symbol", "label_column": "symbol", "timestamp_column": "date",
"axes": [{"id": "sector", "levels": ["sector", "symbol"]}],
"metrics": [{"id": "flow", "agg": "sum", "column": "flow", "unit": "number"}],
"default_axis": "sector", "default_size_by": "flow"
}))
.unwrap();
let mut a = SeriesArgs::new("sector", "flow");
a.branches = Some(bs.keep.clone());
a.include_other = bs.has_other;
let out = series(&series_ds, &FrameSource::new(series_df), &a).unwrap();
let by = by_key(&out);
assert!(
!by.contains_key("__other__"),
"no Other when treemap kept all branches"
);
assert_eq!(out["series"].as_array().unwrap().len(), 2);
}