Skip to main content

nabled_core/
validation.rs

1//! Shared ndarray validation helpers.
2
3use ndarray::{Array1, Array2};
4
5use crate::errors::ShapeError;
6
7/// Validate a square matrix input.
8///
9/// # Errors
10///
11/// Returns [`ShapeError::EmptyInput`] when `matrix` has zero elements, and
12/// [`ShapeError::NotSquare`] when row and column counts differ.
13pub fn validate_square_matrix<T>(matrix: &Array2<T>) -> Result<(), ShapeError> {
14    if matrix.is_empty() {
15        return Err(ShapeError::EmptyInput);
16    }
17
18    if matrix.nrows() != matrix.ncols() {
19        return Err(ShapeError::NotSquare);
20    }
21
22    Ok(())
23}
24
25/// Validate a square matrix and right-hand-side vector pair.
26///
27/// # Errors
28///
29/// Returns [`ShapeError::EmptyInput`] when `matrix` or `rhs` is empty,
30/// [`ShapeError::NotSquare`] when `matrix` is not square, and
31/// [`ShapeError::DimensionMismatch`] when `rhs.len() != matrix.nrows()`.
32pub fn validate_square_system<T>(matrix: &Array2<T>, rhs: &Array1<T>) -> Result<(), ShapeError> {
33    validate_square_matrix(matrix)?;
34
35    if rhs.is_empty() {
36        return Err(ShapeError::EmptyInput);
37    }
38
39    if matrix.nrows() != rhs.len() {
40        return Err(ShapeError::DimensionMismatch);
41    }
42
43    Ok(())
44}
45
46#[cfg(test)]
47mod tests {
48    use ndarray::{Array1, Array2};
49
50    use super::{validate_square_matrix, validate_square_system};
51    use crate::errors::ShapeError;
52
53    #[test]
54    fn validate_square_matrix_rejects_non_square() {
55        let matrix = Array2::<f64>::zeros((2, 3));
56        assert!(matches!(validate_square_matrix(&matrix), Err(ShapeError::NotSquare)));
57    }
58
59    #[test]
60    fn validate_square_system_rejects_empty_rhs_and_dimension_mismatch() {
61        let matrix = Array2::<f64>::eye(2);
62        let empty_rhs = Array1::<f64>::zeros(0);
63        assert!(matches!(validate_square_system(&matrix, &empty_rhs), Err(ShapeError::EmptyInput)));
64
65        let bad_rhs = Array1::<f64>::zeros(3);
66        assert!(matches!(
67            validate_square_system(&matrix, &bad_rhs),
68            Err(ShapeError::DimensionMismatch)
69        ));
70    }
71
72    #[test]
73    fn validate_square_system_accepts_matching_shapes() {
74        let matrix = Array2::<f64>::eye(3);
75        let rhs = Array1::<f64>::ones(3);
76        assert!(validate_square_system(&matrix, &rhs).is_ok());
77    }
78}