sklears_utils/
validation.rs

1//! Input validation utilities
2
3use crate::{UtilsError, UtilsResult};
4use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Dimension, OwnedRepr};
5use sklears_core::types::{Float, Int};
6
7/// Check that all arrays have consistent first dimension (number of samples)
8pub fn check_consistent_length<T>(arrays: &[&Array1<T>]) -> UtilsResult<()> {
9    if arrays.is_empty() {
10        return Ok(());
11    }
12
13    let first_length = arrays[0].len();
14    for (_i, array) in arrays.iter().enumerate().skip(1) {
15        if array.len() != first_length {
16            return Err(UtilsError::ShapeMismatch {
17                expected: vec![first_length],
18                actual: vec![array.len()],
19            });
20        }
21    }
22    Ok(())
23}
24
25/// Check that X and y have consistent number of samples (generic version)
26pub fn check_consistent_length_xy<T, U>(x: &Array2<T>, y: &Array1<U>) -> UtilsResult<()> {
27    if x.nrows() != y.len() {
28        return Err(UtilsError::ShapeMismatch {
29            expected: vec![x.nrows()],
30            actual: vec![y.len()],
31        });
32    }
33    Ok(())
34}
35
36/// Check that a 2D array has valid shape and properties
37pub fn check_array_2d<T>(array: &Array2<T>) -> UtilsResult<()> {
38    check_non_empty(array)?;
39
40    if array.ncols() == 0 {
41        return Err(UtilsError::InvalidParameter(
42            "Array must have at least one column".to_string(),
43        ));
44    }
45
46    Ok(())
47}
48
49/// Check that X and y have consistent number of samples
50pub fn check_x_y(x: &Array2<Float>, y: &Array1<Int>) -> UtilsResult<()> {
51    if x.nrows() != y.len() {
52        return Err(UtilsError::ShapeMismatch {
53            expected: vec![x.nrows()],
54            actual: vec![y.len()],
55        });
56    }
57
58    if x.is_empty() || y.is_empty() {
59        return Err(UtilsError::EmptyInput);
60    }
61
62    Ok(())
63}
64
65/// Check that X and y have consistent number of samples (regression version)
66pub fn check_x_y_regression(x: &Array2<Float>, y: &Array1<Float>) -> UtilsResult<()> {
67    if x.nrows() != y.len() {
68        return Err(UtilsError::ShapeMismatch {
69            expected: vec![x.nrows()],
70            actual: vec![y.len()],
71        });
72    }
73
74    if x.is_empty() || y.is_empty() {
75        return Err(UtilsError::EmptyInput);
76    }
77
78    Ok(())
79}
80
81/// Check that an array is not empty
82pub fn check_non_empty<T, D: Dimension>(array: &ArrayBase<OwnedRepr<T>, D>) -> UtilsResult<()> {
83    if array.is_empty() {
84        return Err(UtilsError::EmptyInput);
85    }
86    Ok(())
87}
88
89/// Check that a parameter is positive
90pub fn check_positive(value: Float, name: &str) -> UtilsResult<()> {
91    if value <= 0.0 {
92        return Err(UtilsError::InvalidParameter(format!(
93            "{name} must be positive, got {value}"
94        )));
95    }
96    Ok(())
97}
98
99/// Check that a parameter is non-negative
100pub fn check_non_negative(value: Float, name: &str) -> UtilsResult<()> {
101    if value < 0.0 {
102        return Err(UtilsError::InvalidParameter(format!(
103            "{name} must be non-negative, got {value}"
104        )));
105    }
106    Ok(())
107}
108
109/// Check that a parameter is in a valid range
110pub fn check_range(value: Float, min: Float, max: Float, name: &str) -> UtilsResult<()> {
111    if value < min || value > max {
112        return Err(UtilsError::InvalidParameter(format!(
113            "{name} must be in range [{min}, {max}], got {value}"
114        )));
115    }
116    Ok(())
117}
118
119/// Check that integer parameter is positive
120pub fn check_positive_int(value: usize, name: &str) -> UtilsResult<()> {
121    if value == 0 {
122        return Err(UtilsError::InvalidParameter(format!(
123            "{name} must be positive, got {value}"
124        )));
125    }
126    Ok(())
127}
128
129/// Check that we have enough samples for the operation
130pub fn check_min_samples(n_samples: usize, min_samples: usize) -> UtilsResult<()> {
131    if n_samples < min_samples {
132        return Err(UtilsError::InsufficientData {
133            min: min_samples,
134            actual: n_samples,
135        });
136    }
137    Ok(())
138}
139
140/// Check that array contains only finite values (no NaN or infinity)
141pub fn check_finite(array: &Array2<Float>) -> UtilsResult<()> {
142    for &value in array.iter() {
143        if !value.is_finite() {
144            return Err(UtilsError::InvalidParameter(
145                "Array contains non-finite values (NaN or infinity)".to_string(),
146            ));
147        }
148    }
149    Ok(())
150}
151
152/// Check that array contains only finite values (1D version)
153pub fn check_finite_1d(array: &Array1<Float>) -> UtilsResult<()> {
154    for &value in array.iter() {
155        if !value.is_finite() {
156            return Err(UtilsError::InvalidParameter(
157                "Array contains non-finite values (NaN or infinity)".to_string(),
158            ));
159        }
160    }
161    Ok(())
162}
163
164/// Validate feature matrix shape and contents
165pub fn validate_features(x: &Array2<Float>) -> UtilsResult<()> {
166    check_non_empty(x)?;
167    check_finite(x)?;
168
169    if x.ncols() == 0 {
170        return Err(UtilsError::InvalidParameter(
171            "Feature matrix must have at least one feature".to_string(),
172        ));
173    }
174
175    Ok(())
176}
177
178/// Validate target array
179pub fn validate_target(y: &Array1<Int>) -> UtilsResult<()> {
180    check_non_empty(y)?;
181    Ok(())
182}
183
184/// Validate target array (regression version)
185pub fn validate_target_regression(y: &Array1<Float>) -> UtilsResult<()> {
186    check_non_empty(y)?;
187    check_finite_1d(y)?;
188    Ok(())
189}
190
191/// Check that class labels are valid (non-negative integers)
192pub fn validate_class_labels(y: &Array1<Int>) -> UtilsResult<Vec<Int>> {
193    validate_target(y)?;
194
195    let mut classes: Vec<Int> = y.iter().copied().collect();
196    classes.sort_unstable();
197    classes.dedup();
198
199    for &class in &classes {
200        if class < 0 {
201            return Err(UtilsError::InvalidParameter(format!(
202                "Class labels must be non-negative, found {class}"
203            )));
204        }
205    }
206
207    Ok(classes)
208}
209
210/// Check that we have at least min_classes distinct classes
211pub fn check_min_classes(classes: &[Int], min_classes: usize) -> UtilsResult<()> {
212    if classes.len() < min_classes {
213        return Err(UtilsError::InvalidParameter(format!(
214            "Need at least {min_classes} classes, found {}",
215            classes.len()
216        )));
217    }
218    Ok(())
219}
220
221/// Validate sample weights
222pub fn validate_sample_weights(sample_weight: &Array1<Float>, n_samples: usize) -> UtilsResult<()> {
223    if sample_weight.len() != n_samples {
224        return Err(UtilsError::ShapeMismatch {
225            expected: vec![n_samples],
226            actual: vec![sample_weight.len()],
227        });
228    }
229
230    check_finite_1d(sample_weight)?;
231
232    for &weight in sample_weight.iter() {
233        if weight < 0.0 {
234            return Err(UtilsError::InvalidParameter(
235                "Sample weights must be non-negative".to_string(),
236            ));
237        }
238    }
239
240    if sample_weight.sum() <= 0.0 {
241        return Err(UtilsError::InvalidParameter(
242            "Sum of sample weights must be positive".to_string(),
243        ));
244    }
245
246    Ok(())
247}
248
249/// Check that matrices have compatible shapes for matrix multiplication
250pub fn check_matmul_shapes(a: &Array2<Float>, b: &Array2<Float>) -> UtilsResult<()> {
251    if a.ncols() != b.nrows() {
252        return Err(UtilsError::ShapeMismatch {
253            expected: vec![a.nrows(), a.ncols(), b.ncols()],
254            actual: vec![a.nrows(), a.ncols(), b.nrows(), b.ncols()],
255        });
256    }
257    Ok(())
258}
259
260/// Validate learning rate parameter
261pub fn validate_learning_rate(learning_rate: Float) -> UtilsResult<()> {
262    check_positive(learning_rate, "learning_rate")?;
263    check_range(learning_rate, 0.0, 1.0, "learning_rate")?;
264    Ok(())
265}
266
267/// Validate regularization parameter
268pub fn validate_regularization(alpha: Float) -> UtilsResult<()> {
269    check_non_negative(alpha, "alpha")?;
270    Ok(())
271}
272
273/// Validate tolerance parameter
274pub fn validate_tolerance(tol: Float) -> UtilsResult<()> {
275    check_positive(tol, "tol")?;
276    Ok(())
277}
278
279/// Validate maximum iterations parameter
280pub fn validate_max_iter(max_iter: usize) -> UtilsResult<()> {
281    check_positive_int(max_iter, "max_iter")?;
282    Ok(())
283}
284
285/// Validate cross-validation fold indices
286pub fn validate_cv_folds(folds: &Array1<i32>, n_samples: usize, n_folds: usize) -> UtilsResult<()> {
287    if folds.len() != n_samples {
288        return Err(UtilsError::ShapeMismatch {
289            expected: vec![n_samples],
290            actual: vec![folds.len()],
291        });
292    }
293
294    // Check that fold indices are in valid range
295    for &fold_idx in folds.iter() {
296        if fold_idx < 0 || fold_idx >= n_folds as i32 {
297            return Err(UtilsError::InvalidParameter(format!(
298                "Fold index {fold_idx} is out of range [0, {n_folds})"
299            )));
300        }
301    }
302
303    // Check that all folds are represented
304    let mut fold_counts = vec![0; n_folds];
305    for &fold_idx in folds.iter() {
306        fold_counts[fold_idx as usize] += 1;
307    }
308
309    for (i, &count) in fold_counts.iter().enumerate() {
310        if count == 0 {
311            return Err(UtilsError::InvalidParameter(format!(
312                "Fold {i} has no samples assigned"
313            )));
314        }
315    }
316
317    Ok(())
318}
319
320/// Validate feature importance values
321pub fn validate_feature_importance(
322    importance: &Array1<Float>,
323    n_features: usize,
324) -> UtilsResult<()> {
325    if importance.len() != n_features {
326        return Err(UtilsError::ShapeMismatch {
327            expected: vec![n_features],
328            actual: vec![importance.len()],
329        });
330    }
331
332    // Check for non-negative values
333    for (i, &value) in importance.iter().enumerate() {
334        if value < 0.0 {
335            return Err(UtilsError::InvalidParameter(format!(
336                "Feature importance at index {i} is negative: {value}"
337            )));
338        }
339        if !value.is_finite() {
340            return Err(UtilsError::InvalidParameter(format!(
341                "Feature importance at index {i} is not finite: {value}"
342            )));
343        }
344    }
345
346    // Check if all importance values are zero (usually indicates an error)
347    if importance.iter().all(|&x| x == 0.0) {
348        return Err(UtilsError::InvalidParameter(
349            "All feature importance values are zero".to_string(),
350        ));
351    }
352
353    Ok(())
354}
355
356/// Validate model prediction format for classification
357pub fn validate_classification_predictions(
358    predictions: &Array1<i32>,
359    n_samples: usize,
360    valid_classes: &[i32],
361) -> UtilsResult<()> {
362    if predictions.len() != n_samples {
363        return Err(UtilsError::ShapeMismatch {
364            expected: vec![n_samples],
365            actual: vec![predictions.len()],
366        });
367    }
368
369    // Check that all predictions are valid class labels
370    for (i, &pred) in predictions.iter().enumerate() {
371        if !valid_classes.contains(&pred) {
372            return Err(UtilsError::InvalidParameter(format!(
373                "Prediction at index {i} ({pred}) is not a valid class label"
374            )));
375        }
376    }
377
378    Ok(())
379}
380
381/// Validate model prediction format for regression
382pub fn validate_regression_predictions(
383    predictions: &Array1<Float>,
384    n_samples: usize,
385) -> UtilsResult<()> {
386    if predictions.len() != n_samples {
387        return Err(UtilsError::ShapeMismatch {
388            expected: vec![n_samples],
389            actual: vec![predictions.len()],
390        });
391    }
392
393    // Check for finite values
394    for (i, &value) in predictions.iter().enumerate() {
395        if !value.is_finite() {
396            return Err(UtilsError::InvalidParameter(format!(
397                "Prediction at index {i} is not finite: {value}"
398            )));
399        }
400    }
401
402    Ok(())
403}
404
405/// Validate sparse matrix properties
406pub fn validate_sparse_matrix(
407    data: &Array1<Float>,
408    indices: &Array1<usize>,
409    indptr: &Array1<usize>,
410    n_rows: usize,
411    n_cols: usize,
412) -> UtilsResult<()> {
413    // Check basic consistency
414    if data.len() != indices.len() {
415        return Err(UtilsError::ShapeMismatch {
416            expected: vec![data.len()],
417            actual: vec![indices.len()],
418        });
419    }
420
421    if indptr.len() != n_rows + 1 {
422        return Err(UtilsError::ShapeMismatch {
423            expected: vec![n_rows + 1],
424            actual: vec![indptr.len()],
425        });
426    }
427
428    // Check indptr is non-decreasing and starts at 0
429    if indptr[0] != 0 {
430        return Err(UtilsError::InvalidParameter(
431            "indptr must start with 0".to_string(),
432        ));
433    }
434
435    for i in 1..indptr.len() {
436        if indptr[i] < indptr[i - 1] {
437            return Err(UtilsError::InvalidParameter(
438                "indptr must be non-decreasing".to_string(),
439            ));
440        }
441    }
442
443    // Check that last indptr value matches data length
444    if indptr[indptr.len() - 1] != data.len() {
445        return Err(UtilsError::InvalidParameter(
446            "Last indptr value must equal data length".to_string(),
447        ));
448    }
449
450    // Check column indices are valid
451    for &col_idx in indices.iter() {
452        if col_idx >= n_cols {
453            return Err(UtilsError::InvalidParameter(format!(
454                "Column index {col_idx} is out of bounds for matrix with {n_cols} columns"
455            )));
456        }
457    }
458
459    // Check for finite data values
460    for (i, &value) in data.iter().enumerate() {
461        if !value.is_finite() {
462            return Err(UtilsError::InvalidParameter(format!(
463                "Data value at index {i} is not finite: {value}"
464            )));
465        }
466    }
467
468    Ok(())
469}
470
471/// Validate time series data for temporal consistency
472pub fn validate_time_series(
473    data: &Array2<Float>,
474    timestamps: &Array1<Float>,
475    min_samples: usize,
476) -> UtilsResult<()> {
477    // Check consistent dimensions
478    if data.nrows() != timestamps.len() {
479        return Err(UtilsError::ShapeMismatch {
480            expected: vec![data.nrows()],
481            actual: vec![timestamps.len()],
482        });
483    }
484
485    // Check minimum number of samples
486    if data.nrows() < min_samples {
487        return Err(UtilsError::InsufficientData {
488            min: min_samples,
489            actual: data.nrows(),
490        });
491    }
492
493    // Check that timestamps are strictly increasing
494    for i in 1..timestamps.len() {
495        if timestamps[i] <= timestamps[i - 1] {
496            return Err(UtilsError::InvalidParameter(format!(
497                "Timestamps must be strictly increasing. Found {} <= {} at index {}",
498                timestamps[i],
499                timestamps[i - 1],
500                i
501            )));
502        }
503    }
504
505    // Check for finite values in both data and timestamps
506    for (i, &ts) in timestamps.iter().enumerate() {
507        if !ts.is_finite() {
508            return Err(UtilsError::InvalidParameter(format!(
509                "Timestamp at index {i} is not finite: {ts}"
510            )));
511        }
512    }
513
514    for ((i, j), &value) in data.indexed_iter() {
515        if !value.is_finite() {
516            return Err(UtilsError::InvalidParameter(format!(
517                "Data value at index ({i}, {j}) is not finite: {value}"
518            )));
519        }
520    }
521
522    Ok(())
523}
524
525/// Validate probability distribution (must sum to 1, all non-negative)
526pub fn validate_probability_distribution(
527    probabilities: &Array1<Float>,
528    tolerance: Float,
529) -> UtilsResult<()> {
530    // Check for non-negative values
531    for (i, &prob) in probabilities.iter().enumerate() {
532        if prob < 0.0 {
533            return Err(UtilsError::InvalidParameter(format!(
534                "Probability at index {i} is negative: {prob}"
535            )));
536        }
537        if !prob.is_finite() {
538            return Err(UtilsError::InvalidParameter(format!(
539                "Probability at index {i} is not finite: {prob}"
540            )));
541        }
542    }
543
544    // Check that probabilities sum to 1
545    let sum: Float = probabilities.sum();
546    if (sum - 1.0).abs() > tolerance {
547        return Err(UtilsError::InvalidParameter(format!(
548            "Probabilities must sum to 1.0 (±{tolerance}), got {sum}"
549        )));
550    }
551
552    Ok(())
553}
554
555#[allow(non_snake_case)]
556#[cfg(test)]
557mod tests {
558    use super::*;
559    use scirs2_core::ndarray::{array, Array2};
560
561    #[test]
562    fn test_check_consistent_length() {
563        let a = array![1, 2, 3];
564        let b = array![4, 5, 6];
565        let c = array![7, 8];
566
567        assert!(check_consistent_length(&[&a, &b]).is_ok());
568        assert!(check_consistent_length(&[&a, &c]).is_err());
569    }
570
571    #[test]
572    fn test_check_x_y() {
573        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
574        let y_good = array![0, 1, 0];
575        let y_bad = array![0, 1];
576
577        assert!(check_x_y(&x, &y_good).is_ok());
578        assert!(check_x_y(&x, &y_bad).is_err());
579    }
580
581    #[test]
582    fn test_check_positive() {
583        assert!(check_positive(1.0, "test").is_ok());
584        assert!(check_positive(0.0, "test").is_err());
585        assert!(check_positive(-1.0, "test").is_err());
586    }
587
588    #[test]
589    fn test_check_range() {
590        assert!(check_range(0.5, 0.0, 1.0, "test").is_ok());
591        assert!(check_range(-0.1, 0.0, 1.0, "test").is_err());
592        assert!(check_range(1.1, 0.0, 1.0, "test").is_err());
593    }
594
595    #[test]
596    fn test_validate_class_labels() {
597        let y_good = array![0, 1, 2, 1, 0];
598        let y_bad = array![0, 1, -1, 1, 0];
599
600        let classes = validate_class_labels(&y_good).unwrap();
601        assert_eq!(classes, vec![0, 1, 2]);
602
603        assert!(validate_class_labels(&y_bad).is_err());
604    }
605
606    #[test]
607    fn test_check_finite() {
608        let good = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
609        let bad = Array2::from_shape_vec((2, 2), vec![1.0, Float::NAN, 3.0, 4.0]).unwrap();
610
611        assert!(check_finite(&good).is_ok());
612        assert!(check_finite(&bad).is_err());
613    }
614
615    #[test]
616    fn test_validate_sample_weights() {
617        let good_weights = array![1.0, 2.0, 1.5];
618        let bad_weights = array![1.0, -1.0, 1.5];
619        let zero_sum_weights = array![0.0, 0.0, 0.0];
620
621        assert!(validate_sample_weights(&good_weights, 3).is_ok());
622        assert!(validate_sample_weights(&bad_weights, 3).is_err());
623        assert!(validate_sample_weights(&zero_sum_weights, 3).is_err());
624    }
625
626    #[test]
627    fn test_validate_cv_folds() {
628        let good_folds = array![0, 1, 2, 0, 1, 2];
629        let bad_folds_range = array![0, 1, 3, 0, 1, 2]; // 3 is out of range for 3 folds
630        let bad_folds_missing = array![0, 0, 1, 1, 1, 1]; // missing fold 2
631
632        assert!(validate_cv_folds(&good_folds, 6, 3).is_ok());
633        assert!(validate_cv_folds(&bad_folds_range, 6, 3).is_err());
634        assert!(validate_cv_folds(&bad_folds_missing, 6, 3).is_err());
635    }
636
637    #[test]
638    fn test_validate_feature_importance() {
639        let good_importance = array![0.5, 0.3, 0.2];
640        let bad_importance_negative = array![0.5, -0.1, 0.2];
641        let bad_importance_all_zero = array![0.0, 0.0, 0.0];
642
643        assert!(validate_feature_importance(&good_importance, 3).is_ok());
644        assert!(validate_feature_importance(&bad_importance_negative, 3).is_err());
645        assert!(validate_feature_importance(&bad_importance_all_zero, 3).is_err());
646    }
647
648    #[test]
649    fn test_validate_classification_predictions() {
650        let good_predictions = array![0, 1, 2, 1, 0];
651        let bad_predictions = array![0, 1, 3, 1, 0]; // 3 is not a valid class
652        let valid_classes = vec![0, 1, 2];
653
654        assert!(validate_classification_predictions(&good_predictions, 5, &valid_classes).is_ok());
655        assert!(validate_classification_predictions(&bad_predictions, 5, &valid_classes).is_err());
656    }
657
658    #[test]
659    fn test_validate_regression_predictions() {
660        let good_predictions = array![1.5, 2.3, -0.5, 10.0];
661        let bad_predictions = array![1.5, Float::NAN, -0.5, 10.0];
662
663        assert!(validate_regression_predictions(&good_predictions, 4).is_ok());
664        assert!(validate_regression_predictions(&bad_predictions, 4).is_err());
665    }
666
667    #[test]
668    fn test_validate_sparse_matrix() {
669        // Valid CSR matrix: [[1, 0, 2], [0, 0, 3], [4, 5, 6]]
670        let data = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
671        let indices = array![0, 2, 2, 0, 1, 2];
672        let indptr = array![0, 2, 3, 6];
673
674        assert!(validate_sparse_matrix(&data, &indices, &indptr, 3, 3).is_ok());
675
676        // Invalid: indptr doesn't start with 0
677        let bad_indptr = array![1, 2, 3, 6];
678        assert!(validate_sparse_matrix(&data, &indices, &bad_indptr, 3, 3).is_err());
679    }
680
681    #[test]
682    fn test_validate_time_series() {
683        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
684        let good_timestamps = array![1.0, 2.0, 3.0];
685        let bad_timestamps = array![1.0, 1.5, 1.2]; // not increasing
686
687        assert!(validate_time_series(&data, &good_timestamps, 2).is_ok());
688        assert!(validate_time_series(&data, &bad_timestamps, 2).is_err());
689    }
690
691    #[test]
692    fn test_validate_probability_distribution() {
693        let good_probs = array![0.3, 0.5, 0.2];
694        let bad_probs_negative = array![0.3, -0.1, 0.8];
695        let bad_probs_sum = array![0.3, 0.5, 0.3]; // sums to 1.1
696
697        assert!(validate_probability_distribution(&good_probs, 1e-6).is_ok());
698        assert!(validate_probability_distribution(&bad_probs_negative, 1e-6).is_err());
699        assert!(validate_probability_distribution(&bad_probs_sum, 1e-6).is_err());
700    }
701}