use std::sync::Arc;
use axum::{
extract::{Path, Query, State},
http::{header, HeaderMap, StatusCode, Uri},
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use serde_json::{Map, Value};
use taxa_core::series::SeriesArgs;
use taxa_core::treemap::TreemapArgs;
use taxa_core::{boot_manifest, frontend_tree, Backend, FrameDataset};
pub mod geo;
#[derive(rust_embed::RustEmbed)]
#[folder = "web"]
#[exclude = "**/*.test.js"]
#[exclude = "package.json"]
struct Assets;
pub struct App {
pub ds: FrameDataset,
pub backend: Arc<dyn Backend>,
pub manifest: Value,
pub series_ds: Option<FrameDataset>,
pub series_backend: Option<Arc<dyn Backend>>,
pub series_follows_treemap: bool,
}
pub type Shared = Arc<App>;
pub fn build_router(
ds: FrameDataset,
backend: Arc<dyn Backend>,
series: Option<(FrameDataset, Arc<dyn Backend>)>,
) -> Router {
build_router_ext(ds, backend, series, false)
}
pub struct ServeDataset {
pub id: String,
pub label: String,
pub app: Shared,
}
fn make_app(
ds: FrameDataset,
backend: Arc<dyn Backend>,
series: Option<(FrameDataset, Arc<dyn Backend>)>,
series_follows_treemap: bool,
) -> Shared {
let mut manifest = boot_manifest(&ds);
if let Ok(cols) = backend.columns() {
let mut maps: Vec<serde_json::Value> = Vec::new();
if cols.contains("state_fips") {
maps.push(serde_json::json!({"id":"states","label":"States","topojson":"us-states","key_column":"state_fips","key_kind":"fips2"}));
}
if cols.contains("county_fips") {
maps.push(serde_json::json!({"id":"counties","label":"Counties","topojson":"us-counties","key_column":"county_fips","key_kind":"fips5"}));
}
if cols.contains("iso3") {
maps.push(serde_json::json!({"id":"countries","label":"Countries","topojson":"world","key_column":"iso3","key_kind":"iso3"}));
}
if !maps.is_empty() {
let default_map = maps
.iter()
.find(|m| m["id"] == "counties")
.or_else(|| maps.first())
.map(|m| m["id"].clone())
.unwrap();
let default_res = if maps.iter().any(|m| m["id"] == "counties") {
"low"
} else {
"med"
};
if let Some(views) = manifest.get_mut("views").and_then(|v| v.as_object_mut()) {
views.insert(
"geo".into(),
serde_json::json!({
"maps": maps,
"default_map": default_map,
"default_colormap": "Viridis",
"default_scale": "log",
"default_projection": "naturalEarth",
"default_res": default_res,
"default_outline": true,
}),
);
}
}
}
let (series_ds, series_backend) = match series {
Some((sds, sb)) => (Some(sds), Some(sb)),
None => (None, None),
};
Arc::new(App {
ds,
backend,
manifest,
series_ds,
series_backend,
series_follows_treemap,
})
}
fn api_router(app: Shared) -> Router {
Router::new()
.route("/api/manifest", get(h_manifest))
.route("/api/treemap", post(h_treemap))
.route("/api/geo", post(h_geo))
.route("/api/series", post(h_series))
.route("/api/scatter", post(h_scatter))
.route("/api/search", get(h_search))
.route("/api/filter-options/:facet", get(h_filter_options))
.route("/api/entity/*id", get(h_entity))
.with_state(app)
}
pub fn build_router_ext(
ds: FrameDataset,
backend: Arc<dyn Backend>,
series: Option<(FrameDataset, Arc<dyn Backend>)>,
series_follows_treemap: bool,
) -> Router {
let app = make_app(ds, backend, series, series_follows_treemap);
let label = app
.manifest
.get("title")
.and_then(|v| v.as_str())
.unwrap_or("dataset")
.to_string();
build_router_multi(vec![ServeDataset {
id: "main".into(),
label,
app,
}])
}
pub fn build_router_multi(entries: Vec<ServeDataset>) -> Router {
let list = Value::Array(
entries
.iter()
.map(|e| serde_json::json!({"id": e.id, "label": e.label}))
.collect(),
);
let mut r = Router::new();
for e in &entries {
r = r.nest(&format!("/d/{}", e.id), api_router(e.app.clone()));
}
if let Some(first) = entries.first() {
r = r.merge(api_router(first.app.clone()));
}
let r = r
.route(
"/api/datasets",
get(move || {
let l = list.clone();
async move { Json(l) }
}),
)
.route("/static/*path", get(h_asset))
.fallback(h_index);
let base = base_path();
if base.is_empty() {
r
} else {
Router::new().nest(base, r).fallback(h_index)
}
}
pub async fn serve(
ds: FrameDataset,
backend: Arc<dyn Backend>,
series: Option<(FrameDataset, Arc<dyn Backend>)>,
host: &str,
port: u16,
) -> std::io::Result<()> {
serve_ext(ds, backend, series, false, host, port).await
}
fn serving_banner(host: &str, port: u16) -> String {
use std::io::IsTerminal;
let mut banner = if std::io::stdout().is_terminal() {
format!("\x1b[1;32mtaxa\x1b[0m serving on \x1b[36mhttp://{host}:{port}\x1b[0m")
} else {
format!("taxa serving on http://{host}:{port}")
};
if let Some(dir) = assets_dir() {
banner.push_str(&format!(
"\n dev assets: serving frontend from {} (embedded bundle bypassed)",
dir.display()
));
}
if let Some(dir) = geo::geo_dir_path() {
match geo::check_registry(dir) {
Ok(Some(summary)) => {
banner.push_str(&format!("\n geo: {} — {}", dir.display(), summary))
}
Ok(None) => banner.push_str(&format!(
"\n geo: {} (no registry.json — built-in maps only)",
dir.display()
)),
Err(e) => banner.push_str(&format!("\n geo: {} — WARNING: {}", dir.display(), e)),
}
}
banner
}
pub async fn serve_ext(
ds: FrameDataset,
backend: Arc<dyn Backend>,
series: Option<(FrameDataset, Arc<dyn Backend>)>,
series_follows_treemap: bool,
host: &str,
port: u16,
) -> std::io::Result<()> {
let router = build_router_ext(ds, backend, series, series_follows_treemap);
let listener = tokio::net::TcpListener::bind((host, port)).await?;
println!("{}", serving_banner(host, port));
axum::serve(listener, router).await
}
pub async fn serve_multi(entries: Vec<ServeDataset>, host: &str, port: u16) -> std::io::Result<()> {
let router = build_router_multi(entries);
let listener = tokio::net::TcpListener::bind((host, port)).await?;
println!("{}", serving_banner(host, port));
axum::serve(listener, router).await
}
pub fn serve_dataset(
id: impl Into<String>,
label: impl Into<String>,
ds: FrameDataset,
backend: Arc<dyn Backend>,
series: Option<(FrameDataset, Arc<dyn Backend>)>,
series_follows_treemap: bool,
) -> ServeDataset {
ServeDataset {
id: id.into(),
label: label.into(),
app: make_app(ds, backend, series, series_follows_treemap),
}
}
type ApiResult = Result<Json<Value>, (StatusCode, String)>;
fn bad(e: impl std::fmt::Display) -> (StatusCode, String) {
(StatusCode::BAD_REQUEST, e.to_string())
}
const SERIES_AGGS: &[&str] = &["sum", "mean", "median", "min", "max", "count"];
const RESOLUTIONS: &[&str] = &["d", "w", "m", "q", "y"];
const MAX_TOP_K: i64 = 100;
fn resolve_resolution(
sds: &FrameDataset,
requested: Option<&str>,
) -> Result<String, (StatusCode, String)> {
let allowed: Vec<&str> = match sds.series_resolutions.as_deref() {
Some(rs) if !rs.is_empty() => rs.iter().map(String::as_str).collect(),
_ => RESOLUTIONS.to_vec(),
};
match requested {
Some(r) if !r.is_empty() => {
if !allowed.contains(&r) {
return Err(bad(format!("unknown resolution {r:?}")));
}
Ok(r.to_string())
}
_ => Ok(allowed.first().copied().unwrap_or("d").to_string()),
}
}
fn check_filters(
ds: &FrameDataset,
filters: &Map<String, Value>,
) -> Result<(), (StatusCode, String)> {
for key in filters.keys() {
let base = key
.strip_suffix("_min")
.or_else(|| key.strip_suffix("_max"))
.unwrap_or(key);
if !ds.filters.iter().any(|f| f.id == base) {
return Err(bad(format!("unknown filter {key:?}")));
}
}
Ok(())
}
async fn blocking<T, F>(f: F) -> Result<T, (StatusCode, String)>
where
F: FnOnce() -> Result<T, (StatusCode, String)> + Send + 'static,
T: Send + 'static,
{
match tokio::task::spawn_blocking(f).await {
Ok(r) => r,
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
"engine task panicked".into(),
)),
}
}
async fn h_manifest(State(app): State<Shared>) -> Response {
(
[(header::CACHE_CONTROL, "public, max-age=60")],
Json(app.manifest.clone()),
)
.into_response()
}
fn obj(body: &Value, key: &str) -> Map<String, Value> {
body.get(key)
.and_then(|v| v.as_object())
.cloned()
.unwrap_or_default()
}
async fn h_treemap(State(app): State<Shared>, Json(body): Json<Value>) -> ApiResult {
blocking(move || {
let axis = body
.get("axis")
.and_then(|v| v.as_str())
.map(str::to_string)
.or_else(|| app.ds.default_axis.clone())
.ok_or_else(|| bad("no axis"))?;
let levels = app
.ds
.axis(&axis)
.map(|a| a.levels.len())
.ok_or_else(|| bad("unknown axis"))?;
let mut a = TreemapArgs::new(axis.clone());
a.filters = obj(&body, "filters");
check_filters(&app.ds, &a.filters)?;
if let Some(sb) = body
.get("size_by")
.and_then(|v| v.as_str())
.filter(|s| !s.is_empty())
{
if app.ds.metric(sb).is_none() {
return Err(bad(format!("unknown size_by {sb:?}")));
}
a.size_by = Some(sb.to_string());
}
a.focus = body
.get("focus")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().skip(1).cloned().collect())
.unwrap_or_default();
a.depth = body
.get("depth")
.and_then(|v| v.as_i64())
.unwrap_or(levels as i64);
a.top_k = body.get("top_k").and_then(|v| v.as_i64()).unwrap_or(50);
let tree = app.backend.treemap(&app.ds, &a).map_err(bad)?;
frontend_tree(&app.ds, &axis, &tree).ok_or_else(|| bad("unknown axis"))
})
.await
.map(Json)
}
fn window_days(w: &str) -> Option<i64> {
Some(match w {
"1w" => 7,
"1m" => 31,
"ytd" | "1y" => 365,
"3y" => 1095,
"5y" => 1825,
"max" | "all" => 36500,
_ => return None,
})
}
async fn h_series(State(app): State<Shared>, Json(body): Json<Value>) -> ApiResult {
blocking(move || {
let sds = app.series_ds.as_ref().unwrap_or(&app.ds);
let sbackend = app.series_backend.as_ref().unwrap_or(&app.backend);
let axis = body
.get("axis")
.and_then(|v| v.as_str())
.ok_or_else(|| bad("no axis"))?;
let metric = body
.get("metric")
.and_then(|v| v.as_str())
.ok_or_else(|| bad("no metric"))?;
if sds.axis(axis).is_none() {
return Err(bad(format!("unknown axis {axis:?}")));
}
if sds.metric(metric).is_none() {
return Err(bad(format!("unknown metric {metric:?}")));
}
let mut a = SeriesArgs::new(axis, metric);
a.agg = match body.get("agg").and_then(|v| v.as_str()) {
Some(g) if !g.is_empty() => {
if !SERIES_AGGS.contains(&g) {
return Err(bad(format!("unknown agg {g:?}")));
}
Some(g.to_string())
}
_ => None,
};
a.resolution = resolve_resolution(sds, body.get("resolution").and_then(|v| v.as_str()))?;
a.filters = obj(&body, "filters");
check_filters(sds, &a.filters)?;
a.window_days = match body.get("window").and_then(|v| v.as_str()) {
Some(w) if !w.is_empty() => {
Some(window_days(w).ok_or_else(|| bad(format!("unknown window {w:?}")))?)
}
_ => None,
};
if let Some(sb) = body
.get("size_by")
.and_then(|v| v.as_str())
.filter(|s| !s.is_empty())
{
a.size_by = Some(if sds.metric(sb).is_some() {
sb.to_string()
} else {
metric.to_string()
});
}
if let Some(k) = body
.get("top_k")
.and_then(|v| v.as_i64())
.filter(|k| *k > 0)
{
a.top_k = k.min(MAX_TOP_K) as usize;
}
a.focus = body
.get("focus")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().skip(1).cloned().collect())
.unwrap_or_default();
if app.series_follows_treemap && app.ds.axis(axis).is_some() {
let mut ta = TreemapArgs::new(axis);
ta.filters = a.filters.clone();
ta.focus = a.focus.clone();
ta.top_k = a.top_k as i64;
if let Some(sb) = body
.get("size_by")
.and_then(|v| v.as_str())
.filter(|s| !s.is_empty())
{
if app.ds.metric(sb).is_some() {
ta.size_by = Some(sb.to_string());
}
}
let bs = app.backend.branch_set(&app.ds, &ta).map_err(bad)?;
a.branches = Some(bs.keep);
a.include_other = bs.has_other;
}
sbackend.series(sds, &a).map_err(bad)
})
.await
.map(Json)
}
async fn h_geo(State(app): State<Shared>, Json(body): Json<Value>) -> ApiResult {
blocking(move || {
let key = body
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| bad("no key column"))?
.to_string();
let metric = body
.get("metric")
.and_then(|v| v.as_str())
.map(str::to_string)
.or_else(|| app.ds.default_size_by.clone())
.ok_or_else(|| bad("no metric"))?;
if app.ds.metric(&metric).is_none() {
return Err(bad(format!("unknown metric {metric:?}")));
}
let filters = obj(&body, "filters");
check_filters(&app.ds, &filters)?;
app.backend
.geo(&app.ds, &key, &metric, &filters)
.map_err(bad)
})
.await
.map(Json)
}
async fn h_scatter(State(app): State<Shared>, Json(body): Json<Value>) -> ApiResult {
blocking(move || {
let tv = app.manifest["views"]["scatter"].clone();
let x = body
.get("x")
.and_then(|v| v.as_str())
.or_else(|| tv["default_x"].as_str())
.ok_or_else(|| bad("no x"))?
.to_string();
let y = body
.get("y")
.and_then(|v| v.as_str())
.or_else(|| tv["default_y"].as_str())
.ok_or_else(|| bad("no y"))?
.to_string();
app.backend
.scatter(&app.ds, &x, &y, &Map::new(), None, None)
.map_err(bad)
})
.await
.map(Json)
}
async fn h_search(
State(app): State<Shared>,
Query(q): Query<std::collections::HashMap<String, String>>,
) -> ApiResult {
let needle = q.get("q").cloned().unwrap_or_default();
let limit = q.get("limit").and_then(|s| s.parse().ok()).unwrap_or(20u32);
let axis = q.get("axis").cloned();
blocking(move || {
app.backend
.search(&app.ds, &needle, axis.as_deref(), limit)
.map_err(bad)
})
.await
.map(Json)
}
async fn h_filter_options(
State(app): State<Shared>,
Path(facet): Path<String>,
Query(q): Query<std::collections::HashMap<String, String>>,
) -> ApiResult {
let query = q.get("q").cloned();
blocking(move || {
app.backend
.filter_options(&app.ds, &facet, query.as_deref(), 200)
.map(Value::Array)
.map_err(bad)
})
.await
.map(Json)
}
async fn h_entity(
State(app): State<Shared>,
Path(id): Path<String>,
Query(q): Query<std::collections::HashMap<String, String>>,
) -> ApiResult {
if let Some(eid) = id.strip_suffix("/ohlc").filter(|e| !e.is_empty()) {
return Err((StatusCode::NOT_FOUND, format!("no ohlc for {eid:?}")));
}
if let Some(eid) = id.strip_suffix("/series").filter(|e| !e.is_empty()) {
let eid = eid.to_string();
return blocking(move || {
let sds = app.series_ds.as_ref().unwrap_or(&app.ds);
let sbackend = app.series_backend.as_ref().unwrap_or(&app.backend);
let metric = q
.get("metric")
.map(String::as_str)
.ok_or_else(|| bad("no metric"))?;
if sds.metric(metric).is_none() {
return Err(bad(format!("unknown metric {metric:?}")));
}
let resolution = resolve_resolution(sds, q.get("resolution").map(String::as_str))?;
let window = match q.get("window").map(String::as_str) {
Some(w) if !w.is_empty() => {
Some(window_days(w).ok_or_else(|| bad(format!("unknown window {w:?}")))?)
}
_ => None,
};
sbackend
.entity_series(sds, &eid, metric, window, &resolution)
.map_err(bad)
})
.await
.map(Json);
}
blocking(
move || match app.backend.detail(&app.ds, &id).map_err(bad)? {
Some(d) => Ok(d),
None => Err((StatusCode::NOT_FOUND, format!("no entity {id:?}"))),
},
)
.await
.map(Json)
}
static ASSETS_DIR: std::sync::OnceLock<Option<std::path::PathBuf>> = std::sync::OnceLock::new();
fn assets_dir() -> Option<&'static std::path::Path> {
ASSETS_DIR
.get_or_init(|| std::env::var_os("TAXA_ASSETS_DIR").map(Into::into))
.as_deref()
}
fn base_path() -> &'static str {
static BASE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
BASE.get_or_init(|| {
let t = std::env::var("TAXA_BASE_PATH")
.unwrap_or_default()
.trim()
.trim_matches('/')
.to_string();
if t.is_empty() {
String::new()
} else {
format!("/{t}")
}
})
}
fn inject_base(html: &[u8], base: &str) -> Vec<u8> {
let s = String::from_utf8_lossy(html);
let Some(i) = s.find("<head>") else {
return html.to_vec();
};
let at = i + "<head>".len();
let js = serde_json::to_string(base).unwrap_or_else(|_| "\"\"".into());
let snippet = format!(
"<base href=\"{base}/\">\
<script>window.__BASE__={js};</script>\
<script type=\"importmap\">{{\"imports\":{{\"/static/\":\"{base}/static/\"}}}}</script>"
);
let mut out = String::with_capacity(s.len() + snippet.len());
out.push_str(&s[..at]);
out.push_str(&snippet);
out.push_str(&s[at..]);
out.into_bytes()
}
pub fn set_assets_dir(dir: impl Into<std::path::PathBuf>) {
let _ = ASSETS_DIR.set(Some(dir.into()));
}
fn dev_asset(rel: &str) -> Option<Vec<u8>> {
let dir = assets_dir()?;
if rel.split('/').any(|s| s == ".." || s.is_empty()) {
return None; }
std::fs::read(dir.join(rel)).ok()
}
fn hex(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
pub(crate) fn etag_matches(if_none_match: &str, etag: &str) -> bool {
let want = etag.trim_start_matches("W/");
if_none_match
.split(',')
.map(|t| t.trim().trim_start_matches("W/"))
.any(|t| t == want || t == "*")
}
fn embedded_response(
content: rust_embed::EmbeddedFile,
mime: &str,
req_headers: &HeaderMap,
) -> Response {
let etag = format!("\"{}\"", hex(content.metadata.sha256_hash().as_ref()));
let inm = req_headers
.get(header::IF_NONE_MATCH)
.and_then(|v| v.to_str().ok());
if inm.is_some_and(|v| etag_matches(v, &etag)) {
return (StatusCode::NOT_MODIFIED, [(header::ETAG, etag)]).into_response();
}
(
[
(header::CONTENT_TYPE, mime),
(header::CACHE_CONTROL, "no-cache"),
(header::ETAG, &etag),
],
content.data,
)
.into_response()
}
async fn h_asset(Path(path): Path<String>, req_headers: HeaderMap) -> Response {
let mime = mime_guess::from_path(&path).first_or_octet_stream();
if let Some(bytes) = dev_asset(&path) {
return (
[
(header::CONTENT_TYPE, mime.as_ref()),
(header::CACHE_CONTROL, "no-store"),
],
bytes,
)
.into_response();
}
if let Some(resp) = geo::serve_geo_asset(&path, &req_headers) {
return resp;
}
match Assets::get(&path) {
Some(content) => embedded_response(content, mime.as_ref(), &req_headers),
None => (StatusCode::NOT_FOUND, "not found").into_response(),
}
}
#[cfg(test)]
mod etag_tests {
use super::etag_matches;
#[test]
fn strong_and_weak_validators_match() {
assert!(etag_matches("\"abc123\"", "\"abc123\""));
assert!(etag_matches(
"W/\"2415-1781732769\"",
"W/\"2415-1781732769\""
));
assert!(etag_matches("\"2415-1\"", "W/\"2415-1\""));
assert!(etag_matches("*", "W/\"whatever\""));
assert!(!etag_matches("\"other\"", "\"abc123\""));
assert!(etag_matches(
"\"x\", W/\"2415-1781732769\"",
"W/\"2415-1781732769\""
));
}
}
async fn h_index(_uri: Uri, req_headers: HeaderMap) -> Response {
let raw =
dev_asset("index.html").or_else(|| Assets::get("index.html").map(|c| c.data.into_owned()));
let Some(bytes) = raw else {
return (StatusCode::NOT_FOUND, "index.html missing from embed").into_response();
};
let html = inject_base(&bytes, base_path());
use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new();
html.hash(&mut h);
let etag = format!("\"{:016x}\"", h.finish());
if let Some(inm) = req_headers
.get(header::IF_NONE_MATCH)
.and_then(|v| v.to_str().ok())
{
if etag_matches(inm, &etag) {
return (StatusCode::NOT_MODIFIED, [(header::ETAG, etag)]).into_response();
}
}
(
[
(header::CONTENT_TYPE, "text/html"),
(header::CACHE_CONTROL, "no-cache"),
(header::ETAG, &etag),
],
html,
)
.into_response()
}