use std::fs;
use std::path::{Path, PathBuf};
use atelier_data::orderbooks::io::ob_parquet::load_parquet_to_ob;
use atelier_data::temporal::{self, TimeResolution};
use atelier_data::trades::io::read_trades_parquet;
use atelier_quant::arrivals::extract::{
extract_orderbook_timestamps, extract_trade_timestamps,
};
use atelier_quant::arrivals::validate::{
MonotonicityResult, check_monotonicity, detect_gaps,
};
use atelier_quant::arrivals::{compute_interarrivals, descriptive_stats};
use atelier_quant::artifact::{
DataMeta, Diagnostics, HawkesParams, ModelArtifact, PoissonBaseline,
};
use atelier_quant::config::FitConfig;
use atelier_quant::forecast::{
ensemble_forecast, forecast_errors, likelihood_ratio_test,
};
use atelier_quant::hawkes::HawkesProcess;
use atelier_quant::hawkes::estimation::{
HawkesEstimationConfig, compensator, estimate_hawkes_mle, time_rescaling_residuals,
};
use atelier_quant::poisson::PoissonProcess;
use atelier_quant::poisson::estimation::{PoissonEstimationConfig, estimate_poisson_mle};
use chrono::Utc;
use clap::Parser;
#[derive(Parser, Debug)]
#[command(name = "inter_fit", version, about = "Batch Hawkes model fit")]
struct Cli {
#[arg(short, long)]
config: PathBuf,
}
fn separator(label: &str) {
println!("\n{:═^72}", format!(" {} ", label));
}
fn row(label: &str, value: impl std::fmt::Display) {
println!(" {:<30} {}", label, value);
}
fn find_latest_parquet(dir: &Path, tag: &str) -> Option<PathBuf> {
let mut best: Option<(PathBuf, std::time::SystemTime)> = None;
let entries = fs::read_dir(dir).ok()?;
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) != Some("parquet") {
continue;
}
let fname = path.file_name().and_then(|f| f.to_str()).unwrap_or("");
if !fname.contains(tag) {
continue;
}
let mtime = entry.metadata().ok().and_then(|m| m.modified().ok());
if let Some(mt) = mtime {
if best.as_ref().is_none_or(|(_, prev)| mt > *prev) {
best = Some((path, mt));
}
}
}
best.map(|(p, _)| p)
}
fn main() {
let cli = Cli::parse();
let config_text = fs::read_to_string(&cli.config).unwrap_or_else(|e| {
eprintln!("ERROR: cannot read config {:?}: {}", cli.config, e);
std::process::exit(1);
});
let cfg: FitConfig = toml::from_str(&config_text).unwrap_or_else(|e| {
eprintln!("ERROR: invalid config: {}", e);
std::process::exit(1);
});
println!("Config loaded from {:?}", cli.config);
separator("1. Resolve Input");
let tag = match cfg.input.data_type.as_str() {
"trades" => "trades",
"orderbook" | "ob" => "ob",
other => {
eprintln!(
"ERROR: unknown data_type {:?}, expected trades|orderbook",
other
);
std::process::exit(1);
}
};
let parquet_path: PathBuf = match cfg.input.selection.as_str() {
"latest" => find_latest_parquet(&cfg.input.path, tag).unwrap_or_else(|| {
eprintln!(
"ERROR: no *{}*.parquet files found in {:?}",
tag, cfg.input.path
);
std::process::exit(1);
}),
filename => cfg.input.path.join(filename),
};
println!(" File: {}", parquet_path.display());
separator("2. Load Parquet");
let is_trades = tag == "trades";
let (timestamps_ns, n_loaded) = if is_trades {
let trades = read_trades_parquet(&parquet_path).unwrap_or_else(|e| {
eprintln!(" ERROR: Failed to load trades parquet: {}", e);
std::process::exit(1);
});
let n = trades.len();
println!(" Loaded {} trades", n);
(extract_trade_timestamps(&trades), n)
} else {
let orderbooks = load_parquet_to_ob(&parquet_path).unwrap_or_else(|e| {
eprintln!(" ERROR: Failed to load OB parquet: {}", e);
std::process::exit(1);
});
let n = orderbooks.len();
println!(" Loaded {} orderbook snapshots", n);
(extract_orderbook_timestamps(&orderbooks), n)
};
if n_loaded < 12 {
eprintln!(
" ERROR: Need at least 12 events (10 test + 2 train), got {}",
n_loaded
);
std::process::exit(1);
}
separator("3. Extract Timestamps");
let mut timestamps_ns = timestamps_ns;
timestamps_ns.sort_unstable();
let before_dedup = timestamps_ns.len();
timestamps_ns.dedup();
let n_dupes = before_dedup - timestamps_ns.len();
println!(" Removed {} duplicate timestamps", n_dupes);
println!(" Total arrivals: {}", timestamps_ns.len());
let first_ms = temporal::from_nanos(
*timestamps_ns.first().unwrap(),
TimeResolution::Milliseconds,
);
let last_ms = temporal::from_nanos(
*timestamps_ns.last().unwrap(),
TimeResolution::Milliseconds,
);
let span_s = (last_ms - first_ms) / 1000.0;
let observation_window_ms = last_ms - first_ms;
println!(
" Observation window: {:.3}s ({:.1} ms)",
span_s, observation_window_ms
);
separator("4. Validation & Gap Detection");
match check_monotonicity(×tamps_ns) {
MonotonicityResult::StrictlyMonotonic => {
println!(" Timestamps are strictly monotonic");
}
MonotonicityResult::Violation { index, prev, curr } => {
eprintln!(
" Monotonicity violation at index {}: {} <= {}",
index, curr, prev
);
std::process::exit(1);
}
}
let gap_threshold_ns = (cfg.model.gap_threshold_secs * 1e9) as u64;
let large_gaps = detect_gaps(×tamps_ns, gap_threshold_ns);
if large_gaps.is_empty() {
println!(
" No gaps exceeding {:.1}s detected",
cfg.model.gap_threshold_secs
);
} else {
println!(" Gaps exceeding {:.1}s:", cfg.model.gap_threshold_secs);
for g in &large_gaps {
println!(
" index {}: gap = {:.3} ms",
g.index,
g.gap_ns as f64 / 1e6
);
}
println!(" Total large gaps: {}", large_gaps.len());
}
separator("5. Train / Test Split");
let n_total = timestamps_ns.len();
let n_train = (n_total as f64 * cfg.model.train_ratio) as usize;
let n_test = n_total - n_train;
let train_ts = ×tamps_ns[..n_train];
let test_ts = ×tamps_ns[n_train..];
println!(" Training set: {} arrivals", train_ts.len());
println!(" Test set: {} arrivals", test_ts.len());
separator("6. Interarrival Statistics (Training)");
let ia_result = compute_interarrivals(train_ts, TimeResolution::Milliseconds)
.unwrap_or_else(|e| {
eprintln!(" ERROR: {}", e);
std::process::exit(1);
});
let stats = descriptive_stats(&ia_result.deltas_f64).unwrap();
row("Count (gaps)", format!("{}", stats.count));
row("Mean (ms)", format!("{:.6}", stats.mean));
row("Std dev (ms)", format!("{:.6}", stats.std_dev));
row("Variance (ms^2)", format!("{:.6}", stats.variance));
row("Min (ms)", format!("{:.6}", stats.min));
row("Max (ms)", format!("{:.6}", stats.max));
row("Skewness", format!("{:.4}", stats.skewness));
row("Excess kurtosis", format!("{:.4}", stats.kurtosis));
row("CV (std/mean)", format!("{:.4}", stats.cv));
if stats.cv > 1.0 {
println!("\n CV > 1: clustering (super-Poisson), consistent with Hawkes.");
} else if (stats.cv - 1.0).abs() < 0.15 {
println!("\n CV ~ 1: near-Poisson (memoryless) arrivals.");
} else {
println!("\n CV < 1: regularity (sub-Poisson).");
}
separator("7. Hawkes MLE Estimation");
let t0_ns = train_ts[0];
let train_events_ms: Vec<f64> = train_ts
.iter()
.map(|&t| temporal::from_nanos(t - t0_ns, TimeResolution::Milliseconds))
.collect();
let hawkes_config = HawkesEstimationConfig {
max_iter: cfg.model.max_iter,
tol: cfg.model.tolerance,
learning_rate: cfg.model.learning_rate,
initial_params: None,
};
let mle = estimate_hawkes_mle(&train_events_ms, &hawkes_config).unwrap_or_else(|e| {
eprintln!(" ERROR: MLE failed: {}", e);
std::process::exit(1);
});
row("mu (events/ms)", format!("{:.8}", mle.mu));
row("alpha (excitation)", format!("{:.8}", mle.alpha));
row("beta (decay 1/ms)", format!("{:.8}", mle.beta));
row("Branching ratio a/b", format!("{:.6}", mle.branching_ratio));
row("Log-likelihood", format!("{:.4}", mle.log_likelihood));
row("AIC", format!("{:.4}", mle.aic));
row("BIC", format!("{:.4}", mle.bic));
row("Iterations", format!("{}", mle.iterations));
row("Converged", format!("{}", mle.converged));
let theoretical_rate = mle.mu / (1.0 - mle.branching_ratio);
row(
"Stationary rate (ev/ms)",
format!("{:.8}", theoretical_rate),
);
row(
"Stationary mean gap (ms)",
format!("{:.6}", 1.0 / theoretical_rate),
);
separator("8. Goodness-of-Fit (Time-Rescaling)");
let residuals =
time_rescaling_residuals(mle.mu, mle.alpha, mle.beta, &train_events_ms);
let res_stats = descriptive_stats(&residuals);
let (residuals_mean, residuals_std) = if let Some(ref rs) = res_stats {
row("Residuals count", format!("{}", rs.count));
row("Residuals mean", format!("{:.6}", rs.mean));
row("Residuals std dev", format!("{:.6}", rs.std_dev));
println!(
"\n Under correct specification, residuals ~ Exp(1): mean ~ 1.0, std ~ 1.0"
);
(rs.mean, rs.std_dev)
} else {
(f64::NAN, f64::NAN)
};
let t_end = *train_events_ms.last().unwrap();
let comp_end = compensator(mle.mu, mle.alpha, mle.beta, &train_events_ms, t_end);
let expected_comp = (train_events_ms.len() - 1) as f64;
let compensator_ratio = comp_end / expected_comp;
row(
"Compensator(T) / (n-1)",
format!("{:.4}", compensator_ratio),
);
separator("9. Forecast vs Actual");
let last_train_ms = *train_events_ms.last().unwrap();
let hp = HawkesProcess::new(mle.mu, mle.alpha, mle.beta).unwrap_or_else(|e| {
eprintln!(" ERROR: Could not create HawkesProcess: {:?}", e);
std::process::exit(1);
});
let actual_arrivals_ms: Vec<f64> = test_ts
.iter()
.map(|&t| temporal::from_nanos(t - t0_ns, TimeResolution::Milliseconds))
.collect();
let actual_gaps: Vec<f64> = actual_arrivals_ms
.iter()
.map(|&t| t - last_train_ms)
.collect();
let mc_paths = cfg.forecast.mc_paths.max(1);
if mc_paths == 1 {
println!(" ── Stochastic (single path, Ogata thinning) ──");
} else {
println!(
" ── Stochastic ensemble ({} paths, {:?}) ──",
mc_paths, cfg.forecast.mc_statistic
);
}
let stoch_gaps = ensemble_forecast(
&hp,
last_train_ms,
&train_events_ms,
n_test,
mc_paths,
cfg.forecast.mc_statistic,
);
let stoch_metrics = forecast_errors(&actual_gaps, &stoch_gaps);
row("MAE (ms)", format!("{:.6}", stoch_metrics.mae));
row("RMSE (ms)", format!("{:.6}", stoch_metrics.rmse));
println!("\n ── Deterministic (conditional mean) ──");
let det_forecasted_ms =
hp.forecast_conditional_means(last_train_ms, &train_events_ms, n_test);
let det_forecast_gaps: Vec<f64> = det_forecasted_ms
.iter()
.map(|&t| t - last_train_ms)
.collect();
let det_metrics = forecast_errors(&actual_gaps, &det_forecast_gaps);
row("MAE (ms)", format!("{:.6}", det_metrics.mae));
row("RMSE (ms)", format!("{:.6}", det_metrics.rmse));
separator("10. Poisson Baseline");
let poisson_config = PoissonEstimationConfig;
let poisson_mle = estimate_poisson_mle(&train_events_ms, &poisson_config)
.unwrap_or_else(|e| {
eprintln!(" ERROR: Poisson MLE failed: {}", e);
std::process::exit(1);
});
row("lambda (events/ms)", format!("{:.8}", poisson_mle.lambda));
row(
"Log-likelihood",
format!("{:.4}", poisson_mle.log_likelihood),
);
row("AIC", format!("{:.4}", poisson_mle.aic));
row("BIC", format!("{:.4}", poisson_mle.bic));
let lr = likelihood_ratio_test(mle.log_likelihood, poisson_mle.log_likelihood, 2);
row("LR statistic", format!("{:.4}", lr.statistic));
row("chi2(2) crit (0.05)", format!("{:.3}", lr.critical_value));
if lr.reject_h0 {
println!(" -> REJECT H0: Hawkes excitation is statistically significant.");
} else {
println!(" -> FAIL TO REJECT H0: Poisson is sufficient.");
}
separator("11. Model Comparison");
println!(
" {:<24} {:>14} {:>14} {:>14}",
"Metric", "Hawkes(stoch)", "Hawkes(det)", "Poisson"
);
println!(" {}", "-".repeat(66));
println!(
" {:<24} {:>14.4} {:>14} {:>14.4}",
"Log-likelihood", mle.log_likelihood, "\u{2014}", poisson_mle.log_likelihood
);
println!(
" {:<24} {:>14.4} {:>14} {:>14.4}",
"AIC", mle.aic, "\u{2014}", poisson_mle.aic
);
println!(
" {:<24} {:>14.4} {:>14} {:>14.4}",
"BIC", mle.bic, "\u{2014}", poisson_mle.bic
);
let pp = PoissonProcess::new(poisson_mle.lambda).unwrap_or_else(|e| {
eprintln!(" ERROR: Could not create PoissonProcess: {:?}", e);
std::process::exit(1);
});
let poisson_forecast_ms = pp.generate_values(last_train_ms, n_test);
let poisson_forecast_gaps: Vec<f64> = poisson_forecast_ms
.iter()
.map(|&t| t - last_train_ms)
.collect();
let poisson_metrics = forecast_errors(&actual_gaps, &poisson_forecast_gaps);
println!(
" {:<24} {:>14.4} {:>14.4} {:>14.4}",
"Forecast MAE (ms)", stoch_metrics.mae, det_metrics.mae, poisson_metrics.mae
);
println!(
" {:<24} {:>14.4} {:>14.4} {:>14.4}",
"Forecast RMSE (ms)", stoch_metrics.rmse, det_metrics.rmse, poisson_metrics.rmse
);
separator("12. Write Model Artifact");
let diagnostics = if cfg.output.include_diagnostics {
Some(Diagnostics {
branching_ratio: mle.branching_ratio,
log_likelihood: mle.log_likelihood,
aic: mle.aic,
bic: mle.bic,
converged: mle.converged,
iterations: mle.iterations,
residuals_mean,
residuals_std,
compensator_ratio,
})
} else {
None
};
let artifact = ModelArtifact {
model: "hawkes_exponential".into(),
version: 1,
fitted_at: Utc::now(),
parameters: HawkesParams {
mu: mle.mu,
alpha: mle.alpha,
beta: mle.beta,
},
diagnostics,
data: DataMeta {
source: parquet_path.display().to_string(),
n_events_train: n_train,
n_events_test: n_test,
observation_window_ms,
time_resolution: "Milliseconds".into(),
},
baseline: Some(PoissonBaseline {
model: "poisson".into(),
lambda: poisson_mle.lambda,
aic: poisson_mle.aic,
lr_statistic: lr.statistic,
lr_reject_h0: lr.reject_h0,
}),
};
fs::create_dir_all(&cfg.output.artifact_dir).unwrap_or_else(|e| {
eprintln!(
" ERROR: cannot create artifact dir {:?}: {}",
cfg.output.artifact_dir, e
);
std::process::exit(1);
});
let ts_tag = Utc::now().format("%Y%m%d_%H%M%S");
let artifact_filename = format!("hawkes_fit_{}.json", ts_tag);
let artifact_path = cfg.output.artifact_dir.join(&artifact_filename);
let json = serde_json::to_string_pretty(&artifact).unwrap_or_else(|e| {
eprintln!(" ERROR: serialization failed: {}", e);
std::process::exit(1);
});
fs::write(&artifact_path, &json).unwrap_or_else(|e| {
eprintln!(" ERROR: cannot write artifact: {}", e);
std::process::exit(1);
});
println!(" Artifact written to: {}", artifact_path.display());
println!(" Size: {} bytes", json.len());
separator("Done");
}