scirs2_series/
validation.rs

1//! Validation utilities for time series module
2//!
3//! Provides centralized validation functions for parameters and data
4
5use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
6use scirs2_core::numeric::{Float, FromPrimitive};
7use std::fmt::Display;
8
9use crate::error::{Result, TimeSeriesError};
10use statrs::statistics::Statistics;
11
12/// Validate that a value is positive
13#[allow(dead_code)]
14pub fn check_positive<F: Float + Display>(value: F, name: &str) -> Result<()> {
15    if value <= F::zero() {
16        return Err(TimeSeriesError::InvalidParameter {
17            name: name.to_string(),
18            message: format!("Must be positive, got {value}"),
19        });
20    }
21    Ok(())
22}
23
24/// Validate that a value is non-negative
25#[allow(dead_code)]
26pub fn check_non_negative<F: Float + Display>(value: F, name: &str) -> Result<()> {
27    if value < F::zero() {
28        return Err(TimeSeriesError::InvalidParameter {
29            name: name.to_string(),
30            message: format!("Must be non-negative, got {value}"),
31        });
32    }
33    Ok(())
34}
35
36/// Validate that a value is in range [0, 1]
37#[allow(dead_code)]
38pub fn check_probability<F: Float + Display>(value: F, name: &str) -> Result<()> {
39    if value < F::zero() || value > F::one() {
40        return Err(TimeSeriesError::InvalidParameter {
41            name: name.to_string(),
42            message: format!("Must be in [0, 1], got {value}"),
43        });
44    }
45    Ok(())
46}
47
48/// Validate that a value is in a given range
49#[allow(dead_code)]
50pub fn check_in_range<F: Float + Display>(value: F, min: F, max: F, name: &str) -> Result<()> {
51    if value < min || value > max {
52        return Err(TimeSeriesError::InvalidParameter {
53            name: name.to_string(),
54            message: format!("Must be in [{min}, {max}], got {value}"),
55        });
56    }
57    Ok(())
58}
59
60/// Validate that an array has sufficient length
61#[allow(dead_code)]
62pub fn check_array_length<S, F>(
63    data: &ArrayBase<S, Ix1>,
64    min_length: usize,
65    operation: &str,
66) -> Result<()>
67where
68    S: Data<Elem = F>,
69    F: Float,
70{
71    if data.len() < min_length {
72        return Err(TimeSeriesError::InsufficientData {
73            message: format!("for {operation}"),
74            required: min_length,
75            actual: data.len(),
76        });
77    }
78    Ok(())
79}
80
81/// Validate that two arrays have the same length
82#[allow(dead_code)]
83pub fn check_same_length<S1, S2, F>(
84    arr1: &ArrayBase<S1, Ix1>,
85    arr2: &ArrayBase<S2, Ix1>,
86    _name1: &str,
87    name2: &str,
88) -> Result<()>
89where
90    S1: Data<Elem = F>,
91    S2: Data<Elem = F>,
92    F: Float,
93{
94    if arr1.len() != arr2.len() {
95        return Err(TimeSeriesError::DimensionMismatch {
96            expected: arr1.len(),
97            actual: arr2.len(),
98        });
99    }
100    Ok(())
101}
102
103/// Validate ARIMA orders
104#[allow(dead_code)]
105pub fn validate_arima_orders(p: usize, d: usize, q: usize) -> Result<()> {
106    if p > 10 {
107        return Err(TimeSeriesError::InvalidParameter {
108            name: "p".to_string(),
109            message: format!("AR order too large: {p}"),
110        });
111    }
112    if d > 3 {
113        return Err(TimeSeriesError::InvalidParameter {
114            name: "d".to_string(),
115            message: format!("Differencing order too large: {d}"),
116        });
117    }
118    if q > 10 {
119        return Err(TimeSeriesError::InvalidParameter {
120            name: "q".to_string(),
121            message: format!("MA order too large: {q}"),
122        });
123    }
124    Ok(())
125}
126
127/// Validate seasonal ARIMA orders
128#[allow(dead_code)]
129pub fn validate_seasonal_arima_orders(
130    p: usize,
131    d: usize,
132    q: usize,
133    p_seasonal: usize,
134    d_seasonal: usize,
135    q_seasonal: usize,
136    period: usize,
137) -> Result<()> {
138    validate_arima_orders(p, d, q)?;
139
140    if p_seasonal > 5 {
141        return Err(TimeSeriesError::InvalidParameter {
142            name: "p_seasonal".to_string(),
143            message: format!("Seasonal AR order too large: {p_seasonal}"),
144        });
145    }
146    if d_seasonal > 2 {
147        return Err(TimeSeriesError::InvalidParameter {
148            name: "d_seasonal".to_string(),
149            message: format!("Seasonal differencing order too large: {d_seasonal}"),
150        });
151    }
152    if q_seasonal > 5 {
153        return Err(TimeSeriesError::InvalidParameter {
154            name: "q_seasonal".to_string(),
155            message: format!("Seasonal MA order too large: {q_seasonal}"),
156        });
157    }
158    if period < 2 {
159        return Err(TimeSeriesError::InvalidParameter {
160            name: "period".to_string(),
161            message: format!("Period must be at least 2, got {period}"),
162        });
163    }
164    if period > 365 {
165        return Err(TimeSeriesError::InvalidParameter {
166            name: "period".to_string(),
167            message: format!("Period too large: {period}"),
168        });
169    }
170
171    Ok(())
172}
173
174/// Validate forecast horizon
175#[allow(dead_code)]
176pub fn validate_forecast_horizon(_steps: usize, maxreasonable: Option<usize>) -> Result<()> {
177    if _steps == 0 {
178        return Err(TimeSeriesError::InvalidParameter {
179            name: "_steps".to_string(),
180            message: "Forecast horizon must be positive".to_string(),
181        });
182    }
183
184    let max = maxreasonable.unwrap_or(10000);
185    if _steps > max {
186        return Err(TimeSeriesError::InvalidParameter {
187            name: "_steps".to_string(),
188            message: format!("Forecast horizon too large: {_steps}"),
189        });
190    }
191
192    Ok(())
193}
194
195/// Validate window size for rolling operations
196#[allow(dead_code)]
197pub fn validate_window_size(_window: usize, datalength: usize) -> Result<()> {
198    if _window == 0 {
199        return Err(TimeSeriesError::InvalidParameter {
200            name: "_window".to_string(),
201            message: "Window size must be positive".to_string(),
202        });
203    }
204
205    if _window > datalength {
206        return Err(TimeSeriesError::InvalidParameter {
207            name: "_window".to_string(),
208            message: format!("Window size {_window} exceeds data _length {datalength}"),
209        });
210    }
211
212    Ok(())
213}
214
215/// Validate lag for time series operations
216#[allow(dead_code)]
217pub fn validate_lag(_lag: usize, datalength: usize) -> Result<()> {
218    if _lag >= datalength {
219        return Err(TimeSeriesError::InvalidParameter {
220            name: "_lag".to_string(),
221            message: format!("Lag {_lag} must be less than data _length {datalength}"),
222        });
223    }
224    Ok(())
225}
226
227/// Check if array has no missing values
228#[allow(dead_code)]
229pub fn check_no_missing<S, F>(data: &ArrayBase<S, Ix1>) -> Result<()>
230where
231    S: Data<Elem = F>,
232    F: Float,
233{
234    for (i, &x) in data.iter().enumerate() {
235        if x.is_nan() || x.is_infinite() {
236            return Err(TimeSeriesError::InvalidInput(format!(
237                "Non-finite value at index {i}"
238            )));
239        }
240    }
241    Ok(())
242}
243
244/// Check if array is stationary (basic check)
245#[allow(dead_code)]
246pub fn check_stationarity_basic<S, F>(data: &ArrayBase<S, Ix1>) -> Result<bool>
247where
248    S: Data<Elem = F>,
249    F: Float + FromPrimitive,
250{
251    check_array_length(data, 10, "stationarity check")?;
252
253    // Split _data into two halves
254    let mid = data.len() / 2;
255    let first_half = data.slice(scirs2_core::ndarray::s![..mid]);
256    let second_half = data.slice(scirs2_core::ndarray::s![mid..]);
257
258    // Compare means and variances
259    let mean1 = first_half.mean().unwrap_or(F::zero());
260    let mean2 = second_half.mean().unwrap_or(F::zero());
261
262    let var1 = first_half
263        .mapv(|x| (x - mean1) * (x - mean1))
264        .mean()
265        .unwrap_or(F::zero());
266    let var2 = second_half
267        .mapv(|x| (x - mean2) * (x - mean2))
268        .mean()
269        .unwrap_or(F::zero());
270
271    // Check if means and variances are similar
272    let mean_diff = (mean1 - mean2).abs();
273    let var_ratio = if var1 > F::zero() && var2 > F::zero() {
274        (var1 / var2).max(var2 / var1)
275    } else {
276        F::one()
277    };
278
279    // Rough thresholds
280    let mean_threshold =
281        F::from(0.2).unwrap() * (var1.sqrt() + var2.sqrt()) / F::from(2.0).unwrap();
282    let var_threshold = F::from(2.0).unwrap();
283
284    Ok(mean_diff < mean_threshold && var_ratio < var_threshold)
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use scirs2_core::ndarray::array;
291
292    #[test]
293    fn test_check_positive() {
294        assert!(check_positive(1.0, "value").is_ok());
295        assert!(check_positive(0.0, "value").is_err());
296        assert!(check_positive(-1.0, "value").is_err());
297    }
298
299    #[test]
300    fn test_check_probability() {
301        assert!(check_probability(0.5, "prob").is_ok());
302        assert!(check_probability(0.0, "prob").is_ok());
303        assert!(check_probability(1.0, "prob").is_ok());
304        assert!(check_probability(1.1, "prob").is_err());
305        assert!(check_probability(-0.1, "prob").is_err());
306    }
307
308    #[test]
309    fn test_check_array_length() {
310        let arr = array![1.0, 2.0, 3.0];
311        assert!(check_array_length(&arr, 3, "test").is_ok());
312        assert!(check_array_length(&arr, 4, "test").is_err());
313    }
314
315    #[test]
316    fn test_validate_arima_orders() {
317        assert!(validate_arima_orders(2, 1, 2).is_ok());
318        assert!(validate_arima_orders(11, 1, 1).is_err());
319        assert!(validate_arima_orders(1, 4, 1).is_err());
320        assert!(validate_arima_orders(1, 1, 11).is_err());
321    }
322}