Skip to main content

scirs2_series/
diagnostics.rs

1//! Model diagnostics and validation tools for time series models
2//!
3//! Implements various diagnostic tests and residual analysis
4
5use scirs2_core::ndarray::ArrayStatCompat;
6use scirs2_core::ndarray::{Array1, ArrayBase, Data, Ix1, ScalarOperand};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::{Debug, Display};
9
10use crate::error::{Result, TimeSeriesError};
11use crate::utils::{autocorrelation, partial_autocorrelation};
12use statrs::statistics::Statistics;
13
14/// Residual diagnostics for time series models
15#[derive(Debug, Clone)]
16pub struct ResidualDiagnostics<F> {
17    /// Residuals
18    pub residuals: Array1<F>,
19    /// Standardized residuals
20    pub standardized_residuals: Array1<F>,
21    /// Mean of residuals
22    pub mean: F,
23    /// Standard deviation of residuals
24    pub std_dev: F,
25    /// Skewness
26    pub skewness: F,
27    /// Kurtosis
28    pub kurtosis: F,
29    /// Ljung-Box test results
30    pub ljung_box: LjungBoxTest<F>,
31    /// Jarque-Bera test for normality
32    pub jarque_bera: JarqueBeraTest<F>,
33    /// ACF of residuals
34    pub acf: Array1<F>,
35    /// PACF of residuals
36    pub pacf: Array1<F>,
37}
38
39/// Ljung-Box test for autocorrelation
40#[derive(Debug, Clone)]
41pub struct LjungBoxTest<F> {
42    /// Test statistic
43    pub statistic: F,
44    /// P-value
45    pub p_value: F,
46    /// Degrees of freedom
47    pub df: usize,
48    /// Number of lags tested
49    pub lags: usize,
50    /// Whether residuals are white noise
51    pub is_white_noise: bool,
52}
53
54/// Jarque-Bera test for normality
55#[derive(Debug, Clone)]
56pub struct JarqueBeraTest<F> {
57    /// Test statistic
58    pub statistic: F,
59    /// P-value
60    pub p_value: F,
61    /// Whether residuals are normal
62    pub is_normal: bool,
63}
64
65/// ARCH test for heteroskedasticity
66#[derive(Debug, Clone)]
67pub struct ArchTest<F> {
68    /// Test statistic
69    pub statistic: F,
70    /// P-value
71    pub p_value: F,
72    /// Number of lags
73    pub lags: usize,
74    /// Whether there is ARCH effect
75    pub has_arch_effect: bool,
76}
77
78/// Model validation results
79#[derive(Debug, Clone)]
80pub struct ModelValidation<F> {
81    /// In-sample fit statistics
82    pub in_sample: FitStatistics<F>,
83    /// Out-of-sample statistics (if available)
84    pub out_of_sample: Option<FitStatistics<F>>,
85    /// Cross-validation results
86    pub cross_validation: Option<CrossValidationResults<F>>,
87}
88
89/// Fit statistics
90#[derive(Debug, Clone)]
91pub struct FitStatistics<F> {
92    /// Mean Absolute Error
93    pub mae: F,
94    /// Mean Squared Error
95    pub mse: F,
96    /// Root Mean Squared Error
97    pub rmse: F,
98    /// Mean Absolute Percentage Error
99    pub mape: Option<F>,
100    /// Symmetric Mean Absolute Percentage Error
101    pub smape: Option<F>,
102    /// R-squared
103    pub r2: F,
104    /// Adjusted R-squared
105    pub adj_r2: F,
106}
107
108/// Cross-validation results
109#[derive(Debug, Clone)]
110pub struct CrossValidationResults<F> {
111    /// Average MAE across folds
112    pub avg_mae: F,
113    /// Average RMSE across folds
114    pub avg_rmse: F,
115    /// Average MAPE across folds
116    pub avg_mape: Option<F>,
117    /// MAE for each fold
118    pub fold_mae: Vec<F>,
119    /// RMSE for each fold
120    pub fold_rmse: Vec<F>,
121}
122
123/// Perform residual diagnostics
124#[allow(dead_code)]
125pub fn residual_diagnostics<S, F>(
126    residuals: &ArrayBase<S, Ix1>,
127    max_lag: Option<usize>,
128    alpha: F,
129) -> Result<ResidualDiagnostics<F>>
130where
131    S: Data<Elem = F>,
132    F: Float + FromPrimitive + Debug + Display + ScalarOperand,
133{
134    scirs2_core::validation::checkarray_finite(residuals, "residuals")?;
135
136    if residuals.len() < 4 {
137        return Err(TimeSeriesError::InvalidInput(
138            "Need at least 4 residuals for diagnostics".to_string(),
139        ));
140    }
141
142    // Basic statistics
143    let mean = residuals.mean_or(F::zero());
144    let variance = residuals
145        .mapv(|x| (x - mean) * (x - mean))
146        .mean()
147        .unwrap_or(F::zero());
148    let std_dev = scirs2_core::numeric::Float::sqrt(variance);
149
150    // Standardized residuals
151    let standardized_residuals = if std_dev > F::zero() {
152        residuals.mapv(|x| (x - mean) / std_dev)
153    } else {
154        residuals.to_owned()
155    };
156
157    // Skewness and kurtosis
158    let (skewness, kurtosis) = calculate_moments(&standardized_residuals)?;
159
160    // ACF and PACF
161    let lag_max = max_lag.unwrap_or((residuals.len() as f64).sqrt() as usize);
162    let acf = autocorrelation(&residuals.to_owned(), Some(lag_max))?;
163    let pacf = partial_autocorrelation(&residuals.to_owned(), Some(lag_max))?;
164
165    // Ljung-Box test
166    let ljung_box = ljung_box_test(residuals, lag_max, alpha)?;
167
168    // Jarque-Bera test
169    let jarque_bera = jarque_bera_test(&standardized_residuals, alpha)?;
170
171    Ok(ResidualDiagnostics {
172        residuals: residuals.to_owned(),
173        standardized_residuals,
174        mean,
175        std_dev,
176        skewness,
177        kurtosis,
178        ljung_box,
179        jarque_bera,
180        acf,
181        pacf,
182    })
183}
184
185/// Calculate skewness and kurtosis
186#[allow(dead_code)]
187fn calculate_moments<S, F>(data: &ArrayBase<S, Ix1>) -> Result<(F, F)>
188where
189    S: Data<Elem = F>,
190    F: Float + FromPrimitive + Display,
191{
192    let n = F::from(data.len()).expect("Operation failed");
193    let mean = data.mean_or(F::zero());
194
195    let mut m2 = F::zero();
196    let mut m3 = F::zero();
197    let mut m4 = F::zero();
198
199    for &x in data.iter() {
200        let diff = x - mean;
201        let diff2 = diff * diff;
202        m2 = m2 + diff2;
203        m3 = m3 + diff2 * diff;
204        m4 = m4 + diff2 * diff2;
205    }
206
207    m2 = m2 / n;
208    m3 = m3 / n;
209    m4 = m4 / n;
210
211    let skewness = m3 / m2.powf(F::from(1.5).expect("Failed to convert constant to float"));
212    let kurtosis = m4 / (m2 * m2) - F::from(3.0).expect("Failed to convert constant to float"); // Excess kurtosis
213
214    Ok((skewness, kurtosis))
215}
216
217/// Ljung-Box test for autocorrelation
218#[allow(dead_code)]
219pub fn ljung_box_test<S, F>(
220    residuals: &ArrayBase<S, Ix1>,
221    lags: usize,
222    alpha: F,
223) -> Result<LjungBoxTest<F>>
224where
225    S: Data<Elem = F>,
226    F: Float + FromPrimitive + Debug + Display + ScalarOperand,
227{
228    scirs2_core::validation::checkarray_finite(residuals, "residuals")?;
229
230    let n = residuals.len();
231    if lags >= n {
232        return Err(TimeSeriesError::InvalidInput(
233            "Number of lags exceeds residual length".to_string(),
234        ));
235    }
236
237    // Calculate autocorrelations
238    let acf = autocorrelation(&residuals.to_owned(), Some(lags))?;
239
240    // Ljung-Box statistic
241    let mut statistic = F::zero();
242    for k in 1..=lags {
243        let rk = acf[k];
244        statistic = statistic + rk * rk / F::from(n - k).expect("Failed to convert to float");
245    }
246    statistic = F::from(n * (n + 2)).expect("Operation failed") * statistic;
247
248    // Calculate p-value using chi-squared distribution
249    let df = lags;
250    let p_value = chi_squared_pvalue(statistic, df)?;
251
252    let is_white_noise = p_value > alpha;
253
254    Ok(LjungBoxTest {
255        statistic,
256        p_value,
257        df,
258        lags,
259        is_white_noise,
260    })
261}
262
263/// Jarque-Bera test for normality
264#[allow(dead_code)]
265pub fn jarque_bera_test<S, F>(residuals: &ArrayBase<S, Ix1>, alpha: F) -> Result<JarqueBeraTest<F>>
266where
267    S: Data<Elem = F>,
268    F: Float + FromPrimitive + Display,
269{
270    let n = F::from(residuals.len()).expect("Operation failed");
271    let (skewness, kurtosis) = calculate_moments(residuals)?;
272
273    // Jarque-Bera statistic
274    let statistic = n / F::from(6.0).expect("Failed to convert constant to float")
275        * (skewness * skewness
276            + kurtosis * kurtosis / F::from(4.0).expect("Failed to convert constant to float"));
277
278    // P-value using chi-squared distribution with 2 df
279    let p_value = chi_squared_pvalue(statistic, 2)?;
280
281    let is_normal = p_value > alpha;
282
283    Ok(JarqueBeraTest {
284        statistic,
285        p_value,
286        is_normal,
287    })
288}
289
290/// ARCH test for heteroskedasticity
291#[allow(dead_code)]
292pub fn arch_test<S, F>(residuals: &ArrayBase<S, Ix1>, lags: usize, alpha: F) -> Result<ArchTest<F>>
293where
294    S: Data<Elem = F>,
295    F: Float + FromPrimitive + Debug + Display + ScalarOperand,
296{
297    scirs2_core::validation::checkarray_finite(residuals, "residuals")?;
298
299    let n = residuals.len();
300    if lags >= n {
301        return Err(TimeSeriesError::InvalidInput(
302            "Number of lags exceeds residual length".to_string(),
303        ));
304    }
305
306    // Square the _residuals
307    let squared_residuals = residuals.mapv(|x| x * x);
308
309    // Regress squared _residuals on their lags
310    use scirs2_core::ndarray::Array2;
311
312    let y = squared_residuals
313        .slice(scirs2_core::ndarray::s![lags..])
314        .to_owned();
315    let mut x = Array2::zeros((n - lags, lags + 1));
316
317    // Add constant
318    for i in 0..(n - lags) {
319        x[[i, 0]] = F::one();
320    }
321
322    // Add lags
323    for lag in 1..=lags {
324        for i in 0..(n - lags) {
325            x[[i, lag]] = squared_residuals[i + lags - lag];
326        }
327    }
328
329    // Perform regression
330    let xtx = x.t().dot(&x);
331    let xty = x.t().dot(&y);
332
333    // Simple matrix inversion for OLS
334    let n = xtx.shape()[0];
335    if n == 0 {
336        return Err(TimeSeriesError::ComputationError(
337            "Empty matrix".to_string(),
338        ));
339    }
340
341    // Regularized pseudo-inverse
342    let lambda = F::from(1e-6).expect("Failed to convert constant to float");
343    let mut xtx_reg = xtx.clone();
344    for i in 0..n {
345        xtx_reg[[i, i]] = xtx_reg[[i, i]] + lambda;
346    }
347
348    // Simple matrix solve
349    if let Ok(coeffs) = matrix_solve(&xtx_reg, &xty) {
350        let fitted = x.dot(&coeffs);
351        let residuals_arch = y.clone() - &fitted;
352
353        // Calculate R-squared
354        let y_mean = y.mean_or(F::zero());
355        let ss_tot = y.mapv(|yi| (yi - y_mean) * (yi - y_mean)).sum();
356        let ss_res = residuals_arch.dot(&residuals_arch);
357        let r2 = if ss_tot > F::zero() {
358            F::one() - ss_res / ss_tot
359        } else {
360            F::zero()
361        };
362
363        // LM statistic
364        let statistic = F::from(n - lags).expect("Failed to convert to float") * r2;
365
366        // P-value using chi-squared distribution
367        let p_value = chi_squared_pvalue(statistic, lags)?;
368
369        let has_arch_effect = p_value < alpha;
370
371        Ok(ArchTest {
372            statistic,
373            p_value,
374            lags,
375            has_arch_effect,
376        })
377    } else {
378        Err(TimeSeriesError::ComputationError(
379            "Failed to perform ARCH test regression".to_string(),
380        ))
381    }
382}
383
384/// Calculate fit statistics
385#[allow(dead_code)]
386pub fn calculate_fit_statistics<S, F>(
387    actual: &ArrayBase<S, Ix1>,
388    predicted: &ArrayBase<S, Ix1>,
389    n_params: Option<usize>,
390) -> Result<FitStatistics<F>>
391where
392    S: Data<Elem = F>,
393    F: Float + FromPrimitive + Display,
394{
395    scirs2_core::validation::checkarray_finite(actual, "actual")?;
396    scirs2_core::validation::checkarray_finite(predicted, "predicted")?;
397
398    if actual.len() != predicted.len() {
399        return Err(TimeSeriesError::InvalidInput(
400            "Actual and predicted arrays must have same length".to_string(),
401        ));
402    }
403
404    let n = actual.len();
405    if n == 0 {
406        return Err(TimeSeriesError::InvalidInput(
407            "Empty arrays provided".to_string(),
408        ));
409    }
410
411    // Calculate errors
412    let errors = actual - predicted;
413    let squared_errors = errors.mapv(|e| e * e);
414
415    // Basic metrics
416    let mae = errors.mapv(|e| e.abs()).mean().expect("Operation failed");
417    let mse = squared_errors.mean().expect("Operation failed");
418    let rmse = mse.sqrt();
419
420    // MAPE and SMAPE (if no zeros in actual)
421    let has_zeros = actual.iter().any(|&x| x == F::zero());
422    let (mape, smape) = if !has_zeros {
423        let mape = errors
424            .iter()
425            .zip(actual.iter())
426            .map(|(e, a)| (*e / *a).abs())
427            .fold(F::zero(), |acc, x| acc + x)
428            / F::from(n).expect("Failed to convert to float");
429
430        let smape = errors
431            .iter()
432            .zip(actual.iter())
433            .zip(predicted.iter())
434            .map(|((e, a), p)| {
435                F::from(2.0).expect("Failed to convert constant to float") * e.abs()
436                    / (a.abs() + p.abs())
437            })
438            .fold(F::zero(), |acc, x| acc + x)
439            / F::from(n).expect("Failed to convert to float");
440
441        (Some(mape), Some(smape))
442    } else {
443        (None, None)
444    };
445
446    // R-squared
447    let y_mean = actual.mean().expect("Operation failed");
448    let ss_tot = actual.mapv(|y| (y - y_mean) * (y - y_mean)).sum();
449    let ss_res = squared_errors.sum();
450    let r2 = if ss_tot > F::zero() {
451        F::one() - ss_res / ss_tot
452    } else {
453        F::one()
454    };
455
456    // Adjusted R-squared
457    let adj_r2 = if let Some(p) = n_params {
458        let n_f = F::from(n).expect("Failed to convert to float");
459        let p_f = F::from(p).expect("Failed to convert to float");
460        F::one() - (F::one() - r2) * (n_f - F::one()) / (n_f - p_f - F::one())
461    } else {
462        r2
463    };
464
465    Ok(FitStatistics {
466        mae,
467        mse,
468        rmse,
469        mape,
470        smape,
471        r2,
472        adj_r2,
473    })
474}
475
476/// Perform time series cross-validation
477#[allow(dead_code)]
478pub fn time_series_cv<S, F, Model, Fit, Predict>(
479    data: &ArrayBase<S, Ix1>,
480    n_folds: usize,
481    min_train_size: usize,
482    forecast_horizon: usize,
483    model: Model,
484    fit: Fit,
485    predict: Predict,
486) -> Result<CrossValidationResults<F>>
487where
488    S: Data<Elem = F>,
489    F: Float + FromPrimitive + Debug + Display,
490    Model: Clone,
491    Fit: Fn(&mut Model, &Array1<F>) -> Result<()>,
492    Predict: Fn(&Model, usize, &Array1<F>) -> Result<Array1<F>>,
493{
494    scirs2_core::validation::checkarray_finite(data, "data")?;
495
496    let n = data.len();
497    if n < min_train_size + forecast_horizon {
498        return Err(TimeSeriesError::InvalidInput(
499            "Insufficient data for cross-validation".to_string(),
500        ));
501    }
502
503    let fold_size = (n - min_train_size - forecast_horizon) / n_folds;
504    if fold_size == 0 {
505        return Err(TimeSeriesError::InvalidInput(
506            "Too many _folds for available data".to_string(),
507        ));
508    }
509
510    let mut fold_mae = Vec::new();
511    let mut fold_rmse = Vec::new();
512    let mut fold_mape = Vec::new();
513
514    for fold in 0..n_folds {
515        let train_end = min_train_size + fold * fold_size;
516        let test_end = train_end + forecast_horizon;
517
518        if test_end > n {
519            break;
520        }
521
522        // Split data
523        let train_data = data.slice(scirs2_core::ndarray::s![..train_end]).to_owned();
524        let test_data = data
525            .slice(scirs2_core::ndarray::s![train_end..test_end])
526            .to_owned();
527
528        // Fit model
529        let mut fold_model = model.clone();
530        fit(&mut fold_model, &train_data)?;
531
532        // Make predictions
533        let predictions = predict(&fold_model, forecast_horizon, &train_data)?;
534
535        // Calculate metrics
536        let stats = calculate_fit_statistics(&test_data, &predictions, None)?;
537        fold_mae.push(stats.mae);
538        fold_rmse.push(stats.rmse);
539        if let Some(mape) = stats.mape {
540            fold_mape.push(mape);
541        }
542    }
543
544    // Calculate averages
545    let avg_mae = fold_mae.iter().fold(F::zero(), |acc, x| acc + *x)
546        / F::from(fold_mae.len()).expect("Operation failed");
547    let avg_rmse = fold_rmse.iter().fold(F::zero(), |acc, x| acc + *x)
548        / F::from(fold_rmse.len()).expect("Operation failed");
549    let avg_mape = if !fold_mape.is_empty() {
550        Some(
551            fold_mape.iter().fold(F::zero(), |acc, x| acc + *x)
552                / F::from(fold_mape.len()).expect("Operation failed"),
553        )
554    } else {
555        None
556    };
557
558    Ok(CrossValidationResults {
559        avg_mae,
560        avg_rmse,
561        avg_mape,
562        fold_mae,
563        fold_rmse,
564    })
565}
566
567/// Simplified chi-squared p-value calculation
568#[allow(dead_code)]
569fn chi_squared_pvalue<F>(statistic: F, df: usize) -> Result<F>
570where
571    F: Float + FromPrimitive + Display,
572{
573    // This is a simplified implementation
574    // In practice, would use a proper statistical library
575
576    // Use normal approximation for large df
577    if df > 30 {
578        let mean = F::from(df).expect("Failed to convert to float");
579        let std_dev = (F::from(2 * df).expect("Failed to convert to float")).sqrt();
580        let z = (statistic - mean) / std_dev;
581
582        // Approximate p-value using standard normal
583        if z > F::from(3.0).expect("Failed to convert constant to float") {
584            Ok(F::from(0.001).expect("Failed to convert constant to float"))
585        } else if z > F::from(2.0).expect("Failed to convert constant to float") {
586            Ok(F::from(0.05).expect("Failed to convert constant to float"))
587        } else if z > F::from(1.0).expect("Failed to convert constant to float") {
588            Ok(F::from(0.16).expect("Failed to convert constant to float"))
589        } else {
590            Ok(F::from(0.5).expect("Failed to convert constant to float"))
591        }
592    } else {
593        // For small df, use simple approximation
594        let critical_values = match df {
595            1 => (
596                F::from(3.841).expect("Failed to convert constant to float"),
597                F::from(6.635).expect("Failed to convert constant to float"),
598            ),
599            2 => (
600                F::from(5.991).expect("Failed to convert constant to float"),
601                F::from(9.210).expect("Failed to convert constant to float"),
602            ),
603            3 => (
604                F::from(7.815).expect("Failed to convert constant to float"),
605                F::from(11.345).expect("Failed to convert constant to float"),
606            ),
607            4 => (
608                F::from(9.488).expect("Failed to convert constant to float"),
609                F::from(13.277).expect("Failed to convert constant to float"),
610            ),
611            5 => (
612                F::from(11.070).expect("Failed to convert constant to float"),
613                F::from(15.086).expect("Failed to convert constant to float"),
614            ),
615            10 => (
616                F::from(18.307).expect("Failed to convert constant to float"),
617                F::from(23.209).expect("Failed to convert constant to float"),
618            ),
619            _ => (
620                F::from(df).expect("Failed to convert to float")
621                    * F::from(1.5).expect("Failed to convert constant to float"),
622                F::from(df).expect("Failed to convert to float")
623                    * F::from(2.0).expect("Failed to convert constant to float"),
624            ),
625        };
626
627        if statistic > critical_values.1 {
628            Ok(F::from(0.01).expect("Failed to convert constant to float"))
629        } else if statistic > critical_values.0 {
630            Ok(F::from(0.05).expect("Failed to convert constant to float"))
631        } else {
632            Ok(F::from(0.1).expect("Failed to convert constant to float"))
633        }
634    }
635}
636
637/// Simple matrix solve using Gaussian elimination
638#[allow(dead_code)]
639fn matrix_solve<F>(a: &scirs2_core::ndarray::Array2<F>, b: &Array1<F>) -> Result<Array1<F>>
640where
641    F: Float + FromPrimitive + ScalarOperand,
642{
643    let n = a.shape()[0];
644    if n != a.shape()[1] || n != b.len() {
645        return Err(TimeSeriesError::InvalidInput(
646            "Matrix dimensions mismatch".to_string(),
647        ));
648    }
649
650    // Create augmented matrix [A | b]
651    let mut aug = a.clone();
652    let mut rhs = b.clone();
653
654    // Forward elimination
655    for i in 0..n {
656        // Find pivot
657        let mut max_row = i;
658        for k in (i + 1)..n {
659            if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
660                max_row = k;
661            }
662        }
663
664        // Swap rows
665        if max_row != i {
666            for j in 0..n {
667                let temp = aug[[i, j]];
668                aug[[i, j]] = aug[[max_row, j]];
669                aug[[max_row, j]] = temp;
670            }
671            let temp = rhs[i];
672            rhs[i] = rhs[max_row];
673            rhs[max_row] = temp;
674        }
675
676        // Check for singular matrix
677        if aug[[i, i]].abs() < F::from(1e-10).expect("Failed to convert constant to float") {
678            return Err(TimeSeriesError::ComputationError(
679                "Matrix is singular".to_string(),
680            ));
681        }
682
683        // Eliminate column
684        for k in (i + 1)..n {
685            let factor = aug[[k, i]] / aug[[i, i]];
686            for j in i..n {
687                aug[[k, j]] = aug[[k, j]] - factor * aug[[i, j]];
688            }
689            rhs[k] = rhs[k] - factor * rhs[i];
690        }
691    }
692
693    // Back substitution
694    let mut x = Array1::zeros(n);
695    for i in (0..n).rev() {
696        let mut sum = rhs[i];
697        for j in (i + 1)..n {
698            sum = sum - aug[[i, j]] * x[j];
699        }
700        x[i] = sum / aug[[i, i]];
701    }
702
703    Ok(x)
704}
705
706#[cfg(test)]
707mod tests {
708    use super::*;
709    use scirs2_core::ndarray::array;
710
711    #[test]
712    fn test_residual_diagnostics() {
713        let residuals = array![0.1, -0.2, 0.3, -0.1, 0.2, -0.3, 0.1, -0.2];
714        let result = residual_diagnostics(&residuals, None, 0.05);
715        assert!(result.is_ok());
716
717        let diag = result.expect("Operation failed");
718        assert!(diag.mean.abs() < 0.1);
719        assert!(diag.std_dev > 0.0);
720    }
721
722    #[test]
723    fn test_ljung_box() {
724        let residuals = array![0.1, -0.2, 0.3, -0.1, 0.2, -0.3, 0.1, -0.2];
725        let result = ljung_box_test(&residuals, 3, 0.05);
726        assert!(result.is_ok());
727    }
728
729    #[test]
730    fn test_fit_statistics() {
731        let actual = array![1.0, 2.0, 3.0, 4.0, 5.0];
732        let predicted = array![1.1, 2.1, 2.9, 3.9, 5.1];
733
734        let result = calculate_fit_statistics(&actual, &predicted, Some(2));
735        assert!(result.is_ok());
736
737        let stats = result.expect("Operation failed");
738        assert!(stats.mae > 0.0);
739        assert!(stats.rmse > 0.0);
740        assert!(stats.r2 > 0.9);
741    }
742}