Skip to main content

voirs_dataset/validation/
datasets.rs

1//! Dataset validation utilities
2//!
3//! This module provides comprehensive validation tools for speech synthesis datasets,
4//! including format checking, integrity validation, and quality assessment.
5
6use crate::{DatasetSample, LanguageCode, Result, ValidationReport};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9
10/// Comprehensive dataset validator
11#[derive(Debug, Clone)]
12pub struct DatasetValidator {
13    config: ValidationConfig,
14}
15
16/// Configuration for dataset validation
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ValidationConfig {
19    /// Minimum audio duration in seconds
20    pub min_duration: f32,
21    /// Maximum audio duration in seconds
22    pub max_duration: f32,
23    /// Minimum text length in characters
24    pub min_text_length: usize,
25    /// Maximum text length in characters
26    pub max_text_length: usize,
27    /// Required sample rates (if empty, any is allowed)
28    pub allowed_sample_rates: Vec<u32>,
29    /// Required channel counts (if empty, any is allowed)
30    pub allowed_channels: Vec<u32>,
31    /// Whether to enforce ID uniqueness
32    pub enforce_unique_ids: bool,
33    /// Whether to check audio-text alignment
34    pub check_alignment: bool,
35    /// Whether to validate audio quality metrics
36    pub validate_quality: bool,
37    /// Minimum quality score (0.0-1.0)
38    pub min_quality_score: f32,
39    /// Maximum clipping percentage
40    pub max_clipping_percent: f32,
41    /// Minimum SNR in dB
42    pub min_snr_db: f32,
43}
44
45impl Default for ValidationConfig {
46    fn default() -> Self {
47        Self {
48            min_duration: 0.1,
49            max_duration: 30.0,
50            min_text_length: 1,
51            max_text_length: 1000,
52            allowed_sample_rates: vec![],
53            allowed_channels: vec![],
54            enforce_unique_ids: true,
55            check_alignment: true,
56            validate_quality: true,
57            min_quality_score: 0.3,
58            max_clipping_percent: 5.0,
59            min_snr_db: 10.0,
60        }
61    }
62}
63
64/// Detailed validation report with comprehensive analysis
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct DetailedValidationReport {
67    /// Basic validation info
68    pub basic_report: ValidationReport,
69    /// Format validation results
70    pub format_validation: FormatValidationResult,
71    /// Integrity validation results
72    pub integrity_validation: IntegrityValidationResult,
73    /// Quality validation results
74    pub quality_validation: QualityValidationResult,
75    /// Statistics about the validation
76    pub validation_stats: ValidationStatistics,
77}
78
79/// Audio and text format validation results
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct FormatValidationResult {
82    /// Audio format consistency
83    pub audio_format_consistent: bool,
84    /// Detected sample rates
85    pub sample_rates: Vec<u32>,
86    /// Detected channel counts
87    pub channel_counts: Vec<u32>,
88    /// Text encoding issues
89    pub text_encoding_issues: Vec<String>,
90    /// Character set analysis
91    pub character_sets: HashMap<String, usize>,
92    /// Language consistency
93    pub language_consistent: bool,
94    /// Detected languages
95    pub detected_languages: Vec<LanguageCode>,
96}
97
98/// Data integrity validation results
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct IntegrityValidationResult {
101    /// Whether all IDs are unique
102    pub unique_ids: bool,
103    /// Duplicate ID information
104    pub duplicate_ids: Vec<String>,
105    /// Empty or corrupted samples
106    pub corrupted_samples: Vec<usize>,
107    /// Missing required fields
108    pub missing_fields: HashMap<String, Vec<usize>>,
109    /// Metadata consistency
110    pub metadata_consistent: bool,
111    /// Type mismatches in metadata
112    pub metadata_type_issues: HashMap<String, Vec<String>>,
113}
114
115/// Quality validation results
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct QualityValidationResult {
118    /// Overall quality assessment
119    pub overall_quality_ok: bool,
120    /// Samples failing quality thresholds
121    pub low_quality_samples: Vec<usize>,
122    /// Audio corruption detection
123    pub corrupted_audio_samples: Vec<usize>,
124    /// Silent audio detection
125    pub silent_samples: Vec<usize>,
126    /// Clipped audio detection
127    pub clipped_samples: Vec<usize>,
128    /// Audio-text alignment issues
129    pub alignment_issues: Vec<usize>,
130    /// Quality score distribution
131    pub quality_distribution: QualityDistribution,
132}
133
134/// Quality score distribution analysis
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct QualityDistribution {
137    /// Excellent quality samples (>0.8)
138    pub excellent_count: usize,
139    /// Good quality samples (0.6-0.8)
140    pub good_count: usize,
141    /// Fair quality samples (0.4-0.6)
142    pub fair_count: usize,
143    /// Poor quality samples (<0.4)
144    pub poor_count: usize,
145    /// Average quality score
146    pub average_score: f32,
147    /// Quality score standard deviation
148    pub score_std_dev: f32,
149}
150
151/// Validation statistics
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct ValidationStatistics {
154    /// Total samples validated
155    pub total_samples: usize,
156    /// Samples that passed validation
157    pub passed_samples: usize,
158    /// Samples with warnings
159    pub warning_samples: usize,
160    /// Samples with errors
161    pub error_samples: usize,
162    /// Validation completion time in milliseconds
163    pub validation_time_ms: u64,
164}
165
166impl DatasetValidator {
167    /// Create a new validator with default configuration
168    pub fn new() -> Self {
169        Self {
170            config: ValidationConfig::default(),
171        }
172    }
173
174    /// Create a validator with custom configuration
175    pub fn with_config(config: ValidationConfig) -> Self {
176        Self { config }
177    }
178
179    /// Validate a single dataset sample
180    pub fn validate_sample(
181        &self,
182        sample: &DatasetSample,
183        index: usize,
184    ) -> Result<SampleValidationResult> {
185        let mut errors = Vec::new();
186        let mut warnings = Vec::new();
187
188        // Validate audio duration
189        let duration = sample.audio.duration();
190        if duration < self.config.min_duration {
191            let min_duration = self.config.min_duration;
192            errors.push(format!(
193                "Sample {index}: Audio too short ({duration:.3}s, minimum: {min_duration:.3}s)"
194            ));
195        }
196        if duration > self.config.max_duration {
197            let max_duration = self.config.max_duration;
198            warnings.push(format!(
199                "Sample {index}: Audio very long ({duration:.1}s, maximum: {max_duration:.1}s)"
200            ));
201        }
202
203        // Validate text length
204        let text_len = sample.text.len();
205        if text_len < self.config.min_text_length {
206            let min_text_length = self.config.min_text_length;
207            errors.push(format!(
208                "Sample {index}: Text too short ({text_len} chars, minimum: {min_text_length})"
209            ));
210        }
211        if text_len > self.config.max_text_length {
212            let max_text_length = self.config.max_text_length;
213            warnings.push(format!(
214                "Sample {index}: Text very long ({text_len} chars, maximum: {max_text_length})"
215            ));
216        }
217
218        // Validate audio format
219        if !self.config.allowed_sample_rates.is_empty()
220            && !self
221                .config
222                .allowed_sample_rates
223                .contains(&sample.audio.sample_rate())
224        {
225            errors.push(format!(
226                "Sample {index}: Invalid sample rate ({}Hz, allowed: {:?})",
227                sample.audio.sample_rate(),
228                self.config.allowed_sample_rates
229            ));
230        }
231
232        if !self.config.allowed_channels.is_empty()
233            && !self
234                .config
235                .allowed_channels
236                .contains(&sample.audio.channels())
237        {
238            errors.push(format!(
239                "Sample {index}: Invalid channel count ({}, allowed: {:?})",
240                sample.audio.channels(),
241                self.config.allowed_channels
242            ));
243        }
244
245        // Validate audio integrity
246        let audio_samples = sample.audio.samples();
247        if audio_samples.is_empty() {
248            errors.push(format!("Sample {index}: Empty audio"));
249        } else {
250            // Check for invalid audio values
251            let invalid_count = audio_samples.iter().filter(|&&s| !s.is_finite()).count();
252            if invalid_count > 0 {
253                errors.push(format!(
254                    "Sample {index}: Audio contains {invalid_count} invalid values (NaN/Infinity)"
255                ));
256            }
257
258            // Check for silent audio
259            let max_amplitude = audio_samples
260                .iter()
261                .fold(0.0f32, |max, &s| max.max(s.abs()));
262            if max_amplitude < 0.001 {
263                warnings.push(format!("Sample {index}: Audio appears to be silent"));
264            }
265
266            // Check for clipped audio
267            let clipped_count = audio_samples.iter().filter(|&&s| s.abs() >= 0.999).count();
268            let clipping_percent = (clipped_count as f32 / audio_samples.len() as f32) * 100.0;
269            if clipping_percent > self.config.max_clipping_percent {
270                warnings.push(format!(
271                    "Sample {index}: High clipping detected ({clipping_percent:.1}%)"
272                ));
273            }
274        }
275
276        // Validate text content
277        if sample.text.trim().is_empty() {
278            errors.push(format!("Sample {index}: Empty or whitespace-only text"));
279        }
280
281        // Check for unusual characters
282        if sample
283            .text
284            .chars()
285            .any(|c| c.is_control() && c != '\n' && c != '\t')
286        {
287            warnings.push(format!("Sample {index}: Text contains control characters"));
288        }
289
290        // Validate quality metrics if available
291        if self.config.validate_quality {
292            if let Some(quality_score) = sample.quality.overall_quality {
293                if quality_score < self.config.min_quality_score {
294                    warnings.push(format!(
295                        "Sample {index}: Low quality score ({quality_score:.2})"
296                    ));
297                }
298            }
299
300            if let Some(snr) = sample.quality.snr {
301                if snr < self.config.min_snr_db {
302                    warnings.push(format!("Sample {index}: Low SNR ({snr:.1}dB)"));
303                }
304            }
305
306            if let Some(clipping) = sample.quality.clipping {
307                if clipping > self.config.max_clipping_percent {
308                    warnings.push(format!(
309                        "Sample {index}: High clipping in quality metrics ({clipping:.1}%)"
310                    ));
311                }
312            }
313        }
314
315        // Check audio-text alignment if enabled
316        if self.config.check_alignment {
317            let alignment_result = self.check_audio_text_alignment(sample);
318            if let Some(issue) = alignment_result {
319                warnings.push(format!("Sample {index}: {issue}"));
320            }
321        }
322
323        Ok(SampleValidationResult {
324            index,
325            is_valid: errors.is_empty(),
326            errors,
327            warnings,
328        })
329    }
330
331    /// Validate an entire dataset
332    pub fn validate_dataset<T>(&self, samples: &[T]) -> Result<DetailedValidationReport>
333    where
334        T: AsRef<DatasetSample>,
335    {
336        let start_time = std::time::Instant::now();
337
338        let mut all_errors = Vec::new();
339        let mut all_warnings = Vec::new();
340        let mut sample_results = Vec::new();
341
342        // Validate individual samples
343        for (index, sample) in samples.iter().enumerate() {
344            let result = self.validate_sample(sample.as_ref(), index)?;
345            all_errors.extend(result.errors.clone());
346            all_warnings.extend(result.warnings.clone());
347            sample_results.push(result);
348        }
349
350        // Perform dataset-level validations
351        let format_result = self.validate_formats(samples);
352        let integrity_result = self.validate_integrity(samples);
353        let quality_result = self.validate_quality(samples);
354
355        // Collect additional errors and warnings from dataset-level validation
356        all_errors.extend(format_result.get_errors());
357        all_warnings.extend(format_result.get_warnings());
358        all_errors.extend(integrity_result.get_errors());
359        all_warnings.extend(integrity_result.get_warnings());
360        all_errors.extend(quality_result.get_errors());
361        all_warnings.extend(quality_result.get_warnings());
362
363        let validation_time = start_time.elapsed().as_millis() as u64;
364
365        // Calculate statistics
366        let error_samples = sample_results.iter().filter(|r| !r.is_valid).count();
367        let warning_samples = sample_results
368            .iter()
369            .filter(|r| !r.warnings.is_empty())
370            .count();
371
372        let validation_stats = ValidationStatistics {
373            total_samples: samples.len(),
374            passed_samples: samples.len() - error_samples,
375            warning_samples,
376            error_samples,
377            validation_time_ms: validation_time,
378        };
379
380        Ok(DetailedValidationReport {
381            basic_report: ValidationReport {
382                is_valid: all_errors.is_empty(),
383                errors: all_errors,
384                warnings: all_warnings,
385                items_validated: samples.len(),
386            },
387            format_validation: format_result,
388            integrity_validation: integrity_result,
389            quality_validation: quality_result,
390            validation_stats,
391        })
392    }
393
394    /// Validate audio and text formats
395    fn validate_formats<T>(&self, samples: &[T]) -> FormatValidationResult
396    where
397        T: AsRef<DatasetSample>,
398    {
399        let mut sample_rates = HashSet::new();
400        let mut channel_counts = HashSet::new();
401        let mut character_sets = HashMap::new();
402        let mut languages = HashSet::new();
403        let mut text_issues = Vec::new();
404
405        for (index, sample) in samples.iter().enumerate() {
406            let sample = sample.as_ref();
407
408            // Collect audio format info
409            sample_rates.insert(sample.audio.sample_rate());
410            channel_counts.insert(sample.audio.channels());
411
412            // Collect language info
413            languages.insert(sample.language);
414
415            // Analyze character sets
416            let charset = self.detect_character_set(&sample.text);
417            *character_sets.entry(charset).or_insert(0) += 1;
418
419            // Check for text encoding issues
420            if sample.text.chars().any(|c| c == '\u{FFFD}') {
421                text_issues.push(format!(
422                    "Sample {index}: Contains replacement characters (encoding issue)"
423                ));
424            }
425        }
426
427        FormatValidationResult {
428            audio_format_consistent: sample_rates.len() <= 1 && channel_counts.len() <= 1,
429            sample_rates: sample_rates.into_iter().collect(),
430            channel_counts: channel_counts.into_iter().collect(),
431            text_encoding_issues: text_issues,
432            character_sets,
433            language_consistent: languages.len() <= 1,
434            detected_languages: languages.into_iter().collect(),
435        }
436    }
437
438    /// Validate data integrity
439    fn validate_integrity<T>(&self, samples: &[T]) -> IntegrityValidationResult
440    where
441        T: AsRef<DatasetSample>,
442    {
443        let mut id_counts = HashMap::new();
444        let mut corrupted_samples = Vec::new();
445        let mut missing_fields: HashMap<String, Vec<usize>> = HashMap::new();
446        let mut metadata_types: HashMap<String, HashSet<String>> = HashMap::new();
447
448        for (index, sample) in samples.iter().enumerate() {
449            let sample = sample.as_ref();
450
451            // Check ID uniqueness
452            *id_counts.entry(sample.id.clone()).or_insert(0) += 1;
453
454            // Check for corrupted samples
455            if sample.audio.samples().is_empty() || sample.text.trim().is_empty() {
456                corrupted_samples.push(index);
457            }
458
459            // Check metadata consistency
460            for (key, value) in &sample.metadata {
461                let value_type = self.get_json_value_type(value);
462                metadata_types
463                    .entry(key.clone())
464                    .or_default()
465                    .insert(value_type);
466            }
467
468            // Check for missing critical fields
469            if sample.id.is_empty() {
470                missing_fields
471                    .entry("id".to_string())
472                    .or_default()
473                    .push(index);
474            }
475            if sample.text.is_empty() {
476                missing_fields
477                    .entry("text".to_string())
478                    .or_default()
479                    .push(index);
480            }
481        }
482
483        let duplicate_ids: Vec<String> = id_counts
484            .iter()
485            .filter(|(_, &count)| count > 1)
486            .map(|(id, _)| id.clone())
487            .collect();
488
489        let metadata_type_issues: HashMap<String, Vec<String>> = metadata_types
490            .iter()
491            .filter(|(_, types)| types.len() > 1)
492            .map(|(key, types)| (key.clone(), types.iter().cloned().collect()))
493            .collect();
494
495        IntegrityValidationResult {
496            unique_ids: duplicate_ids.is_empty(),
497            duplicate_ids,
498            corrupted_samples,
499            missing_fields,
500            metadata_consistent: metadata_type_issues.is_empty(),
501            metadata_type_issues,
502        }
503    }
504
505    /// Validate audio and text quality
506    fn validate_quality<T>(&self, samples: &[T]) -> QualityValidationResult
507    where
508        T: AsRef<DatasetSample>,
509    {
510        let mut low_quality_samples = Vec::new();
511        let mut corrupted_audio_samples = Vec::new();
512        let mut silent_samples = Vec::new();
513        let mut clipped_samples = Vec::new();
514        let mut alignment_issues = Vec::new();
515        let mut quality_scores = Vec::new();
516
517        for (index, sample) in samples.iter().enumerate() {
518            let sample = sample.as_ref();
519            let audio_samples = sample.audio.samples();
520
521            // Check audio corruption
522            if audio_samples.iter().any(|&s| !s.is_finite()) {
523                corrupted_audio_samples.push(index);
524                continue;
525            }
526
527            // Check for silence
528            let max_amplitude = audio_samples
529                .iter()
530                .fold(0.0f32, |max, &s| max.max(s.abs()));
531            if max_amplitude < 0.001 {
532                silent_samples.push(index);
533            }
534
535            // Check for clipping
536            let clipped_count = audio_samples.iter().filter(|&&s| s.abs() >= 0.999).count();
537            let clipping_percent = (clipped_count as f32 / audio_samples.len() as f32) * 100.0;
538            if clipping_percent > self.config.max_clipping_percent {
539                clipped_samples.push(index);
540            }
541
542            // Check quality metrics
543            if let Some(quality_score) = sample.quality.overall_quality {
544                quality_scores.push(quality_score);
545                if quality_score < self.config.min_quality_score {
546                    low_quality_samples.push(index);
547                }
548            }
549
550            // Check alignment
551            if self.config.check_alignment && self.check_audio_text_alignment(sample).is_some() {
552                alignment_issues.push(index);
553            }
554        }
555
556        let quality_distribution = self.calculate_quality_distribution(&quality_scores);
557
558        QualityValidationResult {
559            overall_quality_ok: low_quality_samples.is_empty()
560                && corrupted_audio_samples.is_empty(),
561            low_quality_samples,
562            corrupted_audio_samples,
563            silent_samples,
564            clipped_samples,
565            alignment_issues,
566            quality_distribution,
567        }
568    }
569
570    /// Check audio-text alignment for reasonable speaking rates
571    fn check_audio_text_alignment(&self, sample: &DatasetSample) -> Option<String> {
572        let duration = sample.audio.duration();
573        if duration <= 0.0 {
574            return Some("Zero duration audio".to_string());
575        }
576
577        let char_count = sample.text.len() as f32;
578        let word_count = sample.text.split_whitespace().count() as f32;
579
580        let chars_per_second = char_count / duration;
581        let words_per_second = word_count / duration;
582
583        // Reasonable speaking rates (empirically determined)
584        if chars_per_second > 50.0 {
585            return Some(format!(
586                "Speaking rate too fast ({chars_per_second:.1} chars/sec)"
587            ));
588        }
589        if chars_per_second < 1.0 && char_count > 0.0 {
590            return Some(format!(
591                "Speaking rate too slow ({chars_per_second:.1} chars/sec)"
592            ));
593        }
594        if words_per_second > 15.0 {
595            return Some(format!(
596                "Speaking rate too fast ({words_per_second:.1} words/sec)"
597            ));
598        }
599        if words_per_second < 0.5 && word_count > 0.0 {
600            return Some(format!(
601                "Speaking rate too slow ({words_per_second:.1} words/sec)"
602            ));
603        }
604
605        None
606    }
607
608    /// Detect character set of text
609    fn detect_character_set(&self, text: &str) -> String {
610        if text.is_ascii() {
611            "ASCII".to_string()
612        } else if text.chars().any(|c| {
613            let code = c as u32;
614            (0x4E00..=0x9FFF).contains(&code) || // CJK Unified Ideographs
615            (0x3040..=0x309F).contains(&code) || // Hiragana
616            (0x30A0..=0x30FF).contains(&code) // Katakana
617        }) {
618            "CJK".to_string()
619        } else if text.chars().any(|c| (c as u32) > 255) {
620            "Unicode".to_string()
621        } else {
622            "Latin-1".to_string()
623        }
624    }
625
626    /// Get JSON value type as string
627    fn get_json_value_type(&self, value: &serde_json::Value) -> String {
628        match value {
629            serde_json::Value::String(_) => "string".to_string(),
630            serde_json::Value::Number(_) => "number".to_string(),
631            serde_json::Value::Bool(_) => "boolean".to_string(),
632            serde_json::Value::Array(_) => "array".to_string(),
633            serde_json::Value::Object(_) => "object".to_string(),
634            serde_json::Value::Null => "null".to_string(),
635        }
636    }
637
638    /// Calculate quality score distribution
639    fn calculate_quality_distribution(&self, scores: &[f32]) -> QualityDistribution {
640        if scores.is_empty() {
641            return QualityDistribution {
642                excellent_count: 0,
643                good_count: 0,
644                fair_count: 0,
645                poor_count: 0,
646                average_score: 0.0,
647                score_std_dev: 0.0,
648            };
649        }
650
651        let mut excellent = 0;
652        let mut good = 0;
653        let mut fair = 0;
654        let mut poor = 0;
655
656        for &score in scores {
657            if score > 0.8 {
658                excellent += 1;
659            } else if score > 0.6 {
660                good += 1;
661            } else if score > 0.4 {
662                fair += 1;
663            } else {
664                poor += 1;
665            }
666        }
667
668        let average = scores.iter().sum::<f32>() / scores.len() as f32;
669        let variance =
670            scores.iter().map(|&x| (x - average).powi(2)).sum::<f32>() / scores.len() as f32;
671        let std_dev = variance.sqrt();
672
673        QualityDistribution {
674            excellent_count: excellent,
675            good_count: good,
676            fair_count: fair,
677            poor_count: poor,
678            average_score: average,
679            score_std_dev: std_dev,
680        }
681    }
682}
683
684/// Single sample validation result
685#[derive(Debug, Clone)]
686pub struct SampleValidationResult {
687    pub index: usize,
688    pub is_valid: bool,
689    pub errors: Vec<String>,
690    pub warnings: Vec<String>,
691}
692
693// Helper trait implementations for getting errors and warnings from validation results
694impl FormatValidationResult {
695    fn get_errors(&self) -> Vec<String> {
696        let mut errors = Vec::new();
697        if !self.audio_format_consistent {
698            errors.push("Inconsistent audio formats detected".to_string());
699        }
700        if !self.language_consistent {
701            errors.push("Multiple languages detected in dataset".to_string());
702        }
703        errors.extend(self.text_encoding_issues.clone());
704        errors
705    }
706
707    fn get_warnings(&self) -> Vec<String> {
708        let mut warnings = Vec::new();
709        if self.sample_rates.len() > 1 {
710            warnings.push(format!(
711                "Multiple sample rates detected: {:?}",
712                self.sample_rates
713            ));
714        }
715        if self.channel_counts.len() > 1 {
716            warnings.push(format!(
717                "Multiple channel counts detected: {:?}",
718                self.channel_counts
719            ));
720        }
721        warnings
722    }
723}
724
725impl IntegrityValidationResult {
726    fn get_errors(&self) -> Vec<String> {
727        let mut errors = Vec::new();
728        if !self.unique_ids {
729            let duplicate_ids = &self.duplicate_ids;
730            errors.push(format!("Duplicate IDs found: {duplicate_ids:?}"));
731        }
732        if !self.corrupted_samples.is_empty() {
733            let corrupted_samples = &self.corrupted_samples;
734            errors.push(format!("Corrupted samples found: {corrupted_samples:?}"));
735        }
736        for (field, indices) in &self.missing_fields {
737            errors.push(format!("Missing field '{field}' in samples: {indices:?}"));
738        }
739        errors
740    }
741
742    fn get_warnings(&self) -> Vec<String> {
743        let mut warnings = Vec::new();
744        if !self.metadata_consistent {
745            warnings.push("Inconsistent metadata types detected".to_string());
746        }
747        warnings
748    }
749}
750
751impl QualityValidationResult {
752    fn get_errors(&self) -> Vec<String> {
753        let mut errors = Vec::new();
754        if !self.corrupted_audio_samples.is_empty() {
755            errors.push(format!(
756                "Corrupted audio samples: {:?}",
757                self.corrupted_audio_samples
758            ));
759        }
760        errors
761    }
762
763    fn get_warnings(&self) -> Vec<String> {
764        let mut warnings = Vec::new();
765        if !self.low_quality_samples.is_empty() {
766            warnings.push(format!(
767                "Low quality samples: {:?}",
768                self.low_quality_samples
769            ));
770        }
771        if !self.silent_samples.is_empty() {
772            warnings.push(format!(
773                "Silent samples detected: {:?}",
774                self.silent_samples
775            ));
776        }
777        if !self.clipped_samples.is_empty() {
778            warnings.push(format!(
779                "Clipped samples detected: {:?}",
780                self.clipped_samples
781            ));
782        }
783        if !self.alignment_issues.is_empty() {
784            warnings.push(format!(
785                "Audio-text alignment issues: {:?}",
786                self.alignment_issues
787            ));
788        }
789        warnings
790    }
791}
792
793impl Default for DatasetValidator {
794    fn default() -> Self {
795        Self::new()
796    }
797}
798
799// Helper trait for converting samples to references
800impl AsRef<DatasetSample> for DatasetSample {
801    fn as_ref(&self) -> &DatasetSample {
802        self
803    }
804}
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809    use crate::{AudioData, LanguageCode};
810
811    fn create_test_sample(id: &str, text: &str, duration: f32) -> DatasetSample {
812        let sample_rate = 22050;
813        let num_samples = (duration * sample_rate as f32) as usize;
814        let audio = AudioData::new(vec![0.1; num_samples], sample_rate, 1);
815
816        DatasetSample::new(id.to_string(), text.to_string(), audio, LanguageCode::EnUs)
817    }
818
819    #[test]
820    fn test_sample_validation() {
821        let validator = DatasetValidator::new();
822        let sample = create_test_sample("test-001", "Hello world", 2.0);
823
824        let result = validator.validate_sample(&sample, 0).unwrap();
825        assert!(result.is_valid);
826        assert!(result.errors.is_empty());
827    }
828
829    #[test]
830    fn test_sample_validation_errors() {
831        let validator = DatasetValidator::new();
832        let sample = create_test_sample("test-001", "", 0.05); // Too short, empty text
833
834        let result = validator.validate_sample(&sample, 0).unwrap();
835        assert!(!result.is_valid);
836        assert!(!result.errors.is_empty());
837    }
838
839    #[test]
840    fn test_dataset_validation() {
841        let validator = DatasetValidator::new();
842        let samples = vec![
843            create_test_sample("test-001", "Hello world", 2.0),
844            create_test_sample("test-002", "Good morning", 1.5),
845            create_test_sample("test-003", "How are you?", 1.8),
846        ];
847
848        let result = validator.validate_dataset(&samples).unwrap();
849        assert!(result.basic_report.is_valid);
850        assert_eq!(result.validation_stats.total_samples, 3);
851        assert_eq!(result.validation_stats.passed_samples, 3);
852    }
853
854    #[test]
855    fn test_duplicate_id_detection() {
856        let validator = DatasetValidator::new();
857        let samples = vec![
858            create_test_sample("test-001", "Hello world", 2.0),
859            create_test_sample("test-001", "Duplicate ID", 1.5), // Duplicate ID
860            create_test_sample("test-003", "How are you?", 1.8),
861        ];
862
863        let result = validator.validate_dataset(&samples).unwrap();
864        assert!(!result.basic_report.is_valid);
865        assert!(!result.integrity_validation.unique_ids);
866        assert!(result
867            .integrity_validation
868            .duplicate_ids
869            .contains(&"test-001".to_string()));
870    }
871
872    #[test]
873    fn test_audio_text_alignment() {
874        let validator = DatasetValidator::new();
875
876        // Create sample with unrealistic speaking rate
877        let sample = create_test_sample(
878            "test-001",
879            "This is way too much text for the duration",
880            0.1,
881        );
882
883        let alignment_issue = validator.check_audio_text_alignment(&sample);
884        assert!(alignment_issue.is_some());
885        assert!(alignment_issue.unwrap().contains("Speaking rate too fast"));
886    }
887
888    #[test]
889    fn test_quality_distribution() {
890        let validator = DatasetValidator::new();
891        let scores = vec![0.9, 0.8, 0.7, 0.5, 0.3, 0.1];
892
893        let distribution = validator.calculate_quality_distribution(&scores);
894        assert_eq!(distribution.excellent_count, 1); // 0.9
895        assert_eq!(distribution.good_count, 2); // 0.8, 0.7
896        assert_eq!(distribution.fair_count, 1); // 0.5
897        assert_eq!(distribution.poor_count, 2); // 0.3, 0.1
898    }
899}