use ndarray::{Array2, ArrayView2};
pub(crate) fn compute_spatial_input_scales(x: ArrayView2<'_, f64>) -> Option<Vec<f64>> {
let d = x.ncols();
if d <= 1 {
return None;
}
let n = x.nrows() as f64;
if n < 2.0 {
return None;
}
let mut scales = Vec::with_capacity(d);
for j in 0..d {
let col = x.column(j);
let mean = col.sum() / n;
let var = col.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / (n - 1.0);
scales.push(var.sqrt().max(1e-12));
}
Some(scales)
}
pub(crate) fn apply_input_standardization(x: &mut Array2<f64>, scales: &[f64]) {
for j in 0..x.ncols() {
let inv = 1.0 / scales[j];
x.column_mut(j).mapv_inplace(|v| v * inv);
}
}
fn geometric_mean_scale(scales: &[f64]) -> f64 {
if scales.is_empty() {
return 1.0;
}
let log_mean: f64 = scales.iter().map(|&s| s.ln()).sum::<f64>() / scales.len() as f64;
log_mean.exp()
}
pub(crate) fn compensate_length_scale_for_standardization(
length_scale: f64,
scales: &[f64],
) -> f64 {
let sigma_geom = geometric_mean_scale(scales);
if sigma_geom > 0.0 && sigma_geom.is_finite() {
length_scale / sigma_geom
} else {
length_scale
}
}
pub(crate) fn compensate_optional_length_scale_for_standardization(
length_scale: Option<f64>,
scales: &[f64],
) -> Option<f64> {
length_scale.map(|l| compensate_length_scale_for_standardization(l, scales))
}