Skip to main content

brainwires_datasets/quality/
validator.rs

1use crate::error::DatasetResult;
2use crate::types::{PreferencePair, TrainingExample, TrainingRole};
3
4/// Validation issue severity.
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum IssueSeverity {
7    /// A blocking error that makes the example invalid.
8    Error,
9    /// A non-blocking warning about potential issues.
10    Warning,
11}
12
13/// A single validation issue found in a dataset example.
14#[derive(Debug, Clone)]
15pub struct ValidationIssue {
16    /// ID of the example where the issue was found.
17    pub example_id: String,
18    /// Severity of the issue.
19    pub severity: IssueSeverity,
20    /// Human-readable description of the issue.
21    pub message: String,
22    /// Optional line number where the issue was found.
23    pub line_number: Option<usize>,
24    /// Optional suggestion for how to fix the issue.
25    pub suggestion: Option<String>,
26}
27
28/// Result of validating a dataset.
29#[derive(Debug, Clone)]
30pub struct ValidationReport {
31    /// All issues found during validation.
32    pub issues: Vec<ValidationIssue>,
33    /// Total number of examples validated.
34    pub total_examples: usize,
35    /// Number of examples that passed without errors.
36    pub valid_examples: usize,
37}
38
39impl ValidationReport {
40    /// Return true if any error-level issues exist.
41    pub fn has_errors(&self) -> bool {
42        self.issues
43            .iter()
44            .any(|i| i.severity == IssueSeverity::Error)
45    }
46
47    /// Count the number of error-level issues.
48    pub fn error_count(&self) -> usize {
49        self.issues
50            .iter()
51            .filter(|i| i.severity == IssueSeverity::Error)
52            .count()
53    }
54
55    /// Count the number of warning-level issues.
56    pub fn warning_count(&self) -> usize {
57        self.issues
58            .iter()
59            .filter(|i| i.severity == IssueSeverity::Warning)
60            .count()
61    }
62}
63
64/// Configuration for dataset validation.
65#[derive(Debug, Clone)]
66pub struct ValidatorConfig {
67    /// Minimum messages per example.
68    pub min_messages: usize,
69    /// Maximum messages per example.
70    pub max_messages: usize,
71    /// Maximum tokens per example (estimated).
72    pub max_tokens: usize,
73    /// Require the last message to be from assistant.
74    pub require_assistant_last: bool,
75    /// Require a system message.
76    pub require_system_message: bool,
77    /// Reject empty content.
78    pub reject_empty_content: bool,
79    /// Require alternating user/assistant turns after system.
80    pub require_alternating_turns: bool,
81}
82
83impl Default for ValidatorConfig {
84    fn default() -> Self {
85        Self {
86            min_messages: 2,
87            max_messages: 1000,
88            max_tokens: 32768,
89            require_assistant_last: true,
90            require_system_message: false,
91            reject_empty_content: true,
92            require_alternating_turns: false,
93        }
94    }
95}
96
97/// Validates training examples against configurable rules.
98pub struct DataValidator {
99    config: ValidatorConfig,
100}
101
102impl DataValidator {
103    /// Create a new validator with the given configuration.
104    pub fn new(config: ValidatorConfig) -> Self {
105        Self { config }
106    }
107
108    /// Create a new validator with default configuration.
109    pub fn with_defaults() -> Self {
110        Self::new(ValidatorConfig::default())
111    }
112
113    /// Validate a single training example.
114    pub fn validate_example(&self, example: &TrainingExample) -> Vec<ValidationIssue> {
115        let mut issues = Vec::new();
116        let id = &example.id;
117
118        // Check message count
119        if example.messages.len() < self.config.min_messages {
120            issues.push(ValidationIssue {
121                example_id: id.clone(),
122                severity: IssueSeverity::Error,
123                message: format!(
124                    "Too few messages: {} (min: {})",
125                    example.messages.len(),
126                    self.config.min_messages
127                ),
128                line_number: None,
129                suggestion: None,
130            });
131        }
132
133        if example.messages.len() > self.config.max_messages {
134            issues.push(ValidationIssue {
135                example_id: id.clone(),
136                severity: IssueSeverity::Warning,
137                message: format!(
138                    "Too many messages: {} (max: {})",
139                    example.messages.len(),
140                    self.config.max_messages
141                ),
142                line_number: None,
143                suggestion: None,
144            });
145        }
146
147        // Check token count
148        let tokens = example.estimated_tokens();
149        if tokens > self.config.max_tokens {
150            issues.push(ValidationIssue {
151                example_id: id.clone(),
152                severity: IssueSeverity::Warning,
153                message: format!(
154                    "Estimated tokens ({}) exceeds max ({})",
155                    tokens, self.config.max_tokens
156                ),
157                line_number: None,
158                suggestion: None,
159            });
160        }
161
162        // Check system message requirement
163        if self.config.require_system_message && !example.has_system_message() {
164            issues.push(ValidationIssue {
165                example_id: id.clone(),
166                severity: IssueSeverity::Warning,
167                message: "Missing system message".to_string(),
168                line_number: None,
169                suggestion: None,
170            });
171        }
172
173        // Check last message is assistant
174        if self.config.require_assistant_last && !example.ends_with_assistant() {
175            issues.push(ValidationIssue {
176                example_id: id.clone(),
177                severity: IssueSeverity::Error,
178                message: "Last message must be from assistant".to_string(),
179                line_number: None,
180                suggestion: None,
181            });
182        }
183
184        // Check empty content
185        if self.config.reject_empty_content {
186            for (i, msg) in example.messages.iter().enumerate() {
187                if msg.content.trim().is_empty() && msg.tool_calls.is_none() {
188                    issues.push(ValidationIssue {
189                        example_id: id.clone(),
190                        severity: IssueSeverity::Error,
191                        message: format!("Message {} has empty content", i),
192                        line_number: None,
193                        suggestion: None,
194                    });
195                }
196            }
197        }
198
199        // Check alternating turns
200        if self.config.require_alternating_turns {
201            let non_system: Vec<_> = example
202                .messages
203                .iter()
204                .filter(|m| m.role != TrainingRole::System && m.role != TrainingRole::Tool)
205                .collect();
206            for window in non_system.windows(2) {
207                if window[0].role == window[1].role {
208                    issues.push(ValidationIssue {
209                        example_id: id.clone(),
210                        severity: IssueSeverity::Warning,
211                        message: format!(
212                            "Consecutive {} messages (expected alternating)",
213                            window[0].role
214                        ),
215                        line_number: None,
216                        suggestion: None,
217                    });
218                    break;
219                }
220            }
221        }
222
223        issues
224    }
225
226    /// Validate a preference pair.
227    pub fn validate_preference(&self, pair: &PreferencePair) -> Vec<ValidationIssue> {
228        let mut issues = Vec::new();
229        let id = &pair.id;
230
231        if pair.prompt.is_empty() {
232            issues.push(ValidationIssue {
233                example_id: id.clone(),
234                severity: IssueSeverity::Error,
235                message: "Preference pair has empty prompt".to_string(),
236                line_number: None,
237                suggestion: Some("Add at least one prompt message".to_string()),
238            });
239        }
240
241        if pair.chosen.is_empty() {
242            issues.push(ValidationIssue {
243                example_id: id.clone(),
244                severity: IssueSeverity::Error,
245                message: "Preference pair has empty chosen response".to_string(),
246                line_number: None,
247                suggestion: Some("Add at least one chosen response message".to_string()),
248            });
249        }
250
251        if pair.rejected.is_empty() {
252            issues.push(ValidationIssue {
253                example_id: id.clone(),
254                severity: IssueSeverity::Error,
255                message: "Preference pair has empty rejected response".to_string(),
256                line_number: None,
257                suggestion: Some("Add at least one rejected response message".to_string()),
258            });
259        }
260
261        // Check empty content in messages
262        if self.config.reject_empty_content {
263            for (i, msg) in pair.prompt.iter().enumerate() {
264                if msg.content.trim().is_empty() {
265                    issues.push(ValidationIssue {
266                        example_id: id.clone(),
267                        severity: IssueSeverity::Error,
268                        message: format!("Prompt message {} has empty content", i),
269                        line_number: None,
270                        suggestion: None,
271                    });
272                }
273            }
274            for (i, msg) in pair.chosen.iter().enumerate() {
275                if msg.content.trim().is_empty() {
276                    issues.push(ValidationIssue {
277                        example_id: id.clone(),
278                        severity: IssueSeverity::Error,
279                        message: format!("Chosen message {} has empty content", i),
280                        line_number: None,
281                        suggestion: None,
282                    });
283                }
284            }
285            for (i, msg) in pair.rejected.iter().enumerate() {
286                if msg.content.trim().is_empty() {
287                    issues.push(ValidationIssue {
288                        example_id: id.clone(),
289                        severity: IssueSeverity::Error,
290                        message: format!("Rejected message {} has empty content", i),
291                        line_number: None,
292                        suggestion: None,
293                    });
294                }
295            }
296        }
297
298        // Warn if chosen == rejected
299        if !pair.chosen.is_empty() && !pair.rejected.is_empty() {
300            let chosen_text: String = pair
301                .chosen
302                .iter()
303                .map(|m| m.content.as_str())
304                .collect::<Vec<_>>()
305                .join("");
306            let rejected_text: String = pair
307                .rejected
308                .iter()
309                .map(|m| m.content.as_str())
310                .collect::<Vec<_>>()
311                .join("");
312            if chosen_text == rejected_text {
313                issues.push(ValidationIssue {
314                    example_id: id.clone(),
315                    severity: IssueSeverity::Warning,
316                    message: "Chosen and rejected responses are identical".to_string(),
317                    line_number: None,
318                    suggestion: Some("Ensure chosen and rejected responses differ".to_string()),
319                });
320            }
321
322            // Warn if length ratio > 10x
323            let chosen_len = chosen_text.len().max(1);
324            let rejected_len = rejected_text.len().max(1);
325            let ratio = chosen_len.max(rejected_len) as f64 / chosen_len.min(rejected_len) as f64;
326            if ratio > 10.0 {
327                issues.push(ValidationIssue {
328                    example_id: id.clone(),
329                    severity: IssueSeverity::Warning,
330                    message: format!(
331                        "Length ratio between chosen and rejected is {:.1}x (>10x)",
332                        ratio
333                    ),
334                    line_number: None,
335                    suggestion: Some(
336                        "Large length differences may indicate data quality issues".to_string(),
337                    ),
338                });
339            }
340        }
341
342        // Token count check
343        let tokens = pair.estimated_tokens();
344        if tokens > self.config.max_tokens {
345            issues.push(ValidationIssue {
346                example_id: id.clone(),
347                severity: IssueSeverity::Warning,
348                message: format!(
349                    "Estimated tokens ({}) exceeds max ({})",
350                    tokens, self.config.max_tokens
351                ),
352                line_number: None,
353                suggestion: None,
354            });
355        }
356
357        issues
358    }
359
360    /// Validate a full preference dataset, producing a report.
361    pub fn validate_preference_dataset(
362        &self,
363        pairs: &[PreferencePair],
364    ) -> DatasetResult<ValidationReport> {
365        let mut all_issues = Vec::new();
366        let mut valid_count = 0;
367
368        for pair in pairs {
369            let issues = self.validate_preference(pair);
370            if issues.iter().all(|i| i.severity != IssueSeverity::Error) {
371                valid_count += 1;
372            }
373            all_issues.extend(issues);
374        }
375
376        tracing::debug!(
377            "Validated {} preference pairs: {} valid, {} issues",
378            pairs.len(),
379            valid_count,
380            all_issues.len()
381        );
382
383        Ok(ValidationReport {
384            issues: all_issues,
385            total_examples: pairs.len(),
386            valid_examples: valid_count,
387        })
388    }
389
390    /// Validate a full dataset, producing a report.
391    pub fn validate_dataset(
392        &self,
393        examples: &[TrainingExample],
394    ) -> DatasetResult<ValidationReport> {
395        let mut all_issues = Vec::new();
396        let mut valid_count = 0;
397
398        for example in examples {
399            let issues = self.validate_example(example);
400            if issues.iter().all(|i| i.severity != IssueSeverity::Error) {
401                valid_count += 1;
402            }
403            all_issues.extend(issues);
404        }
405
406        tracing::debug!(
407            "Validated {} examples: {} valid, {} issues",
408            examples.len(),
409            valid_count,
410            all_issues.len()
411        );
412
413        Ok(ValidationReport {
414            issues: all_issues,
415            total_examples: examples.len(),
416            valid_examples: valid_count,
417        })
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::types::TrainingMessage;
425
426    #[test]
427    fn test_valid_example() {
428        let validator = DataValidator::with_defaults();
429        let example = TrainingExample::with_id(
430            "test",
431            vec![
432                TrainingMessage::user("Hello"),
433                TrainingMessage::assistant("Hi!"),
434            ],
435        );
436        let issues = validator.validate_example(&example);
437        assert!(issues.is_empty());
438    }
439
440    #[test]
441    fn test_too_few_messages() {
442        let validator = DataValidator::with_defaults();
443        let example = TrainingExample::with_id("test", vec![TrainingMessage::user("Hello")]);
444        let issues = validator.validate_example(&example);
445        assert!(issues.iter().any(|i| i.message.contains("Too few")));
446        assert!(
447            issues
448                .iter()
449                .any(|i| i.message.contains("must be from assistant"))
450        );
451    }
452
453    #[test]
454    fn test_empty_content_rejected() {
455        let validator = DataValidator::with_defaults();
456        let example = TrainingExample::with_id(
457            "test",
458            vec![TrainingMessage::user(""), TrainingMessage::assistant("Hi")],
459        );
460        let issues = validator.validate_example(&example);
461        assert!(issues.iter().any(|i| i.message.contains("empty content")));
462    }
463
464    #[test]
465    fn test_validation_report() {
466        let validator = DataValidator::with_defaults();
467        let examples = vec![
468            TrainingExample::with_id(
469                "good",
470                vec![TrainingMessage::user("Q"), TrainingMessage::assistant("A")],
471            ),
472            TrainingExample::with_id("bad", vec![TrainingMessage::user("Q")]),
473        ];
474        let report = validator.validate_dataset(&examples).unwrap();
475        assert_eq!(report.total_examples, 2);
476        assert_eq!(report.valid_examples, 1);
477        assert!(report.has_errors());
478    }
479
480    #[test]
481    fn test_preference_validation_identical() {
482        let validator = DataValidator::with_defaults();
483        let pair = PreferencePair::new(
484            vec![TrainingMessage::user("Q")],
485            vec![TrainingMessage::assistant("Same")],
486            vec![TrainingMessage::assistant("Same")],
487        );
488        let issues = validator.validate_preference(&pair);
489        assert!(issues.iter().any(|i| i.message.contains("identical")));
490    }
491
492    #[test]
493    fn test_preference_validation_empty_content() {
494        let validator = DataValidator::with_defaults();
495        let pair = PreferencePair::new(
496            vec![TrainingMessage::user("")],
497            vec![TrainingMessage::assistant("Good")],
498            vec![TrainingMessage::assistant("Bad")],
499        );
500        let issues = validator.validate_preference(&pair);
501        assert!(issues.iter().any(|i| i.message.contains("empty content")));
502    }
503
504    #[test]
505    fn test_validate_preference_dataset() {
506        let validator = DataValidator::with_defaults();
507        let pairs = vec![
508            PreferencePair::new(
509                vec![TrainingMessage::user("Q")],
510                vec![TrainingMessage::assistant("Good")],
511                vec![TrainingMessage::assistant("Bad")],
512            ),
513            PreferencePair::new(
514                vec![],
515                vec![TrainingMessage::assistant("Good")],
516                vec![TrainingMessage::assistant("Bad")],
517            ),
518        ];
519        let report = validator.validate_preference_dataset(&pairs).unwrap();
520        assert_eq!(report.total_examples, 2);
521        assert_eq!(report.valid_examples, 1);
522    }
523}