1use crate::error::DatasetResult;
2use crate::types::{PreferencePair, TrainingExample, TrainingRole};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum IssueSeverity {
7 Error,
9 Warning,
11}
12
13#[derive(Debug, Clone)]
15pub struct ValidationIssue {
16 pub example_id: String,
18 pub severity: IssueSeverity,
20 pub message: String,
22 pub line_number: Option<usize>,
24 pub suggestion: Option<String>,
26}
27
28#[derive(Debug, Clone)]
30pub struct ValidationReport {
31 pub issues: Vec<ValidationIssue>,
33 pub total_examples: usize,
35 pub valid_examples: usize,
37}
38
39impl ValidationReport {
40 pub fn has_errors(&self) -> bool {
42 self.issues
43 .iter()
44 .any(|i| i.severity == IssueSeverity::Error)
45 }
46
47 pub fn error_count(&self) -> usize {
49 self.issues
50 .iter()
51 .filter(|i| i.severity == IssueSeverity::Error)
52 .count()
53 }
54
55 pub fn warning_count(&self) -> usize {
57 self.issues
58 .iter()
59 .filter(|i| i.severity == IssueSeverity::Warning)
60 .count()
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct ValidatorConfig {
67 pub min_messages: usize,
69 pub max_messages: usize,
71 pub max_tokens: usize,
73 pub require_assistant_last: bool,
75 pub require_system_message: bool,
77 pub reject_empty_content: bool,
79 pub require_alternating_turns: bool,
81}
82
83impl Default for ValidatorConfig {
84 fn default() -> Self {
85 Self {
86 min_messages: 2,
87 max_messages: 1000,
88 max_tokens: 32768,
89 require_assistant_last: true,
90 require_system_message: false,
91 reject_empty_content: true,
92 require_alternating_turns: false,
93 }
94 }
95}
96
97pub struct DataValidator {
99 config: ValidatorConfig,
100}
101
102impl DataValidator {
103 pub fn new(config: ValidatorConfig) -> Self {
105 Self { config }
106 }
107
108 pub fn with_defaults() -> Self {
110 Self::new(ValidatorConfig::default())
111 }
112
113 pub fn validate_example(&self, example: &TrainingExample) -> Vec<ValidationIssue> {
115 let mut issues = Vec::new();
116 let id = &example.id;
117
118 if example.messages.len() < self.config.min_messages {
120 issues.push(ValidationIssue {
121 example_id: id.clone(),
122 severity: IssueSeverity::Error,
123 message: format!(
124 "Too few messages: {} (min: {})",
125 example.messages.len(),
126 self.config.min_messages
127 ),
128 line_number: None,
129 suggestion: None,
130 });
131 }
132
133 if example.messages.len() > self.config.max_messages {
134 issues.push(ValidationIssue {
135 example_id: id.clone(),
136 severity: IssueSeverity::Warning,
137 message: format!(
138 "Too many messages: {} (max: {})",
139 example.messages.len(),
140 self.config.max_messages
141 ),
142 line_number: None,
143 suggestion: None,
144 });
145 }
146
147 let tokens = example.estimated_tokens();
149 if tokens > self.config.max_tokens {
150 issues.push(ValidationIssue {
151 example_id: id.clone(),
152 severity: IssueSeverity::Warning,
153 message: format!(
154 "Estimated tokens ({}) exceeds max ({})",
155 tokens, self.config.max_tokens
156 ),
157 line_number: None,
158 suggestion: None,
159 });
160 }
161
162 if self.config.require_system_message && !example.has_system_message() {
164 issues.push(ValidationIssue {
165 example_id: id.clone(),
166 severity: IssueSeverity::Warning,
167 message: "Missing system message".to_string(),
168 line_number: None,
169 suggestion: None,
170 });
171 }
172
173 if self.config.require_assistant_last && !example.ends_with_assistant() {
175 issues.push(ValidationIssue {
176 example_id: id.clone(),
177 severity: IssueSeverity::Error,
178 message: "Last message must be from assistant".to_string(),
179 line_number: None,
180 suggestion: None,
181 });
182 }
183
184 if self.config.reject_empty_content {
186 for (i, msg) in example.messages.iter().enumerate() {
187 if msg.content.trim().is_empty() && msg.tool_calls.is_none() {
188 issues.push(ValidationIssue {
189 example_id: id.clone(),
190 severity: IssueSeverity::Error,
191 message: format!("Message {} has empty content", i),
192 line_number: None,
193 suggestion: None,
194 });
195 }
196 }
197 }
198
199 if self.config.require_alternating_turns {
201 let non_system: Vec<_> = example
202 .messages
203 .iter()
204 .filter(|m| m.role != TrainingRole::System && m.role != TrainingRole::Tool)
205 .collect();
206 for window in non_system.windows(2) {
207 if window[0].role == window[1].role {
208 issues.push(ValidationIssue {
209 example_id: id.clone(),
210 severity: IssueSeverity::Warning,
211 message: format!(
212 "Consecutive {} messages (expected alternating)",
213 window[0].role
214 ),
215 line_number: None,
216 suggestion: None,
217 });
218 break;
219 }
220 }
221 }
222
223 issues
224 }
225
226 pub fn validate_preference(&self, pair: &PreferencePair) -> Vec<ValidationIssue> {
228 let mut issues = Vec::new();
229 let id = &pair.id;
230
231 if pair.prompt.is_empty() {
232 issues.push(ValidationIssue {
233 example_id: id.clone(),
234 severity: IssueSeverity::Error,
235 message: "Preference pair has empty prompt".to_string(),
236 line_number: None,
237 suggestion: Some("Add at least one prompt message".to_string()),
238 });
239 }
240
241 if pair.chosen.is_empty() {
242 issues.push(ValidationIssue {
243 example_id: id.clone(),
244 severity: IssueSeverity::Error,
245 message: "Preference pair has empty chosen response".to_string(),
246 line_number: None,
247 suggestion: Some("Add at least one chosen response message".to_string()),
248 });
249 }
250
251 if pair.rejected.is_empty() {
252 issues.push(ValidationIssue {
253 example_id: id.clone(),
254 severity: IssueSeverity::Error,
255 message: "Preference pair has empty rejected response".to_string(),
256 line_number: None,
257 suggestion: Some("Add at least one rejected response message".to_string()),
258 });
259 }
260
261 if self.config.reject_empty_content {
263 for (i, msg) in pair.prompt.iter().enumerate() {
264 if msg.content.trim().is_empty() {
265 issues.push(ValidationIssue {
266 example_id: id.clone(),
267 severity: IssueSeverity::Error,
268 message: format!("Prompt message {} has empty content", i),
269 line_number: None,
270 suggestion: None,
271 });
272 }
273 }
274 for (i, msg) in pair.chosen.iter().enumerate() {
275 if msg.content.trim().is_empty() {
276 issues.push(ValidationIssue {
277 example_id: id.clone(),
278 severity: IssueSeverity::Error,
279 message: format!("Chosen message {} has empty content", i),
280 line_number: None,
281 suggestion: None,
282 });
283 }
284 }
285 for (i, msg) in pair.rejected.iter().enumerate() {
286 if msg.content.trim().is_empty() {
287 issues.push(ValidationIssue {
288 example_id: id.clone(),
289 severity: IssueSeverity::Error,
290 message: format!("Rejected message {} has empty content", i),
291 line_number: None,
292 suggestion: None,
293 });
294 }
295 }
296 }
297
298 if !pair.chosen.is_empty() && !pair.rejected.is_empty() {
300 let chosen_text: String = pair
301 .chosen
302 .iter()
303 .map(|m| m.content.as_str())
304 .collect::<Vec<_>>()
305 .join("");
306 let rejected_text: String = pair
307 .rejected
308 .iter()
309 .map(|m| m.content.as_str())
310 .collect::<Vec<_>>()
311 .join("");
312 if chosen_text == rejected_text {
313 issues.push(ValidationIssue {
314 example_id: id.clone(),
315 severity: IssueSeverity::Warning,
316 message: "Chosen and rejected responses are identical".to_string(),
317 line_number: None,
318 suggestion: Some("Ensure chosen and rejected responses differ".to_string()),
319 });
320 }
321
322 let chosen_len = chosen_text.len().max(1);
324 let rejected_len = rejected_text.len().max(1);
325 let ratio = chosen_len.max(rejected_len) as f64 / chosen_len.min(rejected_len) as f64;
326 if ratio > 10.0 {
327 issues.push(ValidationIssue {
328 example_id: id.clone(),
329 severity: IssueSeverity::Warning,
330 message: format!(
331 "Length ratio between chosen and rejected is {:.1}x (>10x)",
332 ratio
333 ),
334 line_number: None,
335 suggestion: Some(
336 "Large length differences may indicate data quality issues".to_string(),
337 ),
338 });
339 }
340 }
341
342 let tokens = pair.estimated_tokens();
344 if tokens > self.config.max_tokens {
345 issues.push(ValidationIssue {
346 example_id: id.clone(),
347 severity: IssueSeverity::Warning,
348 message: format!(
349 "Estimated tokens ({}) exceeds max ({})",
350 tokens, self.config.max_tokens
351 ),
352 line_number: None,
353 suggestion: None,
354 });
355 }
356
357 issues
358 }
359
360 pub fn validate_preference_dataset(
362 &self,
363 pairs: &[PreferencePair],
364 ) -> DatasetResult<ValidationReport> {
365 let mut all_issues = Vec::new();
366 let mut valid_count = 0;
367
368 for pair in pairs {
369 let issues = self.validate_preference(pair);
370 if issues.iter().all(|i| i.severity != IssueSeverity::Error) {
371 valid_count += 1;
372 }
373 all_issues.extend(issues);
374 }
375
376 tracing::debug!(
377 "Validated {} preference pairs: {} valid, {} issues",
378 pairs.len(),
379 valid_count,
380 all_issues.len()
381 );
382
383 Ok(ValidationReport {
384 issues: all_issues,
385 total_examples: pairs.len(),
386 valid_examples: valid_count,
387 })
388 }
389
390 pub fn validate_dataset(
392 &self,
393 examples: &[TrainingExample],
394 ) -> DatasetResult<ValidationReport> {
395 let mut all_issues = Vec::new();
396 let mut valid_count = 0;
397
398 for example in examples {
399 let issues = self.validate_example(example);
400 if issues.iter().all(|i| i.severity != IssueSeverity::Error) {
401 valid_count += 1;
402 }
403 all_issues.extend(issues);
404 }
405
406 tracing::debug!(
407 "Validated {} examples: {} valid, {} issues",
408 examples.len(),
409 valid_count,
410 all_issues.len()
411 );
412
413 Ok(ValidationReport {
414 issues: all_issues,
415 total_examples: examples.len(),
416 valid_examples: valid_count,
417 })
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::types::TrainingMessage;
425
426 #[test]
427 fn test_valid_example() {
428 let validator = DataValidator::with_defaults();
429 let example = TrainingExample::with_id(
430 "test",
431 vec![
432 TrainingMessage::user("Hello"),
433 TrainingMessage::assistant("Hi!"),
434 ],
435 );
436 let issues = validator.validate_example(&example);
437 assert!(issues.is_empty());
438 }
439
440 #[test]
441 fn test_too_few_messages() {
442 let validator = DataValidator::with_defaults();
443 let example = TrainingExample::with_id("test", vec![TrainingMessage::user("Hello")]);
444 let issues = validator.validate_example(&example);
445 assert!(issues.iter().any(|i| i.message.contains("Too few")));
446 assert!(
447 issues
448 .iter()
449 .any(|i| i.message.contains("must be from assistant"))
450 );
451 }
452
453 #[test]
454 fn test_empty_content_rejected() {
455 let validator = DataValidator::with_defaults();
456 let example = TrainingExample::with_id(
457 "test",
458 vec![TrainingMessage::user(""), TrainingMessage::assistant("Hi")],
459 );
460 let issues = validator.validate_example(&example);
461 assert!(issues.iter().any(|i| i.message.contains("empty content")));
462 }
463
464 #[test]
465 fn test_validation_report() {
466 let validator = DataValidator::with_defaults();
467 let examples = vec![
468 TrainingExample::with_id(
469 "good",
470 vec![TrainingMessage::user("Q"), TrainingMessage::assistant("A")],
471 ),
472 TrainingExample::with_id("bad", vec![TrainingMessage::user("Q")]),
473 ];
474 let report = validator.validate_dataset(&examples).unwrap();
475 assert_eq!(report.total_examples, 2);
476 assert_eq!(report.valid_examples, 1);
477 assert!(report.has_errors());
478 }
479
480 #[test]
481 fn test_preference_validation_identical() {
482 let validator = DataValidator::with_defaults();
483 let pair = PreferencePair::new(
484 vec![TrainingMessage::user("Q")],
485 vec![TrainingMessage::assistant("Same")],
486 vec![TrainingMessage::assistant("Same")],
487 );
488 let issues = validator.validate_preference(&pair);
489 assert!(issues.iter().any(|i| i.message.contains("identical")));
490 }
491
492 #[test]
493 fn test_preference_validation_empty_content() {
494 let validator = DataValidator::with_defaults();
495 let pair = PreferencePair::new(
496 vec![TrainingMessage::user("")],
497 vec![TrainingMessage::assistant("Good")],
498 vec![TrainingMessage::assistant("Bad")],
499 );
500 let issues = validator.validate_preference(&pair);
501 assert!(issues.iter().any(|i| i.message.contains("empty content")));
502 }
503
504 #[test]
505 fn test_validate_preference_dataset() {
506 let validator = DataValidator::with_defaults();
507 let pairs = vec![
508 PreferencePair::new(
509 vec![TrainingMessage::user("Q")],
510 vec![TrainingMessage::assistant("Good")],
511 vec![TrainingMessage::assistant("Bad")],
512 ),
513 PreferencePair::new(
514 vec![],
515 vec![TrainingMessage::assistant("Good")],
516 vec![TrainingMessage::assistant("Bad")],
517 ),
518 ];
519 let report = validator.validate_preference_dataset(&pairs).unwrap();
520 assert_eq!(report.total_examples, 2);
521 assert_eq!(report.valid_examples, 1);
522 }
523}