atelier_quant 0.0.12

Quantitative Finance Tools & Models for the atelier-rs engine
Documentation
//! `inter_serve` — Load a fitted artifact and serve arrival forecasts.
//!
//! Reads JSON queries from **stdin** (one per line), writes JSON
//! responses to **stdout**.  Diagnostic messages go to stderr.
//!
//! ## Usage
//!
//! ```text
//! cargo run -p atelier_quant --bin inter_serve -- \
//!     --artifact artifacts/inter_arrival/model.json \
//!     --mc-paths 1000 --mc-statistic median
//! ```
//!
//! ## Query format (JSON, one per line)
//!
//! ```json
//! {"events_ms": [0.0, 12.5, 25.1, 30.0], "n_forecast": 10}
//! ```
//!
//! `events_ms`: recent event timestamps in milliseconds (sorted,
//! relative to any common origin).  Used as conditioning history.
//!
//! `n_forecast`: number of future arrival gaps to predict.
//!
//! ## Response format
//!
//! ```json
//! {"forecast_gaps_ms": [8.1, 17.3, ...], "forecast_ts_ms": [38.1, 47.3, ...]}
//! ```
//!
//! `forecast_gaps_ms`: cumulative gaps from the last history event.
//! `forecast_ts_ms`:   absolute timestamps in the same frame as input.
//!
//! Set `--mc-paths 0` for deterministic (conditional-mean) forecasts.

use std::io::{self, BufRead, Write};
use std::path::PathBuf;

use clap::Parser;
use serde::{Deserialize, Serialize};

use atelier_quant::artifact::ModelArtifact;
use atelier_quant::config::EnsembleStatistic;
use atelier_quant::forecast::ensemble_forecast;
use atelier_quant::hawkes::HawkesProcess;

// ── CLI ─────────────────────────────────────────────────────────────

#[derive(Parser)]
#[command(
    name = "inter_serve",
    about = "Load a fitted model artifact and serve arrival forecasts via stdin/stdout."
)]
struct Cli {
    /// Path to the model artifact JSON produced by `inter_fit`.
    #[arg(long)]
    artifact: PathBuf,

    /// Monte-Carlo paths per query (0 = deterministic conditional-mean).
    #[arg(long, default_value_t = 1000)]
    mc_paths: usize,

    /// Ensemble reduction statistic: median, mean, p25, p75.
    #[arg(long, default_value = "median")]
    mc_statistic: String,
}

// ── Query / Response ────────────────────────────────────────────────

#[derive(Deserialize)]
struct Query {
    /// Recent event times in ms (sorted), used as conditioning history.
    events_ms: Vec<f64>,
    /// Number of future events to forecast.
    n_forecast: usize,
}

#[derive(Serialize)]
struct Response {
    /// Cumulative gaps from the last history event (ms).
    forecast_gaps_ms: Vec<f64>,
    /// Absolute forecast timestamps (same frame as `events_ms`).
    forecast_ts_ms: Vec<f64>,
}

#[derive(Serialize)]
struct ErrorResponse {
    error: String,
}

// ── Main ────────────────────────────────────────────────────────────

fn main() {
    let cli = Cli::parse();

    // 1. Load artifact
    let artifact_json = std::fs::read_to_string(&cli.artifact).unwrap_or_else(|e| {
        eprintln!("ERROR: cannot read artifact {:?}: {}", cli.artifact, e);
        std::process::exit(1);
    });

    let artifact: ModelArtifact =
        serde_json::from_str(&artifact_json).unwrap_or_else(|e| {
            eprintln!("ERROR: cannot parse artifact JSON: {}", e);
            std::process::exit(1);
        });

    // 2. Reconstruct Hawkes process
    let p = &artifact.parameters;
    let hp = HawkesProcess::new(p.mu, p.alpha, p.beta).unwrap_or_else(|e| {
        eprintln!("ERROR: invalid Hawkes parameters: {:?}", e);
        std::process::exit(1);
    });

    // 3. Parse ensemble statistic
    let statistic = parse_statistic(&cli.mc_statistic);

    eprintln!(
        "inter_serve: loaded artifact (μ={:.6e}, α={:.6e}, β={:.6e}), mc_paths={}, stat={:?}",
        p.mu, p.alpha, p.beta, cli.mc_paths, statistic
    );
    eprintln!("inter_serve: reading queries from stdin (one JSON per line)…");

    // 4. stdin → process → stdout loop
    let stdin = io::stdin();
    let stdout = io::stdout();
    let mut out = io::BufWriter::new(stdout.lock());

    for line in stdin.lock().lines() {
        let line = match line {
            Ok(l) => l,
            Err(e) => {
                eprintln!("stdin read error: {e}");
                break;
            }
        };

        let trimmed = line.trim();
        if trimmed.is_empty() {
            continue;
        }

        let json = process_query(trimmed, &hp, cli.mc_paths, statistic);
        let _ = writeln!(out, "{json}");
        let _ = out.flush();
    }
}

// ── Helpers ─────────────────────────────────────────────────────────

fn parse_statistic(s: &str) -> EnsembleStatistic {
    match s.to_lowercase().as_str() {
        "median" => EnsembleStatistic::Median,
        "mean" => EnsembleStatistic::Mean,
        "p25" => EnsembleStatistic::P25,
        "p75" => EnsembleStatistic::P75,
        other => {
            eprintln!("ERROR: unknown statistic '{other}', expected median|mean|p25|p75");
            std::process::exit(1);
        }
    }
}

fn process_query(
    json_line: &str,
    hp: &HawkesProcess,
    mc_paths: usize,
    statistic: EnsembleStatistic,
) -> String {
    let query: Query = match serde_json::from_str(json_line) {
        Ok(q) => q,
        Err(e) => {
            return serde_json::to_string(&ErrorResponse {
                error: format!("bad query JSON: {e}"),
            })
            .unwrap();
        }
    };

    if query.events_ms.is_empty() || query.n_forecast == 0 {
        return serde_json::to_string(&ErrorResponse {
            error: "events_ms must be non-empty and n_forecast > 0".into(),
        })
        .unwrap();
    }

    let last_ts = *query.events_ms.last().unwrap();

    let forecast_gaps_ms = if mc_paths == 0 {
        // Deterministic: iterated conditional-mean forecast.
        let abs_ts =
            hp.forecast_conditional_means(last_ts, &query.events_ms, query.n_forecast);
        abs_ts.iter().map(|&t| t - last_ts).collect()
    } else {
        // Stochastic MC ensemble.
        ensemble_forecast(
            hp,
            last_ts,
            &query.events_ms,
            query.n_forecast,
            mc_paths,
            statistic,
        )
    };

    let forecast_ts_ms: Vec<f64> =
        forecast_gaps_ms.iter().map(|&g| last_ts + g).collect();

    serde_json::to_string(&Response {
        forecast_gaps_ms,
        forecast_ts_ms,
    })
    .unwrap()
}