use anyhow::Result;
use clap::Parser;
use egobox_moe::GpMetric;
use egobox_moe::GpMixture;
use egobox_moe::MixtureGpSurrogate;
use rayon::prelude::*;
use std::collections::HashMap;
use std::fs;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
filename: String,
#[arg(short, long, default_value = "0")]
kfold: usize,
}
#[derive(Debug, Clone)]
struct Metrics {
pub q2: f64,
pub pva: f64,
pub iae_alpha: f64,
}
fn compute_metrics(gp_models: &[Box<dyn MixtureGpSurrogate>], kfold: usize) -> Vec<Metrics> {
let mut res: Vec<_> = gp_models
.par_iter()
.enumerate()
.map(|(i, gp)| {
let scores: Vec<_> = [GpMetric::Q2, GpMetric::Pva, GpMetric::IAEAlphaWithPlot]
.par_iter()
.map(|m| {
let score = gp.as_ref().score(*m, kfold);
(m, score)
})
.collect();
let scores: HashMap<_, _> = scores.into_iter().collect();
if i == 0
&& let Some(data) = &scores.get(&GpMetric::IAEAlphaWithPlot).unwrap().plot_data
{
println!("\nIAEα plot data for first GP model:");
println!("Alpha | Empirical coverage | Target coverage | Delta");
println!("---------------------------------------------------");
for i in 0..data.alphas.len() {
let alpha = data.alphas[i];
let delta = data.deltas[i];
println!(
"{:5.2}% | {:5.2}% | {:5.2}% | {:5.2}%",
alpha * 100.,
delta * 100.,
(1. - alpha) * 100.,
(delta - (1. - alpha)).abs() * 100.
);
}
println!();
}
(
i,
Metrics {
q2: scores.get(&GpMetric::Q2).unwrap().value,
pva: scores.get(&GpMetric::Pva).unwrap().value,
iae_alpha: scores.get(&GpMetric::IAEAlphaWithPlot).unwrap().value,
},
)
})
.collect();
res.sort_by_key(|(i, _)| *i);
res.into_iter().map(|(_, m)| m).collect::<Vec<_>>()
}
fn main() -> Result<()> {
let args = Args::parse();
let data: Vec<u8> = fs::read(&args.filename)?;
let gp_models: Vec<Box<dyn MixtureGpSurrogate>> =
bincode::serde::decode_from_slice(&data, bincode::config::standard())
.map(|(res, _)| res)
.unwrap_or_default();
let gp_models = if gp_models.is_empty() {
let gp: Box<GpMixture> =
bincode::serde::decode_from_slice(&data, bincode::config::standard())
.map(|(res, _)| res)?;
vec![gp as Box<dyn MixtureGpSurrogate>]
} else {
gp_models
};
println!(
"Loaded {} GP model(s) from {}",
gp_models.len(),
args.filename
);
gp_models.iter().for_each(|gp| {
println!("Loaded GP model: {}", gp);
});
let (xt, _yt) = gp_models.first().unwrap().training_data();
println!("Training data: {} samples ({}-dim)", xt.nrows(), xt.ncols());
let k = if args.kfold == 0 {
xt.nrows()
} else {
args.kfold
};
let res: Vec<_> = compute_metrics(&gp_models, k);
for (i, m) in res.iter().enumerate() {
println!(
"GP({}): Q2 = {:.2}, PVA = {:.2}, IAEα = {:.2}",
i, m.q2, m.pva, m.iae_alpha
);
}
Ok(())
}