use std::io::{self, BufRead, Write};
use std::path::PathBuf;
use clap::Parser;
use serde::{Deserialize, Serialize};
use atelier_quant::artifact::ModelArtifact;
use atelier_quant::config::EnsembleStatistic;
use atelier_quant::forecast::ensemble_forecast;
use atelier_quant::hawkes::HawkesProcess;
#[derive(Parser)]
#[command(
name = "inter_serve",
about = "Load a fitted model artifact and serve arrival forecasts via stdin/stdout."
)]
struct Cli {
#[arg(long)]
artifact: PathBuf,
#[arg(long, default_value_t = 1000)]
mc_paths: usize,
#[arg(long, default_value = "median")]
mc_statistic: String,
}
#[derive(Deserialize)]
struct Query {
events_ms: Vec<f64>,
n_forecast: usize,
}
#[derive(Serialize)]
struct Response {
forecast_gaps_ms: Vec<f64>,
forecast_ts_ms: Vec<f64>,
}
#[derive(Serialize)]
struct ErrorResponse {
error: String,
}
fn main() {
let cli = Cli::parse();
let artifact_json = std::fs::read_to_string(&cli.artifact).unwrap_or_else(|e| {
eprintln!("ERROR: cannot read artifact {:?}: {}", cli.artifact, e);
std::process::exit(1);
});
let artifact: ModelArtifact =
serde_json::from_str(&artifact_json).unwrap_or_else(|e| {
eprintln!("ERROR: cannot parse artifact JSON: {}", e);
std::process::exit(1);
});
let p = &artifact.parameters;
let hp = HawkesProcess::new(p.mu, p.alpha, p.beta).unwrap_or_else(|e| {
eprintln!("ERROR: invalid Hawkes parameters: {:?}", e);
std::process::exit(1);
});
let statistic = parse_statistic(&cli.mc_statistic);
eprintln!(
"inter_serve: loaded artifact (μ={:.6e}, α={:.6e}, β={:.6e}), mc_paths={}, stat={:?}",
p.mu, p.alpha, p.beta, cli.mc_paths, statistic
);
eprintln!("inter_serve: reading queries from stdin (one JSON per line)…");
let stdin = io::stdin();
let stdout = io::stdout();
let mut out = io::BufWriter::new(stdout.lock());
for line in stdin.lock().lines() {
let line = match line {
Ok(l) => l,
Err(e) => {
eprintln!("stdin read error: {e}");
break;
}
};
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let json = process_query(trimmed, &hp, cli.mc_paths, statistic);
let _ = writeln!(out, "{json}");
let _ = out.flush();
}
}
fn parse_statistic(s: &str) -> EnsembleStatistic {
match s.to_lowercase().as_str() {
"median" => EnsembleStatistic::Median,
"mean" => EnsembleStatistic::Mean,
"p25" => EnsembleStatistic::P25,
"p75" => EnsembleStatistic::P75,
other => {
eprintln!("ERROR: unknown statistic '{other}', expected median|mean|p25|p75");
std::process::exit(1);
}
}
}
fn process_query(
json_line: &str,
hp: &HawkesProcess,
mc_paths: usize,
statistic: EnsembleStatistic,
) -> String {
let query: Query = match serde_json::from_str(json_line) {
Ok(q) => q,
Err(e) => {
return serde_json::to_string(&ErrorResponse {
error: format!("bad query JSON: {e}"),
})
.unwrap();
}
};
if query.events_ms.is_empty() || query.n_forecast == 0 {
return serde_json::to_string(&ErrorResponse {
error: "events_ms must be non-empty and n_forecast > 0".into(),
})
.unwrap();
}
let last_ts = *query.events_ms.last().unwrap();
let forecast_gaps_ms = if mc_paths == 0 {
let abs_ts =
hp.forecast_conditional_means(last_ts, &query.events_ms, query.n_forecast);
abs_ts.iter().map(|&t| t - last_ts).collect()
} else {
ensemble_forecast(
hp,
last_ts,
&query.events_ms,
query.n_forecast,
mc_paths,
statistic,
)
};
let forecast_ts_ms: Vec<f64> =
forecast_gaps_ms.iter().map(|&g| last_ts + g).collect();
serde_json::to_string(&Response {
forecast_gaps_ms,
forecast_ts_ms,
})
.unwrap()
}