1use std::collections::HashMap;
23use std::path::PathBuf;
24
25use crate::ml_features::{AllocationKind, InferredOwnership, OwnershipFeaturesBuilder};
26use crate::retraining_pipeline::TrainingSample;
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub enum DataSource {
31 RustPort {
33 project: String,
35 },
36 CompilerFeedback {
38 error_code: String,
40 },
41 Synthetic {
43 template: String,
45 },
46 HumanAnnotated {
48 annotator: String,
50 },
51}
52
53#[derive(Debug, Clone)]
55pub struct LabeledSample {
56 pub sample: TrainingSample,
58 pub source: DataSource,
60 pub label_confidence: f64,
62 pub c_code: String,
64 pub rust_code: String,
66 pub metadata: HashMap<String, String>,
68}
69
70impl LabeledSample {
71 pub fn new(sample: TrainingSample, source: DataSource, c_code: &str, rust_code: &str) -> Self {
73 Self {
74 sample,
75 source,
76 label_confidence: 1.0,
77 c_code: c_code.to_string(),
78 rust_code: rust_code.to_string(),
79 metadata: HashMap::new(),
80 }
81 }
82
83 pub fn with_confidence(mut self, confidence: f64) -> Self {
85 self.label_confidence = confidence.clamp(0.0, 1.0);
86 self
87 }
88
89 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
91 self.metadata.insert(key.to_string(), value.to_string());
92 self
93 }
94}
95
96#[derive(Debug, Clone, Default)]
98pub struct DatasetStats {
99 pub total_samples: usize,
101 pub label_distribution: HashMap<String, usize>,
103 pub source_distribution: HashMap<String, usize>,
105 pub avg_confidence: f64,
107 pub min_confidence: f64,
109 pub max_confidence: f64,
111}
112
113impl DatasetStats {
114 pub fn is_balanced(&self) -> bool {
116 if self.label_distribution.is_empty() {
117 return true;
118 }
119
120 let counts: Vec<usize> = self.label_distribution.values().copied().collect();
121 let min_count = counts.iter().copied().min().unwrap_or(0);
122 let max_count = counts.iter().copied().max().unwrap_or(0);
123
124 min_count > 0 && max_count <= min_count * 3
125 }
126
127 pub fn dominant_label(&self) -> Option<String> {
129 self.label_distribution
130 .iter()
131 .max_by_key(|(_, count)| *count)
132 .map(|(label, _)| label.clone())
133 }
134}
135
136#[derive(Debug, Clone, Default)]
138pub struct TrainingDataset {
139 samples: Vec<LabeledSample>,
141 name: String,
143 version: String,
145}
146
147impl TrainingDataset {
148 pub fn new(name: &str, version: &str) -> Self {
150 Self {
151 samples: Vec::new(),
152 name: name.to_string(),
153 version: version.to_string(),
154 }
155 }
156
157 pub fn add(&mut self, sample: LabeledSample) {
159 self.samples.push(sample);
160 }
161
162 pub fn add_all(&mut self, samples: impl IntoIterator<Item = LabeledSample>) {
164 self.samples.extend(samples);
165 }
166
167 pub fn len(&self) -> usize {
169 self.samples.len()
170 }
171
172 pub fn is_empty(&self) -> bool {
174 self.samples.is_empty()
175 }
176
177 pub fn samples(&self) -> &[LabeledSample] {
179 &self.samples
180 }
181
182 pub fn samples_by_source(&self, source_type: &str) -> Vec<&LabeledSample> {
184 self.samples
185 .iter()
186 .filter(|s| source_type_name(&s.source) == source_type)
187 .collect()
188 }
189
190 pub fn samples_by_label(&self, label: InferredOwnership) -> Vec<&LabeledSample> {
192 self.samples
193 .iter()
194 .filter(|s| s.sample.label == label)
195 .collect()
196 }
197
198 pub fn to_training_samples(&self) -> Vec<TrainingSample> {
200 self.samples.iter().map(|s| s.sample.clone()).collect()
201 }
202
203 pub fn stats(&self) -> DatasetStats {
205 let mut label_dist = HashMap::new();
206 let mut source_dist = HashMap::new();
207 let mut confidence_sum = 0.0_f64;
208 let mut min_conf: f64 = 1.0;
209 let mut max_conf: f64 = 0.0;
210
211 for sample in &self.samples {
212 let label_key = format!("{:?}", sample.sample.label);
214 *label_dist.entry(label_key).or_insert(0) += 1;
215
216 let source_key = source_type_name(&sample.source);
218 *source_dist.entry(source_key).or_insert(0) += 1;
219
220 confidence_sum += sample.label_confidence;
222 min_conf = min_conf.min(sample.label_confidence);
223 max_conf = max_conf.max(sample.label_confidence);
224 }
225
226 let avg_confidence = if self.samples.is_empty() {
227 0.0
228 } else {
229 confidence_sum / self.samples.len() as f64
230 };
231
232 DatasetStats {
233 total_samples: self.samples.len(),
234 label_distribution: label_dist,
235 source_distribution: source_dist,
236 avg_confidence,
237 min_confidence: if self.samples.is_empty() {
238 0.0
239 } else {
240 min_conf
241 },
242 max_confidence: max_conf,
243 }
244 }
245
246 pub fn filter_by_confidence(&self, min_confidence: f64) -> TrainingDataset {
248 let samples = self
249 .samples
250 .iter()
251 .filter(|s| s.label_confidence >= min_confidence)
252 .cloned()
253 .collect();
254
255 TrainingDataset {
256 samples,
257 name: self.name.clone(),
258 version: self.version.clone(),
259 }
260 }
261
262 pub fn merge(&mut self, other: TrainingDataset) {
264 self.samples.extend(other.samples);
265 }
266
267 pub fn name(&self) -> &str {
269 &self.name
270 }
271
272 pub fn version(&self) -> &str {
274 &self.version
275 }
276}
277
278fn source_type_name(source: &DataSource) -> String {
280 match source {
281 DataSource::RustPort { .. } => "RustPort".to_string(),
282 DataSource::CompilerFeedback { .. } => "CompilerFeedback".to_string(),
283 DataSource::Synthetic { .. } => "Synthetic".to_string(),
284 DataSource::HumanAnnotated { .. } => "HumanAnnotated".to_string(),
285 }
286}
287
288#[derive(Debug, Clone)]
290pub struct SyntheticConfig {
291 pub samples_per_pattern: usize,
293 pub seed: u64,
295 pub include_edge_cases: bool,
297}
298
299impl Default for SyntheticConfig {
300 fn default() -> Self {
301 Self {
302 samples_per_pattern: 100,
303 seed: 42,
304 include_edge_cases: true,
305 }
306 }
307}
308
309pub struct SyntheticDataGenerator {
311 config: SyntheticConfig,
312}
313
314impl SyntheticDataGenerator {
315 pub fn new(config: SyntheticConfig) -> Self {
317 Self { config }
318 }
319
320 pub fn generate_malloc_box_samples(&self) -> Vec<LabeledSample> {
322 let mut samples = Vec::new();
323
324 for i in 0..self.config.samples_per_pattern {
326 let features = OwnershipFeaturesBuilder::default()
327 .pointer_depth(1)
328 .allocation_site(AllocationKind::Malloc)
329 .deallocation_count(1)
330 .build();
331
332 let sample = TrainingSample::new(
333 features,
334 InferredOwnership::Owned,
335 &format!("malloc_box_{}.c", i),
336 i as u32,
337 );
338
339 let c_code = format!(
340 "int* ptr{} = (int*)malloc(sizeof(int));\n*ptr{} = {};\nfree(ptr{});",
341 i, i, i, i
342 );
343 let rust_code = format!("let ptr{}: Box<i32> = Box::new({});", i, i);
344
345 samples.push(LabeledSample::new(
346 sample,
347 DataSource::Synthetic {
348 template: "malloc_free_box".to_string(),
349 },
350 &c_code,
351 &rust_code,
352 ));
353 }
354
355 samples
356 }
357
358 pub fn generate_array_vec_samples(&self) -> Vec<LabeledSample> {
360 let mut samples = Vec::new();
361
362 for i in 0..self.config.samples_per_pattern {
363 let size = (i % 10 + 1) * 10;
364 let features = OwnershipFeaturesBuilder::default()
365 .pointer_depth(1)
366 .allocation_site(AllocationKind::Malloc)
367 .has_size_param(true)
368 .array_decay(true)
369 .deallocation_count(1)
370 .build();
371
372 let sample = TrainingSample::new(
373 features,
374 InferredOwnership::Vec,
375 &format!("array_vec_{}.c", i),
376 i as u32,
377 );
378 let _ = size; let c_code = format!(
381 "int* arr = (int*)malloc({} * sizeof(int));\nfor(int j = 0; j < {}; j++) arr[j] = j;\nfree(arr);",
382 size, size
383 );
384 let rust_code = format!("let arr: Vec<i32> = (0..{}).collect();", size);
385
386 samples.push(LabeledSample::new(
387 sample,
388 DataSource::Synthetic {
389 template: "array_vec".to_string(),
390 },
391 &c_code,
392 &rust_code,
393 ));
394 }
395
396 samples
397 }
398
399 pub fn generate_const_ref_samples(&self) -> Vec<LabeledSample> {
401 let mut samples = Vec::new();
402
403 for i in 0..self.config.samples_per_pattern {
404 let features = OwnershipFeaturesBuilder::default()
405 .pointer_depth(1)
406 .const_qualified(true)
407 .build();
408
409 let sample = TrainingSample::new(
410 features,
411 InferredOwnership::Borrowed,
412 &format!("const_ref_{}.c", i),
413 i as u32,
414 );
415
416 let c_code = format!(
417 "void process{}(const int* ptr) {{ printf(\"%d\", *ptr); }}",
418 i
419 );
420 let rust_code = format!("fn process{}(ptr: &i32) {{ println!(\"{{}}\", ptr); }}", i);
421
422 samples.push(LabeledSample::new(
423 sample,
424 DataSource::Synthetic {
425 template: "const_ref".to_string(),
426 },
427 &c_code,
428 &rust_code,
429 ));
430 }
431
432 samples
433 }
434
435 pub fn generate_mut_ref_samples(&self) -> Vec<LabeledSample> {
437 let mut samples = Vec::new();
438
439 for i in 0..self.config.samples_per_pattern {
440 let features = OwnershipFeaturesBuilder::default()
441 .pointer_depth(1)
442 .const_qualified(false)
443 .write_count(1)
444 .build();
445
446 let sample = TrainingSample::new(
447 features,
448 InferredOwnership::BorrowedMut,
449 &format!("mut_ref_{}.c", i),
450 i as u32,
451 );
452
453 let c_code = format!("void increment{}(int* ptr) {{ (*ptr)++; }}", i);
454 let rust_code = format!("fn increment{}(ptr: &mut i32) {{ *ptr += 1; }}", i);
455
456 samples.push(LabeledSample::new(
457 sample,
458 DataSource::Synthetic {
459 template: "mut_ref".to_string(),
460 },
461 &c_code,
462 &rust_code,
463 ));
464 }
465
466 samples
467 }
468
469 pub fn generate_slice_samples(&self) -> Vec<LabeledSample> {
471 let mut samples = Vec::new();
472
473 for i in 0..self.config.samples_per_pattern {
474 let features = OwnershipFeaturesBuilder::default()
475 .pointer_depth(1)
476 .const_qualified(true)
477 .array_decay(true)
478 .has_size_param(true)
479 .build();
480
481 let sample = TrainingSample::new(
482 features,
483 InferredOwnership::Slice,
484 &format!("slice_{}.c", i),
485 i as u32,
486 );
487
488 let c_code = format!(
489 "int sum{}(const int* arr, size_t len) {{ int s = 0; for(size_t j = 0; j < len; j++) s += arr[j]; return s; }}",
490 i
491 );
492 let rust_code = format!("fn sum{}(arr: &[i32]) -> i32 {{ arr.iter().sum() }}", i);
493
494 samples.push(LabeledSample::new(
495 sample,
496 DataSource::Synthetic {
497 template: "slice".to_string(),
498 },
499 &c_code,
500 &rust_code,
501 ));
502 }
503
504 samples
505 }
506
507 pub fn generate_full_dataset(&self) -> TrainingDataset {
509 let mut dataset = TrainingDataset::new("synthetic", "1.0.0");
510
511 dataset.add_all(self.generate_malloc_box_samples());
512 dataset.add_all(self.generate_array_vec_samples());
513 dataset.add_all(self.generate_const_ref_samples());
514 dataset.add_all(self.generate_mut_ref_samples());
515 dataset.add_all(self.generate_slice_samples());
516
517 dataset
518 }
519}
520
521#[derive(Debug)]
523pub struct CollectionResult {
524 pub samples_collected: usize,
526 pub errors: Vec<String>,
528 pub source_path: Option<PathBuf>,
530}
531
532pub struct TrainingDataCollector {
534 samples: Vec<LabeledSample>,
536 errors: Vec<String>,
538}
539
540impl TrainingDataCollector {
541 pub fn new() -> Self {
543 Self {
544 samples: Vec::new(),
545 errors: Vec::new(),
546 }
547 }
548
549 pub fn add_synthetic(&mut self, generator: &SyntheticDataGenerator) {
551 let dataset = generator.generate_full_dataset();
552 self.samples.extend(dataset.samples);
553 }
554
555 pub fn record_error(&mut self, error: &str) {
557 self.errors.push(error.to_string());
558 }
559
560 pub fn result(&self) -> CollectionResult {
562 CollectionResult {
563 samples_collected: self.samples.len(),
564 errors: self.errors.clone(),
565 source_path: None,
566 }
567 }
568
569 pub fn build(self, name: &str, version: &str) -> TrainingDataset {
571 let mut dataset = TrainingDataset::new(name, version);
572 dataset.add_all(self.samples);
573 dataset
574 }
575
576 pub fn sample_count(&self) -> usize {
578 self.samples.len()
579 }
580
581 pub fn error_count(&self) -> usize {
583 self.errors.len()
584 }
585}
586
587impl Default for TrainingDataCollector {
588 fn default() -> Self {
589 Self::new()
590 }
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596
597 #[test]
602 fn data_source_rust_port() {
603 let source = DataSource::RustPort {
604 project: "rusqlite".to_string(),
605 };
606 assert_eq!(source_type_name(&source), "RustPort");
607 }
608
609 #[test]
610 fn data_source_compiler_feedback() {
611 let source = DataSource::CompilerFeedback {
612 error_code: "E0382".to_string(),
613 };
614 assert_eq!(source_type_name(&source), "CompilerFeedback");
615 }
616
617 #[test]
618 fn data_source_synthetic() {
619 let source = DataSource::Synthetic {
620 template: "malloc_box".to_string(),
621 };
622 assert_eq!(source_type_name(&source), "Synthetic");
623 }
624
625 #[test]
626 fn data_source_human_annotated() {
627 let source = DataSource::HumanAnnotated {
628 annotator: "expert1".to_string(),
629 };
630 assert_eq!(source_type_name(&source), "HumanAnnotated");
631 }
632
633 #[test]
638 fn labeled_sample_new() {
639 let features = OwnershipFeaturesBuilder::default().build();
640 let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 42);
641
642 let labeled = LabeledSample::new(
643 sample,
644 DataSource::Synthetic {
645 template: "test".to_string(),
646 },
647 "int* p = malloc(4);",
648 "let p: Box<i32> = Box::new(0);",
649 );
650
651 assert!((labeled.label_confidence - 1.0).abs() < 0.001);
652 assert!(labeled.c_code.contains("malloc"));
653 assert!(labeled.rust_code.contains("Box"));
654 }
655
656 #[test]
657 fn labeled_sample_with_confidence() {
658 let features = OwnershipFeaturesBuilder::default().build();
659 let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
660
661 let labeled = LabeledSample::new(
662 sample,
663 DataSource::Synthetic {
664 template: "test".to_string(),
665 },
666 "",
667 "",
668 )
669 .with_confidence(0.8);
670
671 assert!((labeled.label_confidence - 0.8).abs() < 0.001);
672 }
673
674 #[test]
675 fn labeled_sample_confidence_clamped() {
676 let features = OwnershipFeaturesBuilder::default().build();
677 let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
678
679 let labeled = LabeledSample::new(
680 sample,
681 DataSource::Synthetic {
682 template: "test".to_string(),
683 },
684 "",
685 "",
686 )
687 .with_confidence(1.5);
688
689 assert!((labeled.label_confidence - 1.0).abs() < 0.001);
690 }
691
692 #[test]
693 fn labeled_sample_with_metadata() {
694 let features = OwnershipFeaturesBuilder::default().build();
695 let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
696
697 let labeled = LabeledSample::new(
698 sample,
699 DataSource::Synthetic {
700 template: "test".to_string(),
701 },
702 "",
703 "",
704 )
705 .with_metadata("commit", "abc123");
706
707 assert_eq!(labeled.metadata.get("commit"), Some(&"abc123".to_string()));
708 }
709
710 #[test]
715 fn dataset_stats_is_balanced_empty() {
716 let stats = DatasetStats::default();
717 assert!(stats.is_balanced());
718 }
719
720 #[test]
721 fn dataset_stats_is_balanced_even() {
722 let mut stats = DatasetStats::default();
723 stats.label_distribution.insert("Owned".to_string(), 100);
724 stats.label_distribution.insert("Borrowed".to_string(), 100);
725 assert!(stats.is_balanced());
726 }
727
728 #[test]
729 fn dataset_stats_is_balanced_imbalanced() {
730 let mut stats = DatasetStats::default();
731 stats.label_distribution.insert("Owned".to_string(), 100);
732 stats.label_distribution.insert("Borrowed".to_string(), 10);
733 assert!(!stats.is_balanced()); }
735
736 #[test]
737 fn dataset_stats_dominant_label() {
738 let mut stats = DatasetStats::default();
739 stats.label_distribution.insert("Owned".to_string(), 100);
740 stats.label_distribution.insert("Borrowed".to_string(), 50);
741 assert_eq!(stats.dominant_label(), Some("Owned".to_string()));
742 }
743
744 #[test]
749 fn training_dataset_new() {
750 let dataset = TrainingDataset::new("test", "1.0.0");
751 assert_eq!(dataset.name(), "test");
752 assert_eq!(dataset.version(), "1.0.0");
753 assert!(dataset.is_empty());
754 }
755
756 #[test]
757 fn training_dataset_add() {
758 let mut dataset = TrainingDataset::new("test", "1.0.0");
759
760 let features = OwnershipFeaturesBuilder::default().build();
761 let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
762 let labeled = LabeledSample::new(
763 sample,
764 DataSource::Synthetic {
765 template: "test".to_string(),
766 },
767 "",
768 "",
769 );
770
771 dataset.add(labeled);
772 assert_eq!(dataset.len(), 1);
773 }
774
775 #[test]
776 fn training_dataset_samples_accessor() {
777 let mut dataset = TrainingDataset::new("test", "1.0.0");
778 let features = OwnershipFeaturesBuilder::default().build();
779 let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
780 dataset.add(LabeledSample::new(
781 sample,
782 DataSource::Synthetic {
783 template: "test".to_string(),
784 },
785 "int* p;",
786 "let p: Box<i32>;",
787 ));
788
789 let samples = dataset.samples();
790 assert_eq!(samples.len(), 1);
791 assert!(samples[0].c_code.contains("int*"));
792 }
793
794 #[test]
795 fn training_dataset_samples_by_source() {
796 let mut dataset = TrainingDataset::new("test", "1.0.0");
797
798 let features = OwnershipFeaturesBuilder::default().build();
799 let sample = TrainingSample::new(features.clone(), InferredOwnership::Owned, "test.c", 1);
800 dataset.add(LabeledSample::new(
801 sample,
802 DataSource::Synthetic {
803 template: "test".to_string(),
804 },
805 "",
806 "",
807 ));
808
809 let sample2 = TrainingSample::new(features, InferredOwnership::Borrowed, "test.c", 2);
810 dataset.add(LabeledSample::new(
811 sample2,
812 DataSource::RustPort {
813 project: "curl".to_string(),
814 },
815 "",
816 "",
817 ));
818
819 assert_eq!(dataset.samples_by_source("Synthetic").len(), 1);
820 assert_eq!(dataset.samples_by_source("RustPort").len(), 1);
821 assert_eq!(dataset.samples_by_source("HumanAnnotated").len(), 0);
822 }
823
824 #[test]
825 fn training_dataset_samples_by_label() {
826 let mut dataset = TrainingDataset::new("test", "1.0.0");
827
828 let features = OwnershipFeaturesBuilder::default().build();
829 let sample = TrainingSample::new(features.clone(), InferredOwnership::Owned, "test.c", 1);
830 dataset.add(LabeledSample::new(
831 sample,
832 DataSource::Synthetic {
833 template: "test".to_string(),
834 },
835 "",
836 "",
837 ));
838
839 let sample2 = TrainingSample::new(features, InferredOwnership::Borrowed, "test.c", 2);
840 dataset.add(LabeledSample::new(
841 sample2,
842 DataSource::Synthetic {
843 template: "test".to_string(),
844 },
845 "",
846 "",
847 ));
848
849 assert_eq!(dataset.samples_by_label(InferredOwnership::Owned).len(), 1);
850 assert_eq!(
851 dataset.samples_by_label(InferredOwnership::Borrowed).len(),
852 1
853 );
854 assert_eq!(dataset.samples_by_label(InferredOwnership::Vec).len(), 0);
855 }
856
857 #[test]
858 fn training_dataset_merge() {
859 let mut dataset1 = TrainingDataset::new("test1", "1.0.0");
860 let features = OwnershipFeaturesBuilder::default().build();
861 let sample = TrainingSample::new(features.clone(), InferredOwnership::Owned, "test.c", 1);
862 dataset1.add(LabeledSample::new(
863 sample,
864 DataSource::Synthetic {
865 template: "test".to_string(),
866 },
867 "",
868 "",
869 ));
870
871 let mut dataset2 = TrainingDataset::new("test2", "1.0.0");
872 let sample2 = TrainingSample::new(features, InferredOwnership::Borrowed, "test.c", 2);
873 dataset2.add(LabeledSample::new(
874 sample2,
875 DataSource::Synthetic {
876 template: "test".to_string(),
877 },
878 "",
879 "",
880 ));
881
882 dataset1.merge(dataset2);
883 assert_eq!(dataset1.len(), 2);
884 }
885
886 #[test]
887 fn training_dataset_stats() {
888 let mut dataset = TrainingDataset::new("test", "1.0.0");
889
890 for _ in 0..2 {
892 let features = OwnershipFeaturesBuilder::default().build();
893 let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
894 dataset.add(LabeledSample::new(
895 sample,
896 DataSource::Synthetic {
897 template: "test".to_string(),
898 },
899 "",
900 "",
901 ));
902 }
903
904 let features = OwnershipFeaturesBuilder::default().build();
906 let sample = TrainingSample::new(features, InferredOwnership::Borrowed, "test.c", 1);
907 dataset.add(LabeledSample::new(
908 sample,
909 DataSource::Synthetic {
910 template: "test".to_string(),
911 },
912 "",
913 "",
914 ));
915
916 let stats = dataset.stats();
917 assert_eq!(stats.total_samples, 3);
918 assert_eq!(stats.label_distribution.get("Owned"), Some(&2));
919 assert_eq!(stats.label_distribution.get("Borrowed"), Some(&1));
920 }
921
922 #[test]
923 fn training_dataset_filter_by_confidence() {
924 let mut dataset = TrainingDataset::new("test", "1.0.0");
925
926 let features = OwnershipFeaturesBuilder::default().build();
927 let sample = TrainingSample::new(features.clone(), InferredOwnership::Owned, "test.c", 1);
928 dataset.add(
929 LabeledSample::new(
930 sample,
931 DataSource::Synthetic {
932 template: "test".to_string(),
933 },
934 "",
935 "",
936 )
937 .with_confidence(0.9),
938 );
939
940 let sample2 = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 2);
941 dataset.add(
942 LabeledSample::new(
943 sample2,
944 DataSource::Synthetic {
945 template: "test".to_string(),
946 },
947 "",
948 "",
949 )
950 .with_confidence(0.5),
951 );
952
953 let filtered = dataset.filter_by_confidence(0.8);
954 assert_eq!(filtered.len(), 1);
955 }
956
957 #[test]
958 fn training_dataset_to_training_samples() {
959 let mut dataset = TrainingDataset::new("test", "1.0.0");
960
961 let features = OwnershipFeaturesBuilder::default().build();
962 let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
963 dataset.add(LabeledSample::new(
964 sample,
965 DataSource::Synthetic {
966 template: "test".to_string(),
967 },
968 "",
969 "",
970 ));
971
972 let samples = dataset.to_training_samples();
973 assert_eq!(samples.len(), 1);
974 }
975
976 #[test]
981 fn synthetic_generator_config_default() {
982 let config = SyntheticConfig::default();
983 assert_eq!(config.samples_per_pattern, 100);
984 assert_eq!(config.seed, 42);
985 assert!(config.include_edge_cases);
986 }
987
988 #[test]
989 fn synthetic_generator_malloc_box() {
990 let config = SyntheticConfig {
991 samples_per_pattern: 10,
992 ..Default::default()
993 };
994 let generator = SyntheticDataGenerator::new(config);
995 let samples = generator.generate_malloc_box_samples();
996
997 assert_eq!(samples.len(), 10);
998 for sample in &samples {
999 assert!(matches!(sample.sample.label, InferredOwnership::Owned));
1000 assert!(sample.c_code.contains("malloc"));
1001 assert!(sample.rust_code.contains("Box"));
1002 }
1003 }
1004
1005 #[test]
1006 fn synthetic_generator_array_vec() {
1007 let config = SyntheticConfig {
1008 samples_per_pattern: 10,
1009 ..Default::default()
1010 };
1011 let generator = SyntheticDataGenerator::new(config);
1012 let samples = generator.generate_array_vec_samples();
1013
1014 assert_eq!(samples.len(), 10);
1015 for sample in &samples {
1016 assert!(matches!(sample.sample.label, InferredOwnership::Vec));
1017 assert!(sample.rust_code.contains("Vec"));
1018 }
1019 }
1020
1021 #[test]
1022 fn synthetic_generator_const_ref() {
1023 let config = SyntheticConfig {
1024 samples_per_pattern: 10,
1025 ..Default::default()
1026 };
1027 let generator = SyntheticDataGenerator::new(config);
1028 let samples = generator.generate_const_ref_samples();
1029
1030 assert_eq!(samples.len(), 10);
1031 for sample in &samples {
1032 assert!(matches!(sample.sample.label, InferredOwnership::Borrowed));
1033 assert!(sample.c_code.contains("const"));
1034 }
1035 }
1036
1037 #[test]
1038 fn synthetic_generator_mut_ref() {
1039 let config = SyntheticConfig {
1040 samples_per_pattern: 10,
1041 ..Default::default()
1042 };
1043 let generator = SyntheticDataGenerator::new(config);
1044 let samples = generator.generate_mut_ref_samples();
1045
1046 assert_eq!(samples.len(), 10);
1047 for sample in &samples {
1048 assert!(matches!(
1049 sample.sample.label,
1050 InferredOwnership::BorrowedMut
1051 ));
1052 assert!(sample.rust_code.contains("&mut"));
1053 }
1054 }
1055
1056 #[test]
1057 fn synthetic_generator_slice() {
1058 let config = SyntheticConfig {
1059 samples_per_pattern: 10,
1060 ..Default::default()
1061 };
1062 let generator = SyntheticDataGenerator::new(config);
1063 let samples = generator.generate_slice_samples();
1064
1065 assert_eq!(samples.len(), 10);
1066 for sample in &samples {
1067 assert!(matches!(sample.sample.label, InferredOwnership::Slice));
1068 assert!(sample.rust_code.contains("&[i32]"));
1069 }
1070 }
1071
1072 #[test]
1073 fn synthetic_generator_full_dataset() {
1074 let config = SyntheticConfig {
1075 samples_per_pattern: 10,
1076 ..Default::default()
1077 };
1078 let generator = SyntheticDataGenerator::new(config);
1079 let dataset = generator.generate_full_dataset();
1080
1081 assert_eq!(dataset.len(), 50);
1083
1084 let stats = dataset.stats();
1085 assert_eq!(stats.source_distribution.get("Synthetic"), Some(&50));
1086 }
1087
1088 #[test]
1093 fn collector_new() {
1094 let collector = TrainingDataCollector::new();
1095 assert_eq!(collector.sample_count(), 0);
1096 assert_eq!(collector.error_count(), 0);
1097 }
1098
1099 #[test]
1100 fn collector_add_synthetic() {
1101 let mut collector = TrainingDataCollector::new();
1102 let config = SyntheticConfig {
1103 samples_per_pattern: 10,
1104 ..Default::default()
1105 };
1106 let generator = SyntheticDataGenerator::new(config);
1107
1108 collector.add_synthetic(&generator);
1109
1110 assert_eq!(collector.sample_count(), 50);
1111 }
1112
1113 #[test]
1114 fn collector_record_error() {
1115 let mut collector = TrainingDataCollector::new();
1116 collector.record_error("Failed to parse file");
1117
1118 assert_eq!(collector.error_count(), 1);
1119 assert!(collector.result().errors[0].contains("parse"));
1120 }
1121
1122 #[test]
1123 fn collector_build() {
1124 let mut collector = TrainingDataCollector::new();
1125 let config = SyntheticConfig {
1126 samples_per_pattern: 10,
1127 ..Default::default()
1128 };
1129 let generator = SyntheticDataGenerator::new(config);
1130
1131 collector.add_synthetic(&generator);
1132 let dataset = collector.build("test", "1.0.0");
1133
1134 assert_eq!(dataset.len(), 50);
1135 assert_eq!(dataset.name(), "test");
1136 assert_eq!(dataset.version(), "1.0.0");
1137 }
1138
1139 #[test]
1140 fn collector_result() {
1141 let mut collector = TrainingDataCollector::new();
1142 let config = SyntheticConfig {
1143 samples_per_pattern: 10,
1144 ..Default::default()
1145 };
1146 let generator = SyntheticDataGenerator::new(config);
1147
1148 collector.add_synthetic(&generator);
1149 let result = collector.result();
1150
1151 assert_eq!(result.samples_collected, 50);
1152 assert!(result.errors.is_empty());
1153 }
1154
1155 #[test]
1160 fn generate_1000_samples() {
1161 let config = SyntheticConfig {
1163 samples_per_pattern: 200,
1164 ..Default::default()
1165 };
1166 let generator = SyntheticDataGenerator::new(config);
1167 let dataset = generator.generate_full_dataset();
1168
1169 assert!(dataset.len() >= 1000);
1171
1172 let stats = dataset.stats();
1173 assert!(stats.label_distribution.len() >= 4);
1175 }
1176}