gam_problem/
finite_validation.rs1use crate::EstimationError;
4use ndarray::Array1;
5
6pub fn ensure_finite_scalar_estimation(name: &str, value: f64) -> Result<(), EstimationError> {
7 if value.is_finite() {
8 Ok(())
9 } else {
10 Err(EstimationError::InvalidInput(format!(
11 "{name} must be finite, got {value}"
12 )))
13 }
14}
15
16pub fn validate_all_finite_estimation<I>(label: &str, values: I) -> Result<(), EstimationError>
17where
18 I: IntoIterator<Item = f64>,
19{
20 for (idx, value) in values.into_iter().enumerate() {
21 if !value.is_finite() {
22 return Err(EstimationError::InvalidInput(format!(
23 "{label}[{idx}] must be finite, got {value}"
24 )));
25 }
26 }
27 Ok(())
28}
29
30#[inline]
31pub fn bail_if_cached_beta_non_finite(beta: &Array1<f64>) -> Result<(), EstimationError> {
32 if beta.iter().any(|v| !v.is_finite()) {
33 return Err(EstimationError::InvalidInput(
34 "cached inner beta contains non-finite entries".to_string(),
35 ));
36 }
37 Ok(())
38}
39
40pub fn ensure_finite_scalar(name: &str, value: f64) -> Result<(), String> {
42 ensure_finite_scalar_estimation(name, value).map_err(|err| err.to_string())
43}
44
45pub fn validate_all_finite<I: IntoIterator<Item = f64>>(
47 label: &str,
48 values: I,
49) -> Result<(), String> {
50 validate_all_finite_estimation(label, values).map_err(|err| err.to_string())
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56
57 #[test]
58 fn ensure_finite_scalar_ok_for_finite() {
59 assert!(ensure_finite_scalar("x", 0.0).is_ok());
60 assert!(ensure_finite_scalar("x", -1.5).is_ok());
61 assert!(ensure_finite_scalar("x", f64::MIN).is_ok());
62 }
63
64 #[test]
65 fn ensure_finite_scalar_err_for_nan() {
66 let e = ensure_finite_scalar("my_value", f64::NAN).unwrap_err();
67 assert!(e.contains("my_value"), "error should mention the name: {e}");
68 }
69
70 #[test]
71 fn ensure_finite_scalar_err_for_inf() {
72 assert!(ensure_finite_scalar("v", f64::INFINITY).is_err());
73 assert!(ensure_finite_scalar("v", f64::NEG_INFINITY).is_err());
74 }
75
76 #[test]
77 fn validate_all_finite_ok_for_finite_slice() {
78 assert!(validate_all_finite("vec", [1.0, 2.0, 3.0]).is_ok());
79 assert!(validate_all_finite("empty", std::iter::empty()).is_ok());
80 }
81
82 #[test]
83 fn validate_all_finite_err_reports_index() {
84 let e = validate_all_finite("arr", [1.0, f64::NAN, 3.0]).unwrap_err();
85 assert!(e.contains("arr[1]"), "error should mention arr[1]: {e}");
86 }
87
88 #[test]
89 fn validate_all_finite_err_reports_inf() {
90 let e = validate_all_finite("data", [0.0, f64::INFINITY]).unwrap_err();
91 assert!(e.contains("data[1]"), "error should mention data[1]: {e}");
92 }
93
94 #[test]
95 fn bail_if_cached_beta_ok_for_finite() {
96 let beta = ndarray::array![1.0, -2.0, 3.0];
97 assert!(bail_if_cached_beta_non_finite(&beta).is_ok());
98 }
99
100 #[test]
101 fn bail_if_cached_beta_err_for_nan() {
102 let beta = ndarray::array![1.0, f64::NAN];
103 assert!(bail_if_cached_beta_non_finite(&beta).is_err());
104 }
105}