use faer::{Col, ColRef, Mat, MatRef};
use crate::error::{PlsKitError, PlsKitResult};
use crate::linalg::{
compute_n_eff, normalize_weights, standardize1_weighted, standardize_weighted,
};
#[derive(Debug, Clone, Copy)]
pub struct PreprocessInput<'a> {
pub x: Option<MatRef<'a, f64>>,
pub y: Option<ColRef<'a, f64>>,
pub weights: Option<ColRef<'a, f64>>,
}
#[derive(Debug, Clone)]
pub struct PreprocessResult {
pub x_std: Option<(Mat<f64>, Col<f64>, Col<f64>)>,
pub y_std: Option<(Col<f64>, f64, f64)>,
pub weights_normalized: Option<Col<f64>>,
pub n_eff: Option<f64>,
}
pub fn preprocess(input: PreprocessInput<'_>) -> PlsKitResult<PreprocessResult> {
let n_from_x = input.x.map(|x| x.nrows());
let n_from_y = input.y.map(|y| y.nrows());
let n_from_w = input.weights.map(|w| w.nrows());
if let (Some(nx), Some(ny)) = (n_from_x, n_from_y) {
if nx != ny {
return Err(PlsKitError::DimensionMismatch { x: (nx, 0), y: ny });
}
}
if let Some(nw) = n_from_w {
if let Some(nx) = n_from_x {
if nx != nw {
return Err(PlsKitError::DimensionMismatch { x: (nx, 0), y: nw });
}
}
if let Some(ny) = n_from_y {
if ny != nw {
return Err(PlsKitError::DimensionMismatch { x: (nw, 0), y: ny });
}
}
}
let (w_norm, n_eff_val) = match input.weights {
None => (None, None),
Some(w) => {
for i in 0..w.nrows() {
if !w[i].is_finite() {
return Err(PlsKitError::NonFiniteInput);
}
if w[i] < 0.0 {
return Err(PlsKitError::InvalidWeights { reason: "negative" });
}
}
let wn =
normalize_weights(w).ok_or(PlsKitError::InvalidWeights { reason: "all_zero" })?;
let neff = compute_n_eff(w);
(Some(wn), Some(neff))
}
};
let wref = w_norm.as_ref().map(Col::as_ref);
let x_std = input.x.map(|x| standardize_weighted(x, wref));
let y_std = input.y.map(|y| standardize1_weighted(y, wref));
Ok(PreprocessResult {
x_std,
y_std,
weights_normalized: w_norm,
n_eff: n_eff_val,
})
}