use crate::error::FdarError;
use crate::matrix::FdMatrix;
pub fn validate_fdata(data: &FdMatrix, argvals: &[f64]) -> Result<(usize, usize), FdarError> {
let (n, m) = data.shape();
if n == 0 {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: "n > 0 rows".to_string(),
actual: format!("n = {n}"),
});
}
if m == 0 {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: "m > 0 columns".to_string(),
actual: format!("m = {m}"),
});
}
if argvals.len() != m {
return Err(FdarError::InvalidDimension {
parameter: "argvals",
expected: format!("{m} elements"),
actual: format!("{} elements", argvals.len()),
});
}
Ok((n, m))
}
pub fn validate_response(y: &[f64], n: usize) -> Result<(), FdarError> {
if y.len() != n {
return Err(FdarError::InvalidDimension {
parameter: "y",
expected: format!("{n} elements"),
actual: format!("{} elements", y.len()),
});
}
Ok(())
}
pub fn validate_labels(y: &[usize], n: usize, min_classes: usize) -> Result<usize, FdarError> {
if y.len() != n {
return Err(FdarError::InvalidDimension {
parameter: "y",
expected: format!("{n} elements"),
actual: format!("{} elements", y.len()),
});
}
let n_classes = y.iter().copied().max().map_or(0, |m| m + 1);
if n_classes < min_classes {
return Err(FdarError::InvalidParameter {
parameter: "y",
message: format!("need at least {min_classes} classes, got {n_classes}"),
});
}
Ok(n_classes)
}
pub fn validate_dist_mat(
dist_mat: &FdMatrix,
expected_n: Option<usize>,
) -> Result<usize, FdarError> {
let n = dist_mat.nrows();
if dist_mat.ncols() != n {
return Err(FdarError::InvalidDimension {
parameter: "dist_mat",
expected: format!("{n} x {n} (square)"),
actual: format!("{} x {}", n, dist_mat.ncols()),
});
}
if let Some(exp) = expected_n {
if n != exp {
return Err(FdarError::InvalidDimension {
parameter: "dist_mat",
expected: format!("{exp} x {exp}"),
actual: format!("{n} x {n}"),
});
}
}
Ok(n)
}
pub fn validate_ncomp(ncomp: usize, n: usize, m: usize) -> Result<usize, FdarError> {
if ncomp == 0 {
return Err(FdarError::InvalidParameter {
parameter: "ncomp",
message: "must be >= 1".to_string(),
});
}
Ok(ncomp.min(n).min(m))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fdata_ok() {
let data = FdMatrix::zeros(10, 50);
let t: Vec<f64> = (0..50).map(|i| i as f64).collect();
let (n, m) = validate_fdata(&data, &t).unwrap();
assert_eq!((n, m), (10, 50));
}
#[test]
fn fdata_zero_rows() {
let data = FdMatrix::zeros(0, 5);
let t = vec![0.0; 5];
assert!(validate_fdata(&data, &t).is_err());
}
#[test]
fn fdata_zero_cols() {
let data = FdMatrix::zeros(5, 0);
assert!(validate_fdata(&data, &[]).is_err());
}
#[test]
fn fdata_argvals_mismatch() {
let data = FdMatrix::zeros(5, 10);
let t = vec![0.0; 8];
assert!(validate_fdata(&data, &t).is_err());
}
#[test]
fn response_ok() {
validate_response(&[1.0, 2.0, 3.0], 3).unwrap();
}
#[test]
fn response_mismatch() {
assert!(validate_response(&[1.0, 2.0], 3).is_err());
}
#[test]
fn labels_ok() {
let nc = validate_labels(&[0, 1, 0, 1], 4, 2).unwrap();
assert_eq!(nc, 2);
}
#[test]
fn labels_too_few_classes() {
assert!(validate_labels(&[0, 0, 0], 3, 2).is_err());
}
#[test]
fn labels_length_mismatch() {
assert!(validate_labels(&[0, 1], 3, 2).is_err());
}
#[test]
fn dist_mat_ok() {
let dm = FdMatrix::zeros(5, 5);
assert_eq!(validate_dist_mat(&dm, None).unwrap(), 5);
assert_eq!(validate_dist_mat(&dm, Some(5)).unwrap(), 5);
}
#[test]
fn dist_mat_not_square() {
let dm = FdMatrix::zeros(5, 3);
assert!(validate_dist_mat(&dm, None).is_err());
}
#[test]
fn dist_mat_wrong_size() {
let dm = FdMatrix::zeros(5, 5);
assert!(validate_dist_mat(&dm, Some(4)).is_err());
}
#[test]
fn ncomp_ok() {
assert_eq!(validate_ncomp(5, 10, 20).unwrap(), 5);
}
#[test]
fn ncomp_clamped() {
assert_eq!(validate_ncomp(100, 10, 20).unwrap(), 10);
assert_eq!(validate_ncomp(100, 20, 10).unwrap(), 10);
}
#[test]
fn ncomp_zero() {
assert!(validate_ncomp(0, 10, 20).is_err());
}
}