Skip to main content

tenflowers_dataset/
validation.rs

1use crate::Dataset;
2use std::collections::HashMap;
3use tenflowers_core::{Result, Tensor};
4
5// Type aliases for complex types
6#[allow(dead_code)]
7type SampleData<T> = (usize, (Tensor<T>, Tensor<T>));
8type SampleList<T> = [(usize, (Tensor<T>, Tensor<T>))];
9
10#[derive(Debug, Clone)]
11pub struct ValidationConfig {
12    pub check_schema: bool,
13    pub check_ranges: bool,
14    pub check_duplicates: bool,
15    pub check_outliers: bool,
16    pub outlier_threshold: f64, // Z-score threshold for outlier detection
17}
18
19impl Default for ValidationConfig {
20    fn default() -> Self {
21        Self {
22            check_schema: true,
23            check_ranges: true,
24            check_duplicates: true,
25            check_outliers: true,
26            outlier_threshold: 3.0, // 3 standard deviations
27        }
28    }
29}
30
31#[derive(Debug, Clone)]
32pub struct SchemaInfo {
33    pub feature_shape: Vec<usize>,
34    pub label_shape: Vec<usize>,
35    pub expected_feature_type: String,
36    pub expected_label_type: String,
37}
38
39#[derive(Debug, Clone)]
40pub struct RangeConstraint<T> {
41    pub min_value: Option<T>,
42    pub max_value: Option<T>,
43}
44
45impl<T> RangeConstraint<T> {
46    pub fn new(min_value: Option<T>, max_value: Option<T>) -> Self {
47        Self {
48            min_value,
49            max_value,
50        }
51    }
52
53    pub fn min(min_value: T) -> Self {
54        Self {
55            min_value: Some(min_value),
56            max_value: None,
57        }
58    }
59
60    pub fn max(max_value: T) -> Self {
61        Self {
62            min_value: None,
63            max_value: Some(max_value),
64        }
65    }
66
67    pub fn range(min_value: T, max_value: T) -> Self {
68        Self {
69            min_value: Some(min_value),
70            max_value: Some(max_value),
71        }
72    }
73}
74
75#[derive(Debug, Clone)]
76pub struct ValidationResult {
77    pub is_valid: bool,
78    pub schema_errors: Vec<String>,
79    pub range_errors: Vec<String>,
80    pub duplicate_indices: Vec<usize>,
81    pub outlier_indices: Vec<usize>,
82}
83
84impl ValidationResult {
85    pub fn new() -> Self {
86        Self {
87            is_valid: true,
88            schema_errors: Vec::new(),
89            range_errors: Vec::new(),
90            duplicate_indices: Vec::new(),
91            outlier_indices: Vec::new(),
92        }
93    }
94
95    pub fn has_errors(&self) -> bool {
96        !self.schema_errors.is_empty()
97            || !self.range_errors.is_empty()
98            || !self.duplicate_indices.is_empty()
99            || !self.outlier_indices.is_empty()
100    }
101
102    pub fn add_schema_error(&mut self, error: String) {
103        self.schema_errors.push(error);
104        self.is_valid = false;
105    }
106
107    pub fn add_range_error(&mut self, error: String) {
108        self.range_errors.push(error);
109        self.is_valid = false;
110    }
111
112    pub fn add_duplicate(&mut self, index: usize) {
113        self.duplicate_indices.push(index);
114        self.is_valid = false;
115    }
116
117    pub fn add_outlier(&mut self, index: usize) {
118        self.outlier_indices.push(index);
119        self.is_valid = false;
120    }
121}
122
123impl Default for ValidationResult {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129pub struct DataValidator<T> {
130    config: ValidationConfig,
131    schema_info: Option<SchemaInfo>,
132    feature_range: Option<RangeConstraint<T>>,
133    label_range: Option<RangeConstraint<T>>,
134}
135
136impl<T> DataValidator<T>
137where
138    T: Clone
139        + Default
140        + PartialEq
141        + PartialOrd
142        + std::fmt::Display
143        + scirs2_core::numeric::Float
144        + Send
145        + Sync
146        + 'static,
147{
148    pub fn new(config: ValidationConfig) -> Self {
149        Self {
150            config,
151            schema_info: None,
152            feature_range: None,
153            label_range: None,
154        }
155    }
156
157    pub fn with_schema(mut self, schema: SchemaInfo) -> Self {
158        self.schema_info = Some(schema);
159        self
160    }
161
162    pub fn with_feature_range(mut self, range: RangeConstraint<T>) -> Self {
163        self.feature_range = Some(range);
164        self
165    }
166
167    pub fn with_label_range(mut self, range: RangeConstraint<T>) -> Self {
168        self.label_range = Some(range);
169        self
170    }
171
172    pub fn validate<D: Dataset<T>>(&self, dataset: &D) -> Result<ValidationResult> {
173        let mut result = ValidationResult::new();
174
175        if dataset.is_empty() {
176            result.add_schema_error("Dataset is empty".to_string());
177            return Ok(result);
178        }
179
180        // Collect all samples for validation
181        let mut samples = Vec::new();
182        for i in 0..dataset.len() {
183            match dataset.get(i) {
184                Ok(sample) => samples.push((i, sample)),
185                Err(e) => {
186                    result.add_schema_error(format!("Failed to get sample {i}: {e:?}"));
187                }
188            }
189        }
190
191        if self.config.check_schema {
192            self.validate_schema(&samples, &mut result)?;
193        }
194
195        if self.config.check_ranges {
196            self.validate_ranges(&samples, &mut result)?;
197        }
198
199        if self.config.check_duplicates {
200            self.validate_duplicates(&samples, &mut result)?;
201        }
202
203        if self.config.check_outliers {
204            self.validate_outliers(&samples, &mut result)?;
205        }
206
207        Ok(result)
208    }
209
210    fn validate_schema(
211        &self,
212        samples: &SampleList<T>,
213        result: &mut ValidationResult,
214    ) -> Result<()> {
215        if let Some(ref schema) = self.schema_info {
216            for (index, (features, labels)) in samples {
217                // Check feature shape
218                if features.shape().dims() != schema.feature_shape {
219                    result.add_schema_error(format!(
220                        "Sample {}: Feature shape mismatch. Expected {:?}, got {:?}",
221                        index,
222                        schema.feature_shape,
223                        features.shape().dims()
224                    ));
225                }
226
227                // Check label shape
228                if labels.shape().dims() != schema.label_shape {
229                    result.add_schema_error(format!(
230                        "Sample {}: Label shape mismatch. Expected {:?}, got {:?}",
231                        index,
232                        schema.label_shape,
233                        labels.shape().dims()
234                    ));
235                }
236            }
237        }
238        Ok(())
239    }
240
241    fn validate_ranges(
242        &self,
243        samples: &SampleList<T>,
244        result: &mut ValidationResult,
245    ) -> Result<()> {
246        for (index, (features, labels)) in samples {
247            // Validate feature ranges
248            if let Some(ref range) = self.feature_range {
249                if let Some(feature_data) = features.as_slice() {
250                    for (i, &value) in feature_data.iter().enumerate() {
251                        if let Some(min_val) = &range.min_value {
252                            if value < *min_val {
253                                result.add_range_error(format!(
254                                    "Sample {index}: Feature {i} value {value} below minimum {min_val}"
255                                ));
256                            }
257                        }
258                        if let Some(max_val) = &range.max_value {
259                            if value > *max_val {
260                                result.add_range_error(format!(
261                                    "Sample {index}: Feature {i} value {value} above maximum {max_val}"
262                                ));
263                            }
264                        }
265                    }
266                }
267            }
268
269            // Validate label ranges
270            if let Some(ref range) = self.label_range {
271                if let Some(label_data) = labels.as_slice() {
272                    for (i, &value) in label_data.iter().enumerate() {
273                        if let Some(min_val) = &range.min_value {
274                            if value < *min_val {
275                                result.add_range_error(format!(
276                                    "Sample {index}: Label {i} value {value} below minimum {min_val}"
277                                ));
278                            }
279                        }
280                        if let Some(max_val) = &range.max_value {
281                            if value > *max_val {
282                                result.add_range_error(format!(
283                                    "Sample {index}: Label {i} value {value} above maximum {max_val}"
284                                ));
285                            }
286                        }
287                    }
288                }
289            }
290        }
291        Ok(())
292    }
293
294    fn validate_duplicates(
295        &self,
296        samples: &SampleList<T>,
297        result: &mut ValidationResult,
298    ) -> Result<()> {
299        let mut seen_features: HashMap<Vec<String>, Vec<usize>> = HashMap::new();
300
301        for (index, (features, _)) in samples {
302            if let Some(feature_data) = features.as_slice() {
303                // Convert to string representation for comparison
304                let feature_key: Vec<String> = feature_data
305                    .iter()
306                    .map(|&x| format!("{x:.6}")) // Use 6 decimal places for float comparison
307                    .collect();
308
309                seen_features.entry(feature_key).or_default().push(*index);
310            }
311        }
312
313        // Find duplicates
314        for (_, indices) in seen_features {
315            if indices.len() > 1 {
316                for &index in &indices[1..] {
317                    // Skip first occurrence
318                    result.add_duplicate(index);
319                }
320            }
321        }
322
323        Ok(())
324    }
325
326    fn validate_outliers(
327        &self,
328        samples: &SampleList<T>,
329        result: &mut ValidationResult,
330    ) -> Result<()> {
331        if samples.is_empty() {
332            return Ok(());
333        }
334
335        // Collect feature values for statistical analysis
336        let mut feature_values: Vec<Vec<T>> = Vec::new();
337        let feature_size = if let Some((_, (features, _))) = samples.first() {
338            if let Some(data) = features.as_slice() {
339                data.len()
340            } else {
341                return Ok(()); // Can't analyze GPU tensors
342            }
343        } else {
344            return Ok(());
345        };
346
347        // Initialize feature value vectors
348        for _ in 0..feature_size {
349            feature_values.push(Vec::new());
350        }
351
352        // Collect all feature values
353        for (_, (features, _)) in samples {
354            if let Some(data) = features.as_slice() {
355                for (i, &value) in data.iter().enumerate() {
356                    if i < feature_values.len() {
357                        feature_values[i].push(value);
358                    }
359                }
360            }
361        }
362
363        // Calculate mean and std for each feature
364        let mut means = Vec::new();
365        let mut stds = Vec::new();
366
367        for values in &feature_values {
368            if values.is_empty() {
369                continue;
370            }
371
372            let mean = values.iter().copied().fold(T::zero(), |acc, x| acc + x)
373                / T::from(values.len()).expect("values length should convert to float");
374            means.push(mean);
375
376            let variance = values
377                .iter()
378                .map(|&x| {
379                    let diff = x - mean;
380                    diff * diff
381                })
382                .fold(T::zero(), |acc, x| acc + x)
383                / T::from(values.len()).expect("values length should convert to float");
384
385            let std = variance.sqrt();
386            stds.push(std);
387        }
388
389        // Check for outliers using Z-score
390        let threshold = T::from(self.config.outlier_threshold)
391            .expect("outlier threshold should convert to float");
392
393        for (index, (features, _)) in samples {
394            if let Some(data) = features.as_slice() {
395                for (i, &value) in data.iter().enumerate() {
396                    if i < means.len() && i < stds.len() {
397                        let mean = means[i];
398                        let std = stds[i];
399
400                        if std > T::zero() {
401                            let z_score = ((value - mean) / std).abs();
402                            if z_score > threshold {
403                                result.add_outlier(*index);
404                                break; // One outlier per sample is enough
405                            }
406                        }
407                    }
408                }
409            }
410        }
411
412        Ok(())
413    }
414}
415
416pub trait DatasetValidationExt<T> {
417    fn validate(&self, validator: &DataValidator<T>) -> Result<ValidationResult>;
418    fn validate_with_config(&self, config: ValidationConfig) -> Result<ValidationResult>;
419    fn is_valid(&self) -> Result<bool>;
420}
421
422impl<T, D: Dataset<T>> DatasetValidationExt<T> for D
423where
424    T: Clone
425        + Default
426        + PartialEq
427        + PartialOrd
428        + std::fmt::Display
429        + scirs2_core::numeric::Float
430        + Send
431        + Sync
432        + 'static,
433{
434    fn validate(&self, validator: &DataValidator<T>) -> Result<ValidationResult> {
435        validator.validate(self)
436    }
437
438    fn validate_with_config(&self, config: ValidationConfig) -> Result<ValidationResult> {
439        let validator = DataValidator::new(config);
440        validator.validate(self)
441    }
442
443    fn is_valid(&self) -> Result<bool> {
444        let config = ValidationConfig::default();
445        let result = self.validate_with_config(config)?;
446        Ok(!result.has_errors())
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use crate::TensorDataset;
454    use tenflowers_core::Tensor;
455
456    #[test]
457    fn test_validation_config() {
458        let config = ValidationConfig::default();
459        assert!(config.check_schema);
460        assert!(config.check_ranges);
461        assert!(config.check_duplicates);
462        assert!(config.check_outliers);
463        assert_eq!(config.outlier_threshold, 3.0);
464    }
465
466    #[test]
467    fn test_range_constraint() {
468        let range = RangeConstraint::range(0.0f32, 1.0f32);
469        assert_eq!(range.min_value, Some(0.0));
470        assert_eq!(range.max_value, Some(1.0));
471
472        let min_only = RangeConstraint::min(-1.0f32);
473        assert_eq!(min_only.min_value, Some(-1.0));
474        assert_eq!(min_only.max_value, None);
475
476        let max_only = RangeConstraint::max(10.0f32);
477        assert_eq!(max_only.min_value, None);
478        assert_eq!(max_only.max_value, Some(10.0));
479    }
480
481    #[test]
482    fn test_validation_result() {
483        let mut result = ValidationResult::new();
484        assert!(result.is_valid);
485        assert!(!result.has_errors());
486
487        result.add_schema_error("Schema error".to_string());
488        assert!(!result.is_valid);
489        assert!(result.has_errors());
490        assert_eq!(result.schema_errors.len(), 1);
491
492        result.add_duplicate(5);
493        assert_eq!(result.duplicate_indices.len(), 1);
494        assert_eq!(result.duplicate_indices[0], 5);
495    }
496
497    #[test]
498    fn test_schema_validation() {
499        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
500            .expect("test: tensor creation should succeed");
501        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
502            .expect("test: tensor creation should succeed");
503        let dataset = TensorDataset::new(features, labels);
504
505        let schema = SchemaInfo {
506            feature_shape: vec![2], // Expect shape [2] after squeezing
507            label_shape: vec![],    // Expect scalar after squeezing
508            expected_feature_type: "f32".to_string(),
509            expected_label_type: "f32".to_string(),
510        };
511
512        let validator = DataValidator::new(ValidationConfig::default()).with_schema(schema);
513
514        let result = validator
515            .validate(&dataset)
516            .expect("test: operation should succeed");
517        assert!(result.is_valid);
518        assert!(!result.has_errors());
519    }
520
521    #[test]
522    fn test_range_validation() {
523        let features = Tensor::<f32>::from_vec(
524            vec![0.5, 0.8, 1.2, 0.3], // 1.2 is above range [0, 1]
525            &[2, 2],
526        )
527        .expect("test: operation should succeed");
528        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
529            .expect("test: tensor creation should succeed");
530        let dataset = TensorDataset::new(features, labels);
531
532        let feature_range = RangeConstraint::range(0.0f32, 1.0f32);
533        let validator =
534            DataValidator::new(ValidationConfig::default()).with_feature_range(feature_range);
535
536        let result = validator
537            .validate(&dataset)
538            .expect("test: operation should succeed");
539        assert!(!result.is_valid);
540        assert!(result.has_errors());
541        assert!(!result.range_errors.is_empty());
542    }
543
544    #[test]
545    fn test_duplicate_detection() {
546        let features = Tensor::<f32>::from_vec(
547            vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0], // First two samples are duplicates
548            &[3, 2],
549        )
550        .expect("test: operation should succeed");
551        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 2.0], &[3])
552            .expect("test: tensor creation should succeed");
553        let dataset = TensorDataset::new(features, labels);
554
555        let config = ValidationConfig {
556            check_schema: false,
557            check_ranges: false,
558            check_duplicates: true,
559            check_outliers: false,
560            outlier_threshold: 3.0,
561        };
562
563        let validator = DataValidator::new(config);
564        let result = validator
565            .validate(&dataset)
566            .expect("test: operation should succeed");
567
568        assert!(!result.is_valid);
569        assert!(result.has_errors());
570        assert!(!result.duplicate_indices.is_empty());
571    }
572
573    #[test]
574    fn test_outlier_detection() {
575        let features = Tensor::<f32>::from_vec(
576            vec![1.0, 1.0, 1.1, 1.0, 1.2, 1.0, 1.0, 1.0, 100.0, 1.0], // 100.0 is an outlier with more stable baseline
577            &[5, 2],
578        )
579        .expect("test: operation should succeed");
580        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0], &[5])
581            .expect("test: tensor creation should succeed");
582        let dataset = TensorDataset::new(features, labels);
583
584        let config = ValidationConfig {
585            check_schema: false,
586            check_ranges: false,
587            check_duplicates: false,
588            check_outliers: true,
589            outlier_threshold: 1.0, // Very low threshold to catch the outlier
590        };
591
592        let validator = DataValidator::new(config);
593        let result = validator
594            .validate(&dataset)
595            .expect("test: operation should succeed");
596
597        assert!(!result.is_valid);
598        assert!(result.has_errors());
599        assert!(!result.outlier_indices.is_empty());
600    }
601
602    #[test]
603    fn test_dataset_validation_ext() {
604        let features = Tensor::<f32>::from_vec(vec![0.5, 0.8, 0.3, 0.7], &[2, 2])
605            .expect("test: tensor creation should succeed");
606        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
607            .expect("test: tensor creation should succeed");
608        let dataset = TensorDataset::new(features, labels);
609
610        let is_valid = dataset.is_valid().expect("test: operation should succeed");
611        assert!(is_valid);
612
613        let config = ValidationConfig::default();
614        let result = dataset
615            .validate_with_config(config)
616            .expect("test: operation should succeed");
617        assert!(result.is_valid);
618    }
619}