1use crate::data::VerifiedTuple;
7use crate::ml::CommitFeatures;
8use crate::Language;
9use serde::{Deserialize, Serialize};
10use std::path::Path;
11
12#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct TrainingCorpus {
15 pub tuples: Vec<VerifiedTuple>,
17 pub features: Vec<CommitFeatures>,
19 pub metadata: CorpusMetadata,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct CorpusMetadata {
26 pub version: String,
28 pub source_language: Language,
30 pub target_language: Language,
32 pub count: usize,
34 pub correct_count: usize,
36 pub incorrect_count: usize,
38 pub timestamp: u64,
40}
41
42impl Default for CorpusMetadata {
43 fn default() -> Self {
44 Self {
45 version: String::new(),
46 source_language: Language::Python,
47 target_language: Language::Rust,
48 count: 0,
49 correct_count: 0,
50 incorrect_count: 0,
51 timestamp: 0,
52 }
53 }
54}
55
56impl CorpusMetadata {
57 #[must_use]
59 pub fn accuracy(&self) -> f64 {
60 if self.count == 0 {
61 0.0
62 } else {
63 self.correct_count as f64 / self.count as f64
64 }
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70pub enum CorpusFormat {
71 Json,
73 #[default]
75 Jsonl,
76 Parquet,
78}
79
80#[derive(Debug, Default)]
82pub struct CorpusManager {
83 corpus: TrainingCorpus,
84}
85
86impl CorpusManager {
87 #[must_use]
89 pub fn new() -> Self {
90 Self::default()
91 }
92
93 #[must_use]
95 pub fn from_corpus(corpus: TrainingCorpus) -> Self {
96 Self { corpus }
97 }
98
99 #[must_use]
101 pub fn corpus(&self) -> &TrainingCorpus {
102 &self.corpus
103 }
104
105 pub fn add(&mut self, tuple: VerifiedTuple, features: CommitFeatures) {
107 if tuple.is_correct {
108 self.corpus.metadata.correct_count += 1;
109 } else {
110 self.corpus.metadata.incorrect_count += 1;
111 }
112 self.corpus.metadata.count += 1;
113 self.corpus.tuples.push(tuple);
114 self.corpus.features.push(features);
115 }
116
117 pub fn add_tuples(&mut self, tuples: Vec<VerifiedTuple>) {
119 for tuple in tuples {
120 let features = CommitFeatures::default();
121 self.add(tuple, features);
122 }
123 }
124
125 pub fn set_metadata(&mut self, metadata: CorpusMetadata) {
127 self.corpus.metadata = metadata;
128 }
129
130 pub fn export(&self, path: &Path, format: CorpusFormat) -> std::io::Result<()> {
136 match format {
137 CorpusFormat::Json => self.export_json(path),
138 CorpusFormat::Jsonl => self.export_jsonl(path),
139 CorpusFormat::Parquet => {
140 self.export_jsonl(path)
142 }
143 }
144 }
145
146 pub fn load(path: &Path) -> std::io::Result<Self> {
152 let content = std::fs::read_to_string(path)?;
153
154 if let Ok(corpus) = serde_json::from_str::<TrainingCorpus>(&content) {
156 return Ok(Self::from_corpus(corpus));
157 }
158
159 let mut corpus = TrainingCorpus::default();
161 for line in content.lines() {
162 if line.trim().is_empty() {
163 continue;
164 }
165 if let Ok(record) = serde_json::from_str::<CorpusRecord>(line) {
166 let is_correct = record.tuple.is_correct;
167 corpus.tuples.push(record.tuple);
168 corpus.features.push(record.features);
169 corpus.metadata.count += 1;
170 if is_correct {
171 corpus.metadata.correct_count += 1;
172 } else {
173 corpus.metadata.incorrect_count += 1;
174 }
175 }
176 }
177
178 Ok(Self::from_corpus(corpus))
179 }
180
181 fn export_json(&self, path: &Path) -> std::io::Result<()> {
182 let json = serde_json::to_string_pretty(&self.corpus)
183 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
184 std::fs::write(path, json)
185 }
186
187 fn export_jsonl(&self, path: &Path) -> std::io::Result<()> {
188 use std::io::Write;
189 let mut file = std::fs::File::create(path)?;
190
191 for (tuple, features) in self.corpus.tuples.iter().zip(self.corpus.features.iter()) {
192 let record = CorpusRecord {
193 tuple: tuple.clone(),
194 features: features.clone(),
195 };
196 let line = serde_json::to_string(&record)
197 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
198 writeln!(file, "{line}")?;
199 }
200
201 Ok(())
202 }
203
204 #[must_use]
208 pub fn to_training_data(&self) -> (Vec<[f32; 8]>, Vec<u8>) {
209 let features: Vec<[f32; 8]> = self
210 .corpus
211 .features
212 .iter()
213 .map(CommitFeatures::to_array)
214 .collect();
215
216 let labels: Vec<u8> = self
217 .corpus
218 .tuples
219 .iter()
220 .map(|t| u8::from(t.is_correct))
221 .collect();
222
223 (features, labels)
224 }
225
226 #[must_use]
228 pub fn train_test_split(
229 &self,
230 train_ratio: f64,
231 seed: u64,
232 ) -> (TrainingCorpus, TrainingCorpus) {
233 use std::collections::hash_map::DefaultHasher;
234 use std::hash::{Hash, Hasher};
235
236 let mut train = TrainingCorpus::default();
237 let mut test = TrainingCorpus::default();
238
239 for (i, (tuple, features)) in self
240 .corpus
241 .tuples
242 .iter()
243 .zip(self.corpus.features.iter())
244 .enumerate()
245 {
246 let mut hasher = DefaultHasher::new();
247 (seed, i).hash(&mut hasher);
248 let hash = hasher.finish();
249
250 #[allow(clippy::cast_sign_loss)]
251 let threshold = (train_ratio * u64::MAX as f64) as u64;
252
253 let target = if hash < threshold {
254 &mut train
255 } else {
256 &mut test
257 };
258
259 target.tuples.push(tuple.clone());
260 target.features.push(features.clone());
261 target.metadata.count += 1;
262 if tuple.is_correct {
263 target.metadata.correct_count += 1;
264 } else {
265 target.metadata.incorrect_count += 1;
266 }
267 }
268
269 (train, test)
270 }
271
272 #[must_use]
274 pub fn filter_correct(&self, correct: bool) -> TrainingCorpus {
275 let mut filtered = TrainingCorpus::default();
276
277 for (tuple, features) in self.corpus.tuples.iter().zip(self.corpus.features.iter()) {
278 if tuple.is_correct == correct {
279 filtered.tuples.push(tuple.clone());
280 filtered.features.push(features.clone());
281 filtered.metadata.count += 1;
282 if correct {
283 filtered.metadata.correct_count += 1;
284 } else {
285 filtered.metadata.incorrect_count += 1;
286 }
287 }
288 }
289
290 filtered
291 }
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
296struct CorpusRecord {
297 tuple: VerifiedTuple,
298 features: CommitFeatures,
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 fn sample_tuple(correct: bool) -> VerifiedTuple {
306 VerifiedTuple {
307 source_language: Language::Python,
308 target_language: Language::Rust,
309 source_code: "x = 1".to_string(),
310 target_code: "let x = 1;".to_string(),
311 is_correct: correct,
312 execution_time_ms: 10,
313 }
314 }
315
316 fn sample_features() -> CommitFeatures {
317 CommitFeatures {
318 lines_added: 5,
319 lines_deleted: 2,
320 files_changed: 1,
321 churn_ratio: 0.7,
322 has_test_changes: false,
323 complexity_delta: 1.0,
324 author_experience: 0.5,
325 days_since_last_change: 7.0,
326 }
327 }
328
329 #[test]
332 fn test_corpus_manager_new() {
333 let manager = CorpusManager::new();
334 assert_eq!(manager.corpus().tuples.len(), 0);
335 assert_eq!(manager.corpus().metadata.count, 0);
336 }
337
338 #[test]
339 fn test_corpus_add_tuple() {
340 let mut manager = CorpusManager::new();
341 manager.add(sample_tuple(true), sample_features());
342
343 assert_eq!(manager.corpus().tuples.len(), 1);
344 assert_eq!(manager.corpus().features.len(), 1);
345 assert_eq!(manager.corpus().metadata.count, 1);
346 assert_eq!(manager.corpus().metadata.correct_count, 1);
347 }
348
349 #[test]
350 fn test_corpus_add_incorrect() {
351 let mut manager = CorpusManager::new();
352 manager.add(sample_tuple(false), sample_features());
353
354 assert_eq!(manager.corpus().metadata.incorrect_count, 1);
355 assert_eq!(manager.corpus().metadata.correct_count, 0);
356 }
357
358 #[test]
359 fn test_corpus_add_tuples_batch() {
360 let mut manager = CorpusManager::new();
361 let tuples = vec![sample_tuple(true), sample_tuple(false), sample_tuple(true)];
362 manager.add_tuples(tuples);
363
364 assert_eq!(manager.corpus().tuples.len(), 3);
365 assert_eq!(manager.corpus().metadata.correct_count, 2);
366 assert_eq!(manager.corpus().metadata.incorrect_count, 1);
367 }
368
369 #[test]
370 fn test_corpus_metadata_accuracy() {
371 let metadata = CorpusMetadata {
372 count: 100,
373 correct_count: 80,
374 incorrect_count: 20,
375 ..Default::default()
376 };
377
378 assert!((metadata.accuracy() - 0.8).abs() < f64::EPSILON);
379 }
380
381 #[test]
382 fn test_corpus_metadata_accuracy_empty() {
383 let metadata = CorpusMetadata::default();
384 assert!(metadata.accuracy().abs() < f64::EPSILON);
385 }
386
387 #[test]
390 fn test_to_training_data() {
391 let mut manager = CorpusManager::new();
392 manager.add(sample_tuple(true), sample_features());
393 manager.add(sample_tuple(false), sample_features());
394
395 let (features, labels) = manager.to_training_data();
396
397 assert_eq!(features.len(), 2);
398 assert_eq!(labels.len(), 2);
399 assert_eq!(labels[0], 1); assert_eq!(labels[1], 0); }
402
403 #[test]
404 fn test_to_training_data_feature_values() {
405 let mut manager = CorpusManager::new();
406 let features = sample_features();
407 manager.add(sample_tuple(true), features.clone());
408
409 let (feature_matrix, _) = manager.to_training_data();
410
411 assert_eq!(feature_matrix[0][0], 5.0); assert_eq!(feature_matrix[0][1], 2.0); assert_eq!(feature_matrix[0][2], 1.0); }
415
416 #[test]
419 fn test_train_test_split() {
420 let mut manager = CorpusManager::new();
421 for _ in 0..100 {
422 manager.add(sample_tuple(true), sample_features());
423 }
424
425 let (train, test) = manager.train_test_split(0.8, 42);
426
427 assert!(train.metadata.count > 0);
428 assert!(test.metadata.count > 0);
429 assert_eq!(train.metadata.count + test.metadata.count, 100);
430
431 let train_ratio = train.metadata.count as f64 / 100.0;
433 assert!(train_ratio > 0.7 && train_ratio < 0.9);
434 }
435
436 #[test]
437 fn test_train_test_split_deterministic() {
438 let mut manager = CorpusManager::new();
439 for _ in 0..50 {
440 manager.add(sample_tuple(true), sample_features());
441 }
442
443 let (train1, _) = manager.train_test_split(0.8, 42);
444 let (train2, _) = manager.train_test_split(0.8, 42);
445
446 assert_eq!(train1.metadata.count, train2.metadata.count);
447 }
448
449 #[test]
452 fn test_filter_correct() {
453 let mut manager = CorpusManager::new();
454 manager.add(sample_tuple(true), sample_features());
455 manager.add(sample_tuple(false), sample_features());
456 manager.add(sample_tuple(true), sample_features());
457
458 let correct_only = manager.filter_correct(true);
459
460 assert_eq!(correct_only.metadata.count, 2);
461 assert_eq!(correct_only.metadata.correct_count, 2);
462 assert_eq!(correct_only.metadata.incorrect_count, 0);
463 }
464
465 #[test]
466 fn test_filter_incorrect() {
467 let mut manager = CorpusManager::new();
468 manager.add(sample_tuple(true), sample_features());
469 manager.add(sample_tuple(false), sample_features());
470
471 let incorrect_only = manager.filter_correct(false);
472
473 assert_eq!(incorrect_only.metadata.count, 1);
474 assert_eq!(incorrect_only.metadata.incorrect_count, 1);
475 }
476
477 #[test]
480 fn test_export_load_jsonl() {
481 let mut manager = CorpusManager::new();
482 manager.add(sample_tuple(true), sample_features());
483 manager.add(sample_tuple(false), sample_features());
484
485 let dir = tempfile::tempdir().unwrap();
486 let path = dir.path().join("corpus.jsonl");
487
488 manager.export(&path, CorpusFormat::Jsonl).unwrap();
489 assert!(path.exists());
490
491 let loaded = CorpusManager::load(&path).unwrap();
492 assert_eq!(loaded.corpus().tuples.len(), 2);
493 assert_eq!(loaded.corpus().metadata.count, 2);
494 }
495
496 #[test]
497 fn test_export_load_json() {
498 let mut manager = CorpusManager::new();
499 manager.add(sample_tuple(true), sample_features());
500
501 let dir = tempfile::tempdir().unwrap();
502 let path = dir.path().join("corpus.json");
503
504 manager.export(&path, CorpusFormat::Json).unwrap();
505 let loaded = CorpusManager::load(&path).unwrap();
506
507 assert_eq!(loaded.corpus().tuples.len(), 1);
508 }
509
510 #[test]
513 fn test_set_metadata() {
514 let mut manager = CorpusManager::new();
515 let metadata = CorpusMetadata {
516 version: "1.0.0".to_string(),
517 source_language: Language::Python,
518 target_language: Language::Rust,
519 count: 0,
520 correct_count: 0,
521 incorrect_count: 0,
522 timestamp: 1700000000,
523 };
524
525 manager.set_metadata(metadata);
526
527 assert_eq!(manager.corpus().metadata.version, "1.0.0");
528 assert_eq!(manager.corpus().metadata.timestamp, 1700000000);
529 }
530
531 #[test]
532 fn test_corpus_format_default() {
533 let format = CorpusFormat::default();
534 assert_eq!(format, CorpusFormat::Jsonl);
535 }
536
537 #[test]
538 fn test_training_corpus_default() {
539 let corpus = TrainingCorpus::default();
540 assert!(corpus.tuples.is_empty());
541 assert!(corpus.features.is_empty());
542 assert_eq!(corpus.metadata.count, 0);
543 }
544
545 #[test]
546 fn test_corpus_manager_debug() {
547 let manager = CorpusManager::new();
548 let debug = format!("{manager:?}");
549 assert!(debug.contains("CorpusManager"));
550 }
551
552 #[test]
553 fn test_corpus_metadata_debug() {
554 let metadata = CorpusMetadata::default();
555 let debug = format!("{metadata:?}");
556 assert!(debug.contains("CorpusMetadata"));
557 }
558
559 #[test]
560 fn test_from_corpus() {
561 let corpus = TrainingCorpus {
562 tuples: vec![sample_tuple(true)],
563 features: vec![sample_features()],
564 metadata: CorpusMetadata {
565 count: 1,
566 correct_count: 1,
567 ..Default::default()
568 },
569 };
570
571 let manager = CorpusManager::from_corpus(corpus);
572 assert_eq!(manager.corpus().tuples.len(), 1);
573 }
574}
575
576#[cfg(test)]
578mod proptests {
579 use super::*;
580 use proptest::prelude::*;
581
582 fn arb_tuple() -> impl Strategy<Value = VerifiedTuple> {
583 (any::<bool>(), 1u64..1000).prop_map(|(correct, time)| VerifiedTuple {
584 source_language: Language::Python,
585 target_language: Language::Rust,
586 source_code: "x = 1".to_string(),
587 target_code: "let x = 1;".to_string(),
588 is_correct: correct,
589 execution_time_ms: time,
590 })
591 }
592
593 fn arb_features() -> impl Strategy<Value = CommitFeatures> {
594 (0u32..100, 0u32..100, 1u32..10, any::<bool>()).prop_map(
595 |(added, deleted, files, has_tests)| CommitFeatures {
596 lines_added: added,
597 lines_deleted: deleted,
598 files_changed: files,
599 churn_ratio: added as f32 / (added + deleted + 1) as f32,
600 has_test_changes: has_tests,
601 complexity_delta: 0.0,
602 author_experience: 0.5,
603 days_since_last_change: 7.0,
604 },
605 )
606 }
607
608 proptest! {
609 #[test]
611 fn prop_count_matches_tuples(n in 1usize..50) {
612 let mut manager = CorpusManager::new();
613 for _ in 0..n {
614 let tuple = VerifiedTuple {
615 source_language: Language::Python,
616 target_language: Language::Rust,
617 source_code: "x = 1".to_string(),
618 target_code: "let x = 1;".to_string(),
619 is_correct: true,
620 execution_time_ms: 10,
621 };
622 manager.add(tuple, CommitFeatures::default());
623 }
624
625 prop_assert_eq!(manager.corpus().metadata.count, n);
626 prop_assert_eq!(manager.corpus().tuples.len(), n);
627 prop_assert_eq!(manager.corpus().features.len(), n);
628 }
629
630 #[test]
632 fn prop_correct_incorrect_sum(
633 correct_count in 0usize..25,
634 incorrect_count in 0usize..25,
635 ) {
636 let mut manager = CorpusManager::new();
637
638 for _ in 0..correct_count {
639 let tuple = VerifiedTuple {
640 source_language: Language::Python,
641 target_language: Language::Rust,
642 source_code: "x = 1".to_string(),
643 target_code: "let x = 1;".to_string(),
644 is_correct: true,
645 execution_time_ms: 10,
646 };
647 manager.add(tuple, CommitFeatures::default());
648 }
649
650 for _ in 0..incorrect_count {
651 let tuple = VerifiedTuple {
652 source_language: Language::Python,
653 target_language: Language::Rust,
654 source_code: "x = 1".to_string(),
655 target_code: "let x = 1;".to_string(),
656 is_correct: false,
657 execution_time_ms: 10,
658 };
659 manager.add(tuple, CommitFeatures::default());
660 }
661
662 let meta = &manager.corpus().metadata;
663 prop_assert_eq!(meta.correct_count + meta.incorrect_count, meta.count);
664 prop_assert_eq!(meta.correct_count, correct_count);
665 prop_assert_eq!(meta.incorrect_count, incorrect_count);
666 }
667
668 #[test]
670 fn prop_split_preserves_count(n in 1usize..100, ratio in 0.1f64..0.9) {
671 let mut manager = CorpusManager::new();
672 for _ in 0..n {
673 let tuple = VerifiedTuple {
674 source_language: Language::Python,
675 target_language: Language::Rust,
676 source_code: "x = 1".to_string(),
677 target_code: "let x = 1;".to_string(),
678 is_correct: true,
679 execution_time_ms: 10,
680 };
681 manager.add(tuple, CommitFeatures::default());
682 }
683
684 let (train, test) = manager.train_test_split(ratio, 42);
685
686 prop_assert_eq!(train.metadata.count + test.metadata.count, n);
687 prop_assert_eq!(train.tuples.len() + test.tuples.len(), n);
688 }
689
690 #[test]
692 fn prop_training_data_length(n in 1usize..50) {
693 let mut manager = CorpusManager::new();
694 for _ in 0..n {
695 let tuple = VerifiedTuple {
696 source_language: Language::Python,
697 target_language: Language::Rust,
698 source_code: "x = 1".to_string(),
699 target_code: "let x = 1;".to_string(),
700 is_correct: true,
701 execution_time_ms: 10,
702 };
703 manager.add(tuple, CommitFeatures::default());
704 }
705
706 let (features, labels) = manager.to_training_data();
707
708 prop_assert_eq!(features.len(), n);
709 prop_assert_eq!(labels.len(), n);
710 }
711
712 #[test]
714 fn prop_labels_match_correctness(correct in proptest::collection::vec(any::<bool>(), 1..20)) {
715 let mut manager = CorpusManager::new();
716
717 for &is_correct in &correct {
718 let tuple = VerifiedTuple {
719 source_language: Language::Python,
720 target_language: Language::Rust,
721 source_code: "x = 1".to_string(),
722 target_code: "let x = 1;".to_string(),
723 is_correct,
724 execution_time_ms: 10,
725 };
726 manager.add(tuple, CommitFeatures::default());
727 }
728
729 let (_, labels) = manager.to_training_data();
730
731 for (expected, &actual) in correct.iter().zip(labels.iter()) {
732 let expected_label = if *expected { 1 } else { 0 };
733 prop_assert_eq!(expected_label, actual);
734 }
735 }
736
737 #[test]
739 fn prop_filter_correct_only(n in 1usize..20) {
740 let mut manager = CorpusManager::new();
741
742 for i in 0..n {
743 let tuple = VerifiedTuple {
744 source_language: Language::Python,
745 target_language: Language::Rust,
746 source_code: "x = 1".to_string(),
747 target_code: "let x = 1;".to_string(),
748 is_correct: i % 2 == 0, execution_time_ms: 10,
750 };
751 manager.add(tuple, CommitFeatures::default());
752 }
753
754 let filtered = manager.filter_correct(true);
755
756 for tuple in &filtered.tuples {
757 prop_assert!(tuple.is_correct);
758 }
759 }
760
761 #[test]
763 fn prop_accuracy_bounded(correct in 0usize..100, total in 1usize..200) {
764 let metadata = CorpusMetadata {
765 count: total,
766 correct_count: correct.min(total),
767 incorrect_count: total.saturating_sub(correct.min(total)),
768 ..Default::default()
769 };
770
771 let acc = metadata.accuracy();
772 prop_assert!(acc >= 0.0);
773 prop_assert!(acc <= 1.0);
774 }
775 }
776}