gam_problem/
finite_validation.rs1use crate::EstimationError;
4use ndarray::Array1;
5
6pub fn ensure_finite_scalar_estimation(
7 name: &str,
8 value: f64,
9) -> Result<(), EstimationError> {
10 if value.is_finite() {
11 Ok(())
12 } else {
13 Err(EstimationError::InvalidInput(format!(
14 "{name} must be finite, got {value}"
15 )))
16 }
17}
18
19pub fn validate_all_finite_estimation<I>(
20 label: &str,
21 values: I,
22) -> Result<(), EstimationError>
23where
24 I: IntoIterator<Item = f64>,
25{
26 for (idx, value) in values.into_iter().enumerate() {
27 if !value.is_finite() {
28 return Err(EstimationError::InvalidInput(format!(
29 "{label}[{idx}] must be finite, got {value}"
30 )));
31 }
32 }
33 Ok(())
34}
35
36#[inline]
37pub fn bail_if_cached_beta_non_finite(beta: &Array1<f64>) -> Result<(), EstimationError> {
38 if beta.iter().any(|v| !v.is_finite()) {
39 return Err(EstimationError::InvalidInput(
40 "cached inner beta contains non-finite entries".to_string(),
41 ));
42 }
43 Ok(())
44}
45
46pub fn ensure_finite_scalar(name: &str, value: f64) -> Result<(), String> {
48 ensure_finite_scalar_estimation(name, value).map_err(|err| err.to_string())
49}
50
51pub fn validate_all_finite<I: IntoIterator<Item = f64>>(
53 label: &str,
54 values: I,
55) -> Result<(), String> {
56 validate_all_finite_estimation(label, values).map_err(|err| err.to_string())
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62
63 #[test]
64 fn ensure_finite_scalar_ok_for_finite() {
65 assert!(ensure_finite_scalar("x", 0.0).is_ok());
66 assert!(ensure_finite_scalar("x", -1.5).is_ok());
67 assert!(ensure_finite_scalar("x", f64::MIN).is_ok());
68 }
69
70 #[test]
71 fn ensure_finite_scalar_err_for_nan() {
72 let e = ensure_finite_scalar("my_value", f64::NAN).unwrap_err();
73 assert!(e.contains("my_value"), "error should mention the name: {e}");
74 }
75
76 #[test]
77 fn ensure_finite_scalar_err_for_inf() {
78 assert!(ensure_finite_scalar("v", f64::INFINITY).is_err());
79 assert!(ensure_finite_scalar("v", f64::NEG_INFINITY).is_err());
80 }
81
82 #[test]
83 fn validate_all_finite_ok_for_finite_slice() {
84 assert!(validate_all_finite("vec", [1.0, 2.0, 3.0]).is_ok());
85 assert!(validate_all_finite("empty", std::iter::empty()).is_ok());
86 }
87
88 #[test]
89 fn validate_all_finite_err_reports_index() {
90 let e = validate_all_finite("arr", [1.0, f64::NAN, 3.0]).unwrap_err();
91 assert!(e.contains("arr[1]"), "error should mention arr[1]: {e}");
92 }
93
94 #[test]
95 fn validate_all_finite_err_reports_inf() {
96 let e = validate_all_finite("data", [0.0, f64::INFINITY]).unwrap_err();
97 assert!(e.contains("data[1]"), "error should mention data[1]: {e}");
98 }
99
100 #[test]
101 fn bail_if_cached_beta_ok_for_finite() {
102 let beta = ndarray::array![1.0, -2.0, 3.0];
103 assert!(bail_if_cached_beta_non_finite(&beta).is_ok());
104 }
105
106 #[test]
107 fn bail_if_cached_beta_err_for_nan() {
108 let beta = ndarray::array![1.0, f64::NAN];
109 assert!(bail_if_cached_beta_non_finite(&beta).is_err());
110 }
111}