use polars::prelude::*;
use serde_json::{json, Value as Json};
use crate::error::Result;
use crate::manifest::Dataset;
use crate::source::Source;
const MAX_LEVEL_CARD: u64 = 2000;
const MAX_FILTER_CARD: u64 = 200;
#[derive(PartialEq)]
enum Kind {
Numeric,
Temporal,
Text,
}
fn kind(dtype: &str) -> Kind {
let t = dtype.to_lowercase();
if t.contains("date") || t.contains("time") || t.contains("duration") {
Kind::Temporal
} else if t.starts_with('i')
|| t.starts_with('u')
|| t.starts_with('f')
|| t.contains("decimal")
{
let rest = &t[1..];
if rest.chars().all(|c| c.is_ascii_digit()) && !rest.is_empty() || t.contains("decimal") {
Kind::Numeric
} else {
Kind::Text
}
} else {
Kind::Text
}
}
fn title(col: &str) -> String {
let s = col.replace('_', " ");
let mut c = s.chars();
match c.next() {
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
None => s,
}
}
pub fn infer_dataset(
source: &dyn Source,
title_arg: Option<&str>,
entity_noun: &str,
) -> Result<Dataset> {
let schema = source.schema()?;
let mut numeric = vec![];
let mut temporal = vec![];
let mut text = vec![];
for (name, dt) in &schema {
match kind(dt) {
Kind::Numeric => numeric.push(name.clone()),
Kind::Temporal => temporal.push(name.clone()),
Kind::Text => text.push(name.clone()),
}
}
let mut card: std::collections::HashMap<String, u64> = std::collections::HashMap::new();
for c in &text {
let df = source
.frame()?
.select([col(c.as_str()).n_unique().alias("n")])
.collect()?;
let n = df.column("n")?.get(0)?.try_extract::<u64>().unwrap_or(0);
card.insert(c.clone(), n);
}
let mut text_by_card = text.clone();
text_by_card.sort_by_key(|c| std::cmp::Reverse(*card.get(c).unwrap_or(&0)));
let id_col = text_by_card.first().cloned().unwrap_or_else(|| {
schema
.first()
.map(|(n, _)| n.clone())
.unwrap_or_else(|| "id".into())
});
let mut cats: Vec<String> = text
.iter()
.filter(|c| **c != id_col && (1..=MAX_LEVEL_CARD).contains(card.get(*c).unwrap_or(&0)))
.cloned()
.collect();
cats.sort_by_key(|c| *card.get(c).unwrap_or(&0));
let primary: Vec<String> = cats
.iter()
.take(3)
.cloned()
.chain(std::iter::once(id_col.clone()))
.collect();
let mut axes: Vec<Json> = vec![json!({
"id": "hierarchy",
"levels": primary,
"label": primary.iter().map(|c| title(c)).collect::<Vec<_>>().join(" → "),
})];
let top3: Vec<&String> = cats.iter().take(3).collect();
for c in &cats {
if top3.contains(&c) {
continue;
}
axes.push(json!({"id": c, "levels": [c, id_col], "label": title(c)}));
}
if cats.is_empty() {
axes = vec![json!({"id": "all", "levels": [id_col], "label": title(&id_col)})];
}
let mut metrics: Vec<Json> =
vec![json!({"id": "count", "agg": "count", "unit": "count", "label": "Count"})];
for c in &numeric {
metrics
.push(json!({"id": c, "agg": "sum", "column": c, "unit": "number", "label": title(c)}));
}
let filters: Vec<Json> = cats
.iter()
.filter(|c| *card.get(*c).unwrap_or(&0) <= MAX_FILTER_CARD)
.map(|c| json!({"id": c, "column": c, "type": "categorical", "label": title(c)}))
.collect();
let default_size = numeric.first().cloned().unwrap_or_else(|| "count".into());
let default_axis = axes
.first()
.and_then(|a| a["id"].as_str())
.unwrap_or("hierarchy")
.to_string();
let mut main_frame = serde_json::Map::new();
main_frame.insert("source".into(), json!("data"));
main_frame.insert("id_column".into(), json!(id_col));
main_frame.insert("label_column".into(), json!(id_col));
main_frame.insert("metrics".into(), json!(metrics));
if let Some(ts) = temporal.first() {
main_frame.insert("timestamp".into(), json!(ts));
}
let ds = json!({
"title": title_arg.unwrap_or("taxa"),
"entity_noun": entity_noun,
"entity_noun_plural": format!("{entity_noun}s"),
"axes": axes, "filters": filters,
"default_axis": default_axis, "default_size_by": default_size,
"frames": {"main": Json::Object(main_frame)},
});
serde_json::from_value(ds)
.map_err(|e| crate::error::Error::Schema(format!("inferred manifest is invalid: {e}")))
}