pub fn pearson_r(x: &[f32], y: &[f32]) -> Option<f32> {
if x.len() != y.len() || x.is_empty() {
return None;
}
let n = x.len() as f64;
let mean_x = x.iter().map(|&v| v as f64).sum::<f64>() / n;
let mean_y = y.iter().map(|&v| v as f64).sum::<f64>() / n;
let mut cov = 0.0f64;
let mut var_x = 0.0f64;
let mut var_y = 0.0f64;
for (&xi, &yi) in x.iter().zip(y.iter()) {
let dx = xi as f64 - mean_x;
let dy = yi as f64 - mean_y;
cov += dx * dy;
var_x += dx * dx;
var_y += dy * dy;
}
let denom = (var_x * var_y).sqrt();
if denom < 1e-15 {
return None;
}
Some((cov / denom) as f32)
}
pub fn pearson_per_vertex(pred: &[Vec<f32>], truth: &[Vec<f32>]) -> Vec<f32> {
if pred.is_empty() || truth.is_empty() {
return Vec::new();
}
let n_t = pred.len().min(truth.len());
let n_v = pred[0].len().min(truth[0].len());
let mut result = vec![0.0f32; n_v];
for vi in 0..n_v {
let x: Vec<f32> = (0..n_t).map(|ti| pred[ti][vi]).collect();
let y: Vec<f32> = (0..n_t).map(|ti| truth[ti][vi]).collect();
result[vi] = pearson_r(&x, &y).unwrap_or(0.0);
}
result
}
pub fn mean_pearson(pred: &[Vec<f32>], truth: &[Vec<f32>]) -> f32 {
let per_vertex = pearson_per_vertex(pred, truth);
if per_vertex.is_empty() {
return 0.0;
}
let valid: Vec<f32> = per_vertex.iter().filter(|&&v| v.is_finite()).copied().collect();
if valid.is_empty() {
return 0.0;
}
valid.iter().sum::<f32>() / valid.len() as f32
}
pub fn median_pearson(pred: &[Vec<f32>], truth: &[Vec<f32>]) -> f32 {
let mut per_vertex: Vec<f32> = pearson_per_vertex(pred, truth)
.into_iter()
.filter(|v| v.is_finite())
.collect();
if per_vertex.is_empty() {
return 0.0;
}
per_vertex.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mid = per_vertex.len() / 2;
if per_vertex.len() % 2 == 0 {
(per_vertex[mid - 1] + per_vertex[mid]) / 2.0
} else {
per_vertex[mid]
}
}
pub fn mse(pred: &[Vec<f32>], truth: &[Vec<f32>]) -> f32 {
if pred.is_empty() || truth.is_empty() {
return 0.0;
}
let n_t = pred.len().min(truth.len());
let n_v = pred[0].len().min(truth[0].len());
let mut total = 0.0f64;
let mut count = 0usize;
for ti in 0..n_t {
for vi in 0..n_v {
let diff = pred[ti][vi] as f64 - truth[ti][vi] as f64;
total += diff * diff;
count += 1;
}
}
if count == 0 { 0.0 } else { (total / count as f64) as f32 }
}
pub fn topk_accuracy(pred: &[Vec<f32>], truth: &[Vec<f32>], k: usize) -> f32 {
let n = pred.len().min(truth.len());
if n == 0 || k == 0 {
return 0.0;
}
let mut correct = 0usize;
for ti in 0..n {
let mut sims: Vec<(usize, f32)> = (0..n)
.map(|pi| {
let sim = cosine_similarity(&pred[pi], &truth[ti]);
(pi, sim)
})
.collect();
sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if sims.iter().take(k).any(|(idx, _)| *idx == ti) {
correct += 1;
}
}
correct as f32 / n as f32
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
let mut dot = 0.0f64;
let mut norm_a = 0.0f64;
let mut norm_b = 0.0f64;
for i in 0..n {
dot += a[i] as f64 * b[i] as f64;
norm_a += (a[i] as f64) * (a[i] as f64);
norm_b += (b[i] as f64) * (b[i] as f64);
}
let denom = (norm_a * norm_b).sqrt();
if denom < 1e-15 { 0.0 } else { (dot / denom) as f32 }
}
pub fn format_metrics_report(
mean_r: f32,
median_r: f32,
mse_val: f32,
topk_acc: Option<(usize, f32)>,
n_timesteps: usize,
n_vertices: usize,
) -> String {
let mut lines = vec![
format!("Evaluation Metrics"),
format!("{}", "=".repeat(45)),
format!(" Timesteps: {}", n_timesteps),
format!(" Vertices: {}", n_vertices),
format!(" Mean Pearson r: {:.6}", mean_r),
format!(" Median Pearson r: {:.6}", median_r),
format!(" MSE: {:.6}", mse_val),
];
if let Some((k, acc)) = topk_acc {
lines.push(format!(" Top-{} accuracy: {:.4} ({:.1}%)", k, acc, acc * 100.0));
}
lines.join("\n")
}
pub fn load_ground_truth(path: &str, n_vertices: usize) -> anyhow::Result<Vec<Vec<f32>>> {
let bytes = std::fs::read(path)
.map_err(|e| anyhow::anyhow!("failed to read ground truth: {}: {}", path, e))?;
let data: Vec<f32> = bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
if data.len() % n_vertices != 0 {
anyhow::bail!(
"Ground truth file has {} floats, not divisible by {} vertices",
data.len(), n_vertices
);
}
let n_timesteps = data.len() / n_vertices;
let mut result = Vec::with_capacity(n_timesteps);
for ti in 0..n_timesteps {
let start = ti * n_vertices;
result.push(data[start..start + n_vertices].to_vec());
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pearson_perfect_correlation() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let r = pearson_r(&x, &y).unwrap();
assert!((r - 1.0).abs() < 1e-5);
}
#[test]
fn test_pearson_negative_correlation() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let r = pearson_r(&x, &y).unwrap();
assert!((r - (-1.0)).abs() < 1e-5);
}
#[test]
fn test_pearson_zero_variance() {
let x = vec![1.0, 1.0, 1.0];
let y = vec![1.0, 2.0, 3.0];
assert!(pearson_r(&x, &y).is_none());
}
#[test]
fn test_mse_zero() {
let pred = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let truth = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
assert!((mse(&pred, &truth)).abs() < 1e-6);
}
#[test]
fn test_mse_nonzero() {
let pred = vec![vec![1.0, 2.0]];
let truth = vec![vec![2.0, 4.0]];
assert!((mse(&pred, &truth) - 2.5).abs() < 1e-5);
}
#[test]
fn test_topk_accuracy_perfect() {
let data = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let acc = topk_accuracy(&data, &data, 1);
assert!((acc - 1.0).abs() < 1e-5);
}
#[test]
fn test_mean_pearson() {
let pred = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let truth = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let r = mean_pearson(&pred, &truth);
assert!((r - 1.0).abs() < 1e-5);
}
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&a, &a);
assert!((sim - 1.0).abs() < 1e-5);
}
}