#![cfg(feature = "enable-eigensnp-diagnostics")]
use crate::linalg_backends::BackendSVD; use crate::linalg_backends::LinAlgBackendProvider; use ndarray::{Array2, ArrayView1, ArrayView2};
use serde::{Deserialize, Serialize};
use std::f64::INFINITY;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RsvdStepDetail {
pub step_name: String, pub input_matrix_dims: Option<(usize, usize)>, pub output_matrix_dims: Option<(usize, usize)>,
pub fro_norm: Option<f64>, pub condition_number: Option<f64>,
pub orthogonality_error: Option<f64>,
pub svd_reconstruction_error_abs: Option<f64>, pub svd_reconstruction_error_rel: Option<f64>, pub num_singular_values: Option<usize>, pub singular_values_sample: Option<Vec<f64>>,
pub notes: String, }
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PerBlockLocalBasisDiagnostics {
pub block_id: String, pub rsvd_stages: Vec<RsvdStepDetail>,
pub input_x_s_p_dims: Option<(usize, usize)>,
pub input_x_s_p_fro_norm: Option<f64>,
pub input_x_s_p_condition_number: Option<f64>,
pub u_correlation_vs_f64_truth: Option<Vec<f64>>,
pub u_p_dims: Option<(usize, usize)>,
pub u_p_fro_norm: Option<f64>,
pub u_p_condition_number: Option<f64>,
pub u_p_orthogonality_error: Option<f64>,
pub notes: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GlobalPcaDiagnostics {
pub stage_name: String, pub rsvd_stages: Vec<RsvdStepDetail>,
pub initial_scores_correlation_vs_py_truth: Option<Vec<f64>>,
pub notes: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SrPassDetail {
pub pass_num: usize,
pub v_hat_dims: Option<(usize, usize)>,
pub v_hat_orthogonality_error: Option<f64>,
pub s_intermediate_dims: Option<(usize, usize)>,
pub s_intermediate_fro_norm: Option<f64>,
pub s_intermediate_condition_number: Option<f64>,
pub s_intermediate_svd_reconstruction_error_abs: Option<f64>,
pub s_intermediate_svd_reconstruction_error_rel: Option<f64>,
pub s_intermediate_num_singular_values: Option<usize>,
pub s_intermediate_singular_values_sample: Option<Vec<f64>>,
pub u_s_orthogonality_error: Option<f64>,
pub notes: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct FullPcaRunDetailedDiagnostics {
pub per_block_diagnostics: Vec<PerBlockLocalBasisDiagnostics>,
pub c_matrix_dims: Option<(usize, usize)>,
pub c_matrix_fro_norm: Option<f64>,
pub c_std_matrix_dims: Option<(usize, usize)>,
pub c_std_matrix_fro_norm: Option<f64>, pub c_std_col_means_sample: Option<Vec<f64>>, pub c_std_col_std_devs_sample: Option<Vec<f64>>,
pub global_pca_diag: Option<Box<GlobalPcaDiagnostics>>,
pub sr_pass_details: Vec<SrPassDetail>,
pub total_runtime_seconds: Option<f64>, pub notes: String, }
pub fn compute_frob_norm_f32(matrix: &ArrayView2<f32>) -> f32 {
if matrix.is_empty() {
return 0.0;
}
matrix.iter().map(|&x| x * x).sum::<f32>().sqrt()
}
pub fn compute_frob_norm_f64(matrix: &ArrayView2<f64>) -> f64 {
if matrix.is_empty() {
return 0.0;
}
matrix.iter().map(|&x| x * x).sum::<f64>().sqrt()
}
pub fn compute_condition_number_via_svd_f32(matrix: &ArrayView2<f32>) -> Option<f64> {
if matrix.nrows() == 0 || matrix.ncols() == 0 {
return None;
}
let matrix_f64 = matrix.mapv(|x| x as f64);
compute_condition_number_via_svd_f64(&matrix_f64.view())
}
pub fn compute_condition_number_via_svd_f64(matrix: &ArrayView2<f64>) -> Option<f64> {
if matrix.nrows() == 0 || matrix.ncols() == 0 {
return None;
}
let backend_f64 = LinAlgBackendProvider::<f64>::new();
let svd_result = backend_f64.svd_into(matrix.to_owned(), false, false);
let singular_values = match svd_result {
Ok(output) => output.s,
Err(_) => return None, };
if singular_values.is_empty() {
return Some(0.0); }
let sigma_max = singular_values
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let sigma_min_non_zero = singular_values
.iter()
.cloned()
.filter(|&s| s > 1e-12)
.fold(f64::INFINITY, f64::min);
if sigma_min_non_zero == f64::INFINITY || sigma_min_non_zero <= 1e-12 {
return Some(INFINITY); }
if sigma_max == f64::NEG_INFINITY {
return Some(0.0);
}
Some(sigma_max / sigma_min_non_zero)
}
pub fn compute_orthogonality_error_f32(q_matrix: &ArrayView2<f32>) -> Option<f64> {
if q_matrix.nrows() == 0 || q_matrix.ncols() == 0 {
return None;
}
let q_f64 = q_matrix.mapv(|x| x as f64);
compute_orthogonality_error_f64(&q_f64.view())
}
pub fn compute_orthogonality_error_f64(q_matrix: &ArrayView2<f64>) -> Option<f64> {
if q_matrix.nrows() == 0 || q_matrix.ncols() == 0 {
return None;
}
if q_matrix.nrows() < q_matrix.ncols() {
}
let qtq = q_matrix.t().dot(q_matrix);
let identity = Array2::<f64>::eye(qtq.nrows());
let diff = identity - qtq;
Some(compute_frob_norm_f64(&diff.view()))
}
pub fn compute_svd_reconstruction_error_f32(
original_matrix: &ArrayView2<f32>,
u: &ArrayView2<f32>,
s_vec: &ArrayView1<f32>,
vt: &ArrayView2<f32>,
) -> Option<f64> {
if original_matrix.is_empty() {
return None;
}
let original_f64 = original_matrix.mapv(|x| x as f64);
let u_f64 = u.mapv(|x| x as f64);
let s_vec_f64 = s_vec.mapv(|x| x as f64);
let vt_f64 = vt.mapv(|x| x as f64);
compute_svd_reconstruction_error_f64(
&original_f64.view(),
&u_f64.view(),
&s_vec_f64.view(),
&vt_f64.view(),
)
}
pub fn compute_svd_reconstruction_error_f64(
original_matrix: &ArrayView2<f64>,
u: &ArrayView2<f64>,
s_vec: &ArrayView1<f64>,
vt: &ArrayView2<f64>,
) -> Option<f64> {
if original_matrix.is_empty() {
return None;
}
if u.ncols() != s_vec.len() || s_vec.len() != vt.nrows() {
return None;
} if u.nrows() != original_matrix.nrows() || vt.ncols() != original_matrix.ncols() {
return None;
}
let s_diag = Array2::from_diag(s_vec);
let reconstructed_matrix = u.dot(&s_diag).dot(vt);
let diff = original_matrix - &reconstructed_matrix;
let norm_diff = compute_frob_norm_f64(&diff.view());
let norm_original = compute_frob_norm_f64(&original_matrix.view());
if norm_original < 1e-12 {
if norm_diff < 1e-12 {
Some(0.0) } else {
Some(INFINITY) }
} else {
Some(norm_diff / norm_original)
}
}
fn pearson_correlation_f64_single(vec_a: &ArrayView1<f64>, vec_b: &ArrayView1<f64>) -> Option<f64> {
let n = vec_a.len();
if n != vec_b.len() || n < 2 {
return None;
}
let mean_a = vec_a.mean().unwrap_or(0.0);
let mean_b = vec_b.mean().unwrap_or(0.0);
let mut cov_ab = 0.0;
let mut var_a = 0.0;
let mut var_b = 0.0;
for i in 0..n {
let diff_a = vec_a[i] - mean_a;
let diff_b = vec_b[i] - mean_b;
cov_ab += diff_a * diff_b;
var_a += diff_a * diff_a;
var_b += diff_b * diff_b;
}
if var_a < 1e-12 || var_b < 1e-12 {
if var_a < 1e-12 && var_b < 1e-12 {
if (mean_a - mean_b).abs() < 1e-9 {
return Some(1.0);
} else {
return Some(0.0);
} }
return Some(0.0); }
let r = cov_ab / (var_a.sqrt() * var_b.sqrt());
Some(r.clamp(-1.0, 1.0))
}
pub fn compute_matrix_column_correlations_abs(
m1: &ArrayView2<f32>,
m2_f64: &ArrayView2<f64>,
) -> Option<Vec<f64>> {
if m1.dim() != m2_f64.dim() {
return None;
}
if m1.ncols() == 0 {
return Some(Vec::new());
}
if m1.nrows() < 2 {
return None;
}
let num_cols = m1.ncols();
let mut correlations = Vec::with_capacity(num_cols);
let m1_f64 = m1.mapv(|x| x as f64);
for j in 0..num_cols {
let col_a = m1_f64.column(j);
let col_b = m2_f64.column(j);
match pearson_correlation_f64_single(&col_a, &col_b) {
Some(corr) => correlations.push(corr.abs()),
None => return None, }
}
Some(correlations)
}
pub fn sample_singular_values(s_values: &ArrayView1<f32>, count: usize) -> Option<Vec<f32>> {
if s_values.is_empty() || count == 0 {
return Some(Vec::new());
}
if count >= s_values.len() {
return Some(s_values.to_vec());
}
let mut sampled = Vec::with_capacity(count);
let len = s_values.len();
sampled.push(s_values[0]);
if count == 1 {
return Some(sampled);
}
let step = (len - 2) as f64 / (count - 1) as f64;
for i in 1..(count - 1) {
let _original_idx = (i as f64 * step).round() as usize;
let pick_idx_float = i as f64 * (len - 1) as f64 / (count - 1) as f64;
sampled.push(s_values[pick_idx_float.round() as usize]);
}
sampled.push(s_values[len - 1]);
sampled.dedup_by(|a, b| (*a - *b).abs() < 1e-7);
Some(sampled)
}
pub fn sample_singular_values_f64(s_values: &ArrayView1<f64>, count: usize) -> Option<Vec<f64>> {
if s_values.is_empty() || count == 0 {
return Some(Vec::new());
}
if count >= s_values.len() {
return Some(s_values.to_vec());
}
let mut sampled = Vec::with_capacity(count);
let len = s_values.len();
sampled.push(s_values[0]);
if count == 1 {
return Some(sampled);
}
let step = (len - 1) as f64 / (count - 1) as f64;
for i in 1..count {
let pick_idx = (i as f64 * step).round() as usize;
if i == count - 1 {
if sampled.last() != Some(&s_values[len - 1]) {
sampled.push(s_values[len - 1]);
} else if sampled.len() < count
&& pick_idx == len - 1
&& s_values[len - 1] != sampled.last().cloned().unwrap_or(f64::NAN)
{
sampled.push(s_values[len - 1]);
}
} else {
if pick_idx < len - 1
&& (sampled.len() == 0
|| s_values[pick_idx] != sampled.last().cloned().unwrap_or(f64::NAN))
{
sampled.push(s_values[pick_idx]);
} else if pick_idx < len - 1 && sampled.len() < count {
if pick_idx + 1 < len - 1
&& s_values[pick_idx + 1] != sampled.last().cloned().unwrap_or(f64::NAN)
{
sampled.push(s_values[pick_idx + 1]);
}
}
}
}
let len = s_values.len(); if count == 0 || len == 0 {
return Some(Vec::new());
}
if count >= len {
return Some(s_values.to_vec());
}
let mut final_sampled = Vec::with_capacity(count);
final_sampled.push(s_values[0]);
if count > 1 {
for i in 1..(count - 1) {
let idx_float = i as f64 * (len - 1) as f64 / (count - 1) as f64;
let idx = idx_float.round() as usize;
final_sampled.push(s_values[idx.min(len - 1).max(0)]); }
final_sampled.push(s_values[len - 1]); }
final_sampled.dedup_by(|a, b| (*a - *b).abs() < 1e-9);
Some(final_sampled)
}