Skip to main content

csaps/ndg/
validate.rs

1use ndarray::{
2    ArrayView,
3    ArrayView1,
4    Dimension,
5};
6
7use crate::{Real, Result, CsapsError::InvalidInputData};
8use crate::validate::{validate_data_sites, validate_smooth_value};
9
10use super::GridCubicSmoothingSpline;
11
12
13impl<'a, T, D> GridCubicSmoothingSpline<'a, T, D>
14    where
15        T: Real,
16        D: Dimension
17{
18    pub(super) fn make_validate(&self) -> Result<()> {
19        validate_xy(&self.x, self.y.view())?;
20        validate_weights(&self.x, &self.weights)?;
21        validate_smooth(&self.x, &self.smooth)?;
22
23        Ok(())
24    }
25
26    pub(super) fn evaluate_validate(&self, xi: &[ArrayView1<'a, T>]) -> Result<()> {
27        let x_len = self.x.len();
28        let xi_len = xi.len();
29
30        if xi_len != x_len {
31            return Err(
32                InvalidInputData(
33                    format!("The number of `xi` vectors ({}) is not equal to the number of dimensions ({})",
34                            xi_len, x_len)
35                )
36            )
37        }
38
39        for xi_ax in xi.iter() {
40            if xi_ax.is_empty() {
41                return Err(
42                    InvalidInputData(
43                        "The sizes of `xi` vectors must be greater or equal to 1".to_string()
44                    )
45                )
46            }
47        }
48
49        Ok(())
50    }
51}
52
53
54pub(super) fn validate_xy<T, D>(x: &[ArrayView1<'_, T>], y: ArrayView<'_, T, D>) -> Result<()>
55    where
56        T: Real,
57        D: Dimension
58{
59    if x.len() != y.ndim() {
60        return Err(
61            InvalidInputData(
62                format!("The number of `x` data site vectors ({}) is not equal to `y` data dimensionality ({})",
63                        x.len(), y.ndim())
64            )
65        )
66    }
67
68    for (ax, (&xi, &ys)) in x
69        .iter()
70        .zip(y.shape().iter())
71        .enumerate()
72    {
73        let xi_len = xi.len();
74
75        if xi_len < 2 {
76            return Err(
77                InvalidInputData(
78                    format!("The size of `x` site vectors must be greater or equal to 2, axis {}", ax)
79                )
80            )
81        }
82
83        validate_data_sites(xi.view())?;
84
85        if xi_len != ys {
86            return Err(
87                InvalidInputData(
88                    format!("`x` data sites vector size ({}) is not equal to `y` data size ({}) for axis {}",
89                            xi_len, ys, ax)
90                )
91            )
92        }
93    }
94
95    Ok(())
96}
97
98
99pub(super) fn validate_weights<T>(x: &[ArrayView1<'_, T>], w: &[Option<ArrayView1<'_, T>>]) -> Result<()>
100    where
101        T: Real
102{
103    let x_len = x.len();
104    let w_len = w.len();
105
106    if w_len != x_len {
107        return Err(
108            InvalidInputData(
109                format!("The number of `weights` vectors ({}) is not equal to the number of dimensions ({})",
110                        w_len, x_len)
111            )
112        )
113    }
114
115    for (ax, (xi, wi)) in x.iter().zip(w.iter()).enumerate() {
116        if let Some(wi_view) = wi {
117            let xi_len = xi.len();
118            let wi_len = wi_view.len();
119
120            if wi_len != xi_len {
121                return Err(
122                    InvalidInputData(
123                        format!("`weights` vector size ({}) is not equal to `x` vector size ({}) for axis {}",
124                                wi_len, xi_len, ax)
125                    )
126                )
127            }
128        }
129    }
130
131    Ok(())
132}
133
134
135pub(super) fn validate_smooth<T>(x: &[ArrayView1<'_, T>], smooth: &[Option<T>]) -> Result<()>
136    where
137        T: Real
138{
139    let x_len = x.len();
140    let s_len = smooth.len();
141
142    if s_len != x_len {
143        return Err(
144            InvalidInputData(
145                format!("The number of `smooth` values ({}) is not equal to the number of dimensions ({})",
146                        s_len, x_len)
147            )
148        )
149    }
150
151    for (ax, s_opt) in smooth.iter().enumerate() {
152        if let Some(s) = s_opt {
153            if let Err(err) = validate_smooth_value(*s) {
154                return Err(InvalidInputData(format!("{} for axis {}", err, ax)))
155            };
156        }
157    }
158
159    Ok(())
160}