Skip to main content

decy_ownership/
training_data.rs

1//! Training data collection and management (DECY-ML-010).
2//!
3//! Provides infrastructure for collecting, storing, and managing
4//! labeled training data for ML-enhanced ownership inference.
5//!
6//! # Data Sources
7//!
8//! 1. **C Projects with Rust Ports**: Extract ground truth from
9//!    established C→Rust migrations (Linux kernel, SQLite, curl)
10//!
11//! 2. **Compiler Error Feedback Loop (CITL)**: Learn from rustc
12//!    errors on generated code to identify correct ownership patterns
13//!
14//! 3. **Synthetic Generation**: Use templates to generate labeled
15//!    C→Rust pairs with known ownership patterns
16//!
17//! # Toyota Way Principles
18//!
19//! - **Genchi Genbutsu**: Ground truth from real-world code
20//! - **Kaizen**: Continuous data collection improves model
21
22use std::collections::HashMap;
23use std::path::PathBuf;
24
25use crate::ml_features::{AllocationKind, InferredOwnership, OwnershipFeaturesBuilder};
26use crate::retraining_pipeline::TrainingSample;
27
28/// Source of training data.
29#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub enum DataSource {
31    /// From open-source C→Rust migration (e.g., "rusqlite", "linux-rust").
32    RustPort {
33        /// Name of the project.
34        project: String,
35    },
36    /// From compiler error feedback loop.
37    CompilerFeedback {
38        /// Error code (e.g., "E0382" for use after move).
39        error_code: String,
40    },
41    /// Synthetically generated with known pattern.
42    Synthetic {
43        /// Pattern template name.
44        template: String,
45    },
46    /// Manually annotated by human expert.
47    HumanAnnotated {
48        /// Annotator identifier.
49        annotator: String,
50    },
51}
52
53/// A labeled training sample with provenance metadata.
54#[derive(Debug, Clone)]
55pub struct LabeledSample {
56    /// Core training sample (features + label).
57    pub sample: TrainingSample,
58    /// Where this sample came from.
59    pub source: DataSource,
60    /// Confidence in the label (0.0 - 1.0).
61    pub label_confidence: f64,
62    /// Original C code snippet.
63    pub c_code: String,
64    /// Expected Rust code snippet.
65    pub rust_code: String,
66    /// Additional metadata.
67    pub metadata: HashMap<String, String>,
68}
69
70impl LabeledSample {
71    /// Create a new labeled sample.
72    pub fn new(sample: TrainingSample, source: DataSource, c_code: &str, rust_code: &str) -> Self {
73        Self {
74            sample,
75            source,
76            label_confidence: 1.0,
77            c_code: c_code.to_string(),
78            rust_code: rust_code.to_string(),
79            metadata: HashMap::new(),
80        }
81    }
82
83    /// Set label confidence.
84    pub fn with_confidence(mut self, confidence: f64) -> Self {
85        self.label_confidence = confidence.clamp(0.0, 1.0);
86        self
87    }
88
89    /// Add metadata.
90    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
91        self.metadata.insert(key.to_string(), value.to_string());
92        self
93    }
94}
95
96/// Statistics about a training dataset.
97#[derive(Debug, Clone, Default)]
98pub struct DatasetStats {
99    /// Total sample count.
100    pub total_samples: usize,
101    /// Samples per label.
102    pub label_distribution: HashMap<String, usize>,
103    /// Samples per source type.
104    pub source_distribution: HashMap<String, usize>,
105    /// Average label confidence.
106    pub avg_confidence: f64,
107    /// Minimum label confidence.
108    pub min_confidence: f64,
109    /// Maximum label confidence.
110    pub max_confidence: f64,
111}
112
113impl DatasetStats {
114    /// Check if dataset is balanced (no class has >3x samples of another).
115    pub fn is_balanced(&self) -> bool {
116        if self.label_distribution.is_empty() {
117            return true;
118        }
119
120        let counts: Vec<usize> = self.label_distribution.values().copied().collect();
121        let min_count = counts.iter().copied().min().unwrap_or(0);
122        let max_count = counts.iter().copied().max().unwrap_or(0);
123
124        min_count > 0 && max_count <= min_count * 3
125    }
126
127    /// Get the dominant label (most samples).
128    pub fn dominant_label(&self) -> Option<String> {
129        self.label_distribution
130            .iter()
131            .max_by_key(|(_, count)| *count)
132            .map(|(label, _)| label.clone())
133    }
134}
135
136/// A collection of training samples with metadata.
137#[derive(Debug, Clone, Default)]
138pub struct TrainingDataset {
139    /// All labeled samples.
140    samples: Vec<LabeledSample>,
141    /// Dataset name.
142    name: String,
143    /// Version identifier.
144    version: String,
145}
146
147impl TrainingDataset {
148    /// Create a new empty dataset.
149    pub fn new(name: &str, version: &str) -> Self {
150        Self {
151            samples: Vec::new(),
152            name: name.to_string(),
153            version: version.to_string(),
154        }
155    }
156
157    /// Add a sample to the dataset.
158    pub fn add(&mut self, sample: LabeledSample) {
159        self.samples.push(sample);
160    }
161
162    /// Add multiple samples.
163    pub fn add_all(&mut self, samples: impl IntoIterator<Item = LabeledSample>) {
164        self.samples.extend(samples);
165    }
166
167    /// Get sample count.
168    pub fn len(&self) -> usize {
169        self.samples.len()
170    }
171
172    /// Check if empty.
173    pub fn is_empty(&self) -> bool {
174        self.samples.is_empty()
175    }
176
177    /// Get all samples.
178    pub fn samples(&self) -> &[LabeledSample] {
179        &self.samples
180    }
181
182    /// Get samples by source.
183    pub fn samples_by_source(&self, source_type: &str) -> Vec<&LabeledSample> {
184        self.samples
185            .iter()
186            .filter(|s| source_type_name(&s.source) == source_type)
187            .collect()
188    }
189
190    /// Get samples by label.
191    pub fn samples_by_label(&self, label: InferredOwnership) -> Vec<&LabeledSample> {
192        self.samples
193            .iter()
194            .filter(|s| s.sample.label == label)
195            .collect()
196    }
197
198    /// Convert to training samples for the pipeline.
199    pub fn to_training_samples(&self) -> Vec<TrainingSample> {
200        self.samples.iter().map(|s| s.sample.clone()).collect()
201    }
202
203    /// Compute dataset statistics.
204    pub fn stats(&self) -> DatasetStats {
205        let mut label_dist = HashMap::new();
206        let mut source_dist = HashMap::new();
207        let mut confidence_sum = 0.0_f64;
208        let mut min_conf: f64 = 1.0;
209        let mut max_conf: f64 = 0.0;
210
211        for sample in &self.samples {
212            // Label distribution
213            let label_key = format!("{:?}", sample.sample.label);
214            *label_dist.entry(label_key).or_insert(0) += 1;
215
216            // Source distribution
217            let source_key = source_type_name(&sample.source);
218            *source_dist.entry(source_key).or_insert(0) += 1;
219
220            // Confidence stats
221            confidence_sum += sample.label_confidence;
222            min_conf = min_conf.min(sample.label_confidence);
223            max_conf = max_conf.max(sample.label_confidence);
224        }
225
226        let avg_confidence = if self.samples.is_empty() {
227            0.0
228        } else {
229            confidence_sum / self.samples.len() as f64
230        };
231
232        DatasetStats {
233            total_samples: self.samples.len(),
234            label_distribution: label_dist,
235            source_distribution: source_dist,
236            avg_confidence,
237            min_confidence: if self.samples.is_empty() {
238                0.0
239            } else {
240                min_conf
241            },
242            max_confidence: max_conf,
243        }
244    }
245
246    /// Filter samples by confidence threshold.
247    pub fn filter_by_confidence(&self, min_confidence: f64) -> TrainingDataset {
248        let samples = self
249            .samples
250            .iter()
251            .filter(|s| s.label_confidence >= min_confidence)
252            .cloned()
253            .collect();
254
255        TrainingDataset {
256            samples,
257            name: self.name.clone(),
258            version: self.version.clone(),
259        }
260    }
261
262    /// Merge with another dataset.
263    pub fn merge(&mut self, other: TrainingDataset) {
264        self.samples.extend(other.samples);
265    }
266
267    /// Get dataset name.
268    pub fn name(&self) -> &str {
269        &self.name
270    }
271
272    /// Get dataset version.
273    pub fn version(&self) -> &str {
274        &self.version
275    }
276}
277
278/// Get source type name for statistics.
279fn source_type_name(source: &DataSource) -> String {
280    match source {
281        DataSource::RustPort { .. } => "RustPort".to_string(),
282        DataSource::CompilerFeedback { .. } => "CompilerFeedback".to_string(),
283        DataSource::Synthetic { .. } => "Synthetic".to_string(),
284        DataSource::HumanAnnotated { .. } => "HumanAnnotated".to_string(),
285    }
286}
287
288/// Configuration for synthetic data generation.
289#[derive(Debug, Clone)]
290pub struct SyntheticConfig {
291    /// Number of samples to generate per pattern.
292    pub samples_per_pattern: usize,
293    /// Randomization seed for reproducibility.
294    pub seed: u64,
295    /// Whether to include edge cases.
296    pub include_edge_cases: bool,
297}
298
299impl Default for SyntheticConfig {
300    fn default() -> Self {
301        Self {
302            samples_per_pattern: 100,
303            seed: 42,
304            include_edge_cases: true,
305        }
306    }
307}
308
309/// Generator for synthetic training data.
310pub struct SyntheticDataGenerator {
311    config: SyntheticConfig,
312}
313
314impl SyntheticDataGenerator {
315    /// Create a new generator with configuration.
316    pub fn new(config: SyntheticConfig) -> Self {
317        Self { config }
318    }
319
320    /// Generate samples for malloc/free → Box pattern.
321    pub fn generate_malloc_box_samples(&self) -> Vec<LabeledSample> {
322        let mut samples = Vec::new();
323
324        // Basic malloc pattern
325        for i in 0..self.config.samples_per_pattern {
326            let features = OwnershipFeaturesBuilder::default()
327                .pointer_depth(1)
328                .allocation_site(AllocationKind::Malloc)
329                .deallocation_count(1)
330                .build();
331
332            let sample = TrainingSample::new(
333                features,
334                InferredOwnership::Owned,
335                &format!("malloc_box_{}.c", i),
336                i as u32,
337            );
338
339            let c_code = format!(
340                "int* ptr{} = (int*)malloc(sizeof(int));\n*ptr{} = {};\nfree(ptr{});",
341                i, i, i, i
342            );
343            let rust_code = format!("let ptr{}: Box<i32> = Box::new({});", i, i);
344
345            samples.push(LabeledSample::new(
346                sample,
347                DataSource::Synthetic {
348                    template: "malloc_free_box".to_string(),
349                },
350                &c_code,
351                &rust_code,
352            ));
353        }
354
355        samples
356    }
357
358    /// Generate samples for array allocation → Vec pattern.
359    pub fn generate_array_vec_samples(&self) -> Vec<LabeledSample> {
360        let mut samples = Vec::new();
361
362        for i in 0..self.config.samples_per_pattern {
363            let size = (i % 10 + 1) * 10;
364            let features = OwnershipFeaturesBuilder::default()
365                .pointer_depth(1)
366                .allocation_site(AllocationKind::Malloc)
367                .has_size_param(true)
368                .array_decay(true)
369                .deallocation_count(1)
370                .build();
371
372            let sample = TrainingSample::new(
373                features,
374                InferredOwnership::Vec,
375                &format!("array_vec_{}.c", i),
376                i as u32,
377            );
378            let _ = size; // Used in c_code formatting
379
380            let c_code = format!(
381                "int* arr = (int*)malloc({} * sizeof(int));\nfor(int j = 0; j < {}; j++) arr[j] = j;\nfree(arr);",
382                size, size
383            );
384            let rust_code = format!("let arr: Vec<i32> = (0..{}).collect();", size);
385
386            samples.push(LabeledSample::new(
387                sample,
388                DataSource::Synthetic {
389                    template: "array_vec".to_string(),
390                },
391                &c_code,
392                &rust_code,
393            ));
394        }
395
396        samples
397    }
398
399    /// Generate samples for const pointer → &T pattern.
400    pub fn generate_const_ref_samples(&self) -> Vec<LabeledSample> {
401        let mut samples = Vec::new();
402
403        for i in 0..self.config.samples_per_pattern {
404            let features = OwnershipFeaturesBuilder::default()
405                .pointer_depth(1)
406                .const_qualified(true)
407                .build();
408
409            let sample = TrainingSample::new(
410                features,
411                InferredOwnership::Borrowed,
412                &format!("const_ref_{}.c", i),
413                i as u32,
414            );
415
416            let c_code = format!(
417                "void process{}(const int* ptr) {{ printf(\"%d\", *ptr); }}",
418                i
419            );
420            let rust_code = format!("fn process{}(ptr: &i32) {{ println!(\"{{}}\", ptr); }}", i);
421
422            samples.push(LabeledSample::new(
423                sample,
424                DataSource::Synthetic {
425                    template: "const_ref".to_string(),
426                },
427                &c_code,
428                &rust_code,
429            ));
430        }
431
432        samples
433    }
434
435    /// Generate samples for mutable pointer → &mut T pattern.
436    pub fn generate_mut_ref_samples(&self) -> Vec<LabeledSample> {
437        let mut samples = Vec::new();
438
439        for i in 0..self.config.samples_per_pattern {
440            let features = OwnershipFeaturesBuilder::default()
441                .pointer_depth(1)
442                .const_qualified(false)
443                .write_count(1)
444                .build();
445
446            let sample = TrainingSample::new(
447                features,
448                InferredOwnership::BorrowedMut,
449                &format!("mut_ref_{}.c", i),
450                i as u32,
451            );
452
453            let c_code = format!("void increment{}(int* ptr) {{ (*ptr)++; }}", i);
454            let rust_code = format!("fn increment{}(ptr: &mut i32) {{ *ptr += 1; }}", i);
455
456            samples.push(LabeledSample::new(
457                sample,
458                DataSource::Synthetic {
459                    template: "mut_ref".to_string(),
460                },
461                &c_code,
462                &rust_code,
463            ));
464        }
465
466        samples
467    }
468
469    /// Generate samples for slice parameter pattern.
470    pub fn generate_slice_samples(&self) -> Vec<LabeledSample> {
471        let mut samples = Vec::new();
472
473        for i in 0..self.config.samples_per_pattern {
474            let features = OwnershipFeaturesBuilder::default()
475                .pointer_depth(1)
476                .const_qualified(true)
477                .array_decay(true)
478                .has_size_param(true)
479                .build();
480
481            let sample = TrainingSample::new(
482                features,
483                InferredOwnership::Slice,
484                &format!("slice_{}.c", i),
485                i as u32,
486            );
487
488            let c_code = format!(
489                "int sum{}(const int* arr, size_t len) {{ int s = 0; for(size_t j = 0; j < len; j++) s += arr[j]; return s; }}",
490                i
491            );
492            let rust_code = format!("fn sum{}(arr: &[i32]) -> i32 {{ arr.iter().sum() }}", i);
493
494            samples.push(LabeledSample::new(
495                sample,
496                DataSource::Synthetic {
497                    template: "slice".to_string(),
498                },
499                &c_code,
500                &rust_code,
501            ));
502        }
503
504        samples
505    }
506
507    /// Generate a complete synthetic dataset with all patterns.
508    pub fn generate_full_dataset(&self) -> TrainingDataset {
509        let mut dataset = TrainingDataset::new("synthetic", "1.0.0");
510
511        dataset.add_all(self.generate_malloc_box_samples());
512        dataset.add_all(self.generate_array_vec_samples());
513        dataset.add_all(self.generate_const_ref_samples());
514        dataset.add_all(self.generate_mut_ref_samples());
515        dataset.add_all(self.generate_slice_samples());
516
517        dataset
518    }
519}
520
521/// Result of data collection.
522#[derive(Debug)]
523pub struct CollectionResult {
524    /// Samples collected.
525    pub samples_collected: usize,
526    /// Errors encountered.
527    pub errors: Vec<String>,
528    /// Source path.
529    pub source_path: Option<PathBuf>,
530}
531
532/// Collector for training data from various sources.
533pub struct TrainingDataCollector {
534    /// Collected samples.
535    samples: Vec<LabeledSample>,
536    /// Collection errors.
537    errors: Vec<String>,
538}
539
540impl TrainingDataCollector {
541    /// Create a new collector.
542    pub fn new() -> Self {
543        Self {
544            samples: Vec::new(),
545            errors: Vec::new(),
546        }
547    }
548
549    /// Add synthetic data from generator.
550    pub fn add_synthetic(&mut self, generator: &SyntheticDataGenerator) {
551        let dataset = generator.generate_full_dataset();
552        self.samples.extend(dataset.samples);
553    }
554
555    /// Record a collection error.
556    pub fn record_error(&mut self, error: &str) {
557        self.errors.push(error.to_string());
558    }
559
560    /// Get collection result.
561    pub fn result(&self) -> CollectionResult {
562        CollectionResult {
563            samples_collected: self.samples.len(),
564            errors: self.errors.clone(),
565            source_path: None,
566        }
567    }
568
569    /// Build the final dataset.
570    pub fn build(self, name: &str, version: &str) -> TrainingDataset {
571        let mut dataset = TrainingDataset::new(name, version);
572        dataset.add_all(self.samples);
573        dataset
574    }
575
576    /// Get sample count.
577    pub fn sample_count(&self) -> usize {
578        self.samples.len()
579    }
580
581    /// Get error count.
582    pub fn error_count(&self) -> usize {
583        self.errors.len()
584    }
585}
586
587impl Default for TrainingDataCollector {
588    fn default() -> Self {
589        Self::new()
590    }
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    // ========================================================================
598    // DataSource tests
599    // ========================================================================
600
601    #[test]
602    fn data_source_rust_port() {
603        let source = DataSource::RustPort {
604            project: "rusqlite".to_string(),
605        };
606        assert_eq!(source_type_name(&source), "RustPort");
607    }
608
609    #[test]
610    fn data_source_compiler_feedback() {
611        let source = DataSource::CompilerFeedback {
612            error_code: "E0382".to_string(),
613        };
614        assert_eq!(source_type_name(&source), "CompilerFeedback");
615    }
616
617    #[test]
618    fn data_source_synthetic() {
619        let source = DataSource::Synthetic {
620            template: "malloc_box".to_string(),
621        };
622        assert_eq!(source_type_name(&source), "Synthetic");
623    }
624
625    #[test]
626    fn data_source_human_annotated() {
627        let source = DataSource::HumanAnnotated {
628            annotator: "expert1".to_string(),
629        };
630        assert_eq!(source_type_name(&source), "HumanAnnotated");
631    }
632
633    // ========================================================================
634    // LabeledSample tests
635    // ========================================================================
636
637    #[test]
638    fn labeled_sample_new() {
639        let features = OwnershipFeaturesBuilder::default().build();
640        let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 42);
641
642        let labeled = LabeledSample::new(
643            sample,
644            DataSource::Synthetic {
645                template: "test".to_string(),
646            },
647            "int* p = malloc(4);",
648            "let p: Box<i32> = Box::new(0);",
649        );
650
651        assert!((labeled.label_confidence - 1.0).abs() < 0.001);
652        assert!(labeled.c_code.contains("malloc"));
653        assert!(labeled.rust_code.contains("Box"));
654    }
655
656    #[test]
657    fn labeled_sample_with_confidence() {
658        let features = OwnershipFeaturesBuilder::default().build();
659        let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
660
661        let labeled = LabeledSample::new(
662            sample,
663            DataSource::Synthetic {
664                template: "test".to_string(),
665            },
666            "",
667            "",
668        )
669        .with_confidence(0.8);
670
671        assert!((labeled.label_confidence - 0.8).abs() < 0.001);
672    }
673
674    #[test]
675    fn labeled_sample_confidence_clamped() {
676        let features = OwnershipFeaturesBuilder::default().build();
677        let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
678
679        let labeled = LabeledSample::new(
680            sample,
681            DataSource::Synthetic {
682                template: "test".to_string(),
683            },
684            "",
685            "",
686        )
687        .with_confidence(1.5);
688
689        assert!((labeled.label_confidence - 1.0).abs() < 0.001);
690    }
691
692    #[test]
693    fn labeled_sample_with_metadata() {
694        let features = OwnershipFeaturesBuilder::default().build();
695        let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
696
697        let labeled = LabeledSample::new(
698            sample,
699            DataSource::Synthetic {
700                template: "test".to_string(),
701            },
702            "",
703            "",
704        )
705        .with_metadata("commit", "abc123");
706
707        assert_eq!(labeled.metadata.get("commit"), Some(&"abc123".to_string()));
708    }
709
710    // ========================================================================
711    // DatasetStats tests
712    // ========================================================================
713
714    #[test]
715    fn dataset_stats_is_balanced_empty() {
716        let stats = DatasetStats::default();
717        assert!(stats.is_balanced());
718    }
719
720    #[test]
721    fn dataset_stats_is_balanced_even() {
722        let mut stats = DatasetStats::default();
723        stats.label_distribution.insert("Owned".to_string(), 100);
724        stats.label_distribution.insert("Borrowed".to_string(), 100);
725        assert!(stats.is_balanced());
726    }
727
728    #[test]
729    fn dataset_stats_is_balanced_imbalanced() {
730        let mut stats = DatasetStats::default();
731        stats.label_distribution.insert("Owned".to_string(), 100);
732        stats.label_distribution.insert("Borrowed".to_string(), 10);
733        assert!(!stats.is_balanced()); // 100 > 10 * 3
734    }
735
736    #[test]
737    fn dataset_stats_dominant_label() {
738        let mut stats = DatasetStats::default();
739        stats.label_distribution.insert("Owned".to_string(), 100);
740        stats.label_distribution.insert("Borrowed".to_string(), 50);
741        assert_eq!(stats.dominant_label(), Some("Owned".to_string()));
742    }
743
744    // ========================================================================
745    // TrainingDataset tests
746    // ========================================================================
747
748    #[test]
749    fn training_dataset_new() {
750        let dataset = TrainingDataset::new("test", "1.0.0");
751        assert_eq!(dataset.name(), "test");
752        assert_eq!(dataset.version(), "1.0.0");
753        assert!(dataset.is_empty());
754    }
755
756    #[test]
757    fn training_dataset_add() {
758        let mut dataset = TrainingDataset::new("test", "1.0.0");
759
760        let features = OwnershipFeaturesBuilder::default().build();
761        let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
762        let labeled = LabeledSample::new(
763            sample,
764            DataSource::Synthetic {
765                template: "test".to_string(),
766            },
767            "",
768            "",
769        );
770
771        dataset.add(labeled);
772        assert_eq!(dataset.len(), 1);
773    }
774
775    #[test]
776    fn training_dataset_samples_accessor() {
777        let mut dataset = TrainingDataset::new("test", "1.0.0");
778        let features = OwnershipFeaturesBuilder::default().build();
779        let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
780        dataset.add(LabeledSample::new(
781            sample,
782            DataSource::Synthetic {
783                template: "test".to_string(),
784            },
785            "int* p;",
786            "let p: Box<i32>;",
787        ));
788
789        let samples = dataset.samples();
790        assert_eq!(samples.len(), 1);
791        assert!(samples[0].c_code.contains("int*"));
792    }
793
794    #[test]
795    fn training_dataset_samples_by_source() {
796        let mut dataset = TrainingDataset::new("test", "1.0.0");
797
798        let features = OwnershipFeaturesBuilder::default().build();
799        let sample = TrainingSample::new(features.clone(), InferredOwnership::Owned, "test.c", 1);
800        dataset.add(LabeledSample::new(
801            sample,
802            DataSource::Synthetic {
803                template: "test".to_string(),
804            },
805            "",
806            "",
807        ));
808
809        let sample2 = TrainingSample::new(features, InferredOwnership::Borrowed, "test.c", 2);
810        dataset.add(LabeledSample::new(
811            sample2,
812            DataSource::RustPort {
813                project: "curl".to_string(),
814            },
815            "",
816            "",
817        ));
818
819        assert_eq!(dataset.samples_by_source("Synthetic").len(), 1);
820        assert_eq!(dataset.samples_by_source("RustPort").len(), 1);
821        assert_eq!(dataset.samples_by_source("HumanAnnotated").len(), 0);
822    }
823
824    #[test]
825    fn training_dataset_samples_by_label() {
826        let mut dataset = TrainingDataset::new("test", "1.0.0");
827
828        let features = OwnershipFeaturesBuilder::default().build();
829        let sample = TrainingSample::new(features.clone(), InferredOwnership::Owned, "test.c", 1);
830        dataset.add(LabeledSample::new(
831            sample,
832            DataSource::Synthetic {
833                template: "test".to_string(),
834            },
835            "",
836            "",
837        ));
838
839        let sample2 = TrainingSample::new(features, InferredOwnership::Borrowed, "test.c", 2);
840        dataset.add(LabeledSample::new(
841            sample2,
842            DataSource::Synthetic {
843                template: "test".to_string(),
844            },
845            "",
846            "",
847        ));
848
849        assert_eq!(dataset.samples_by_label(InferredOwnership::Owned).len(), 1);
850        assert_eq!(
851            dataset.samples_by_label(InferredOwnership::Borrowed).len(),
852            1
853        );
854        assert_eq!(dataset.samples_by_label(InferredOwnership::Vec).len(), 0);
855    }
856
857    #[test]
858    fn training_dataset_merge() {
859        let mut dataset1 = TrainingDataset::new("test1", "1.0.0");
860        let features = OwnershipFeaturesBuilder::default().build();
861        let sample = TrainingSample::new(features.clone(), InferredOwnership::Owned, "test.c", 1);
862        dataset1.add(LabeledSample::new(
863            sample,
864            DataSource::Synthetic {
865                template: "test".to_string(),
866            },
867            "",
868            "",
869        ));
870
871        let mut dataset2 = TrainingDataset::new("test2", "1.0.0");
872        let sample2 = TrainingSample::new(features, InferredOwnership::Borrowed, "test.c", 2);
873        dataset2.add(LabeledSample::new(
874            sample2,
875            DataSource::Synthetic {
876                template: "test".to_string(),
877            },
878            "",
879            "",
880        ));
881
882        dataset1.merge(dataset2);
883        assert_eq!(dataset1.len(), 2);
884    }
885
886    #[test]
887    fn training_dataset_stats() {
888        let mut dataset = TrainingDataset::new("test", "1.0.0");
889
890        // Add 2 Owned samples
891        for _ in 0..2 {
892            let features = OwnershipFeaturesBuilder::default().build();
893            let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
894            dataset.add(LabeledSample::new(
895                sample,
896                DataSource::Synthetic {
897                    template: "test".to_string(),
898                },
899                "",
900                "",
901            ));
902        }
903
904        // Add 1 Borrowed sample
905        let features = OwnershipFeaturesBuilder::default().build();
906        let sample = TrainingSample::new(features, InferredOwnership::Borrowed, "test.c", 1);
907        dataset.add(LabeledSample::new(
908            sample,
909            DataSource::Synthetic {
910                template: "test".to_string(),
911            },
912            "",
913            "",
914        ));
915
916        let stats = dataset.stats();
917        assert_eq!(stats.total_samples, 3);
918        assert_eq!(stats.label_distribution.get("Owned"), Some(&2));
919        assert_eq!(stats.label_distribution.get("Borrowed"), Some(&1));
920    }
921
922    #[test]
923    fn training_dataset_filter_by_confidence() {
924        let mut dataset = TrainingDataset::new("test", "1.0.0");
925
926        let features = OwnershipFeaturesBuilder::default().build();
927        let sample = TrainingSample::new(features.clone(), InferredOwnership::Owned, "test.c", 1);
928        dataset.add(
929            LabeledSample::new(
930                sample,
931                DataSource::Synthetic {
932                    template: "test".to_string(),
933                },
934                "",
935                "",
936            )
937            .with_confidence(0.9),
938        );
939
940        let sample2 = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 2);
941        dataset.add(
942            LabeledSample::new(
943                sample2,
944                DataSource::Synthetic {
945                    template: "test".to_string(),
946                },
947                "",
948                "",
949            )
950            .with_confidence(0.5),
951        );
952
953        let filtered = dataset.filter_by_confidence(0.8);
954        assert_eq!(filtered.len(), 1);
955    }
956
957    #[test]
958    fn training_dataset_to_training_samples() {
959        let mut dataset = TrainingDataset::new("test", "1.0.0");
960
961        let features = OwnershipFeaturesBuilder::default().build();
962        let sample = TrainingSample::new(features, InferredOwnership::Owned, "test.c", 1);
963        dataset.add(LabeledSample::new(
964            sample,
965            DataSource::Synthetic {
966                template: "test".to_string(),
967            },
968            "",
969            "",
970        ));
971
972        let samples = dataset.to_training_samples();
973        assert_eq!(samples.len(), 1);
974    }
975
976    // ========================================================================
977    // SyntheticDataGenerator tests
978    // ========================================================================
979
980    #[test]
981    fn synthetic_generator_config_default() {
982        let config = SyntheticConfig::default();
983        assert_eq!(config.samples_per_pattern, 100);
984        assert_eq!(config.seed, 42);
985        assert!(config.include_edge_cases);
986    }
987
988    #[test]
989    fn synthetic_generator_malloc_box() {
990        let config = SyntheticConfig {
991            samples_per_pattern: 10,
992            ..Default::default()
993        };
994        let generator = SyntheticDataGenerator::new(config);
995        let samples = generator.generate_malloc_box_samples();
996
997        assert_eq!(samples.len(), 10);
998        for sample in &samples {
999            assert!(matches!(sample.sample.label, InferredOwnership::Owned));
1000            assert!(sample.c_code.contains("malloc"));
1001            assert!(sample.rust_code.contains("Box"));
1002        }
1003    }
1004
1005    #[test]
1006    fn synthetic_generator_array_vec() {
1007        let config = SyntheticConfig {
1008            samples_per_pattern: 10,
1009            ..Default::default()
1010        };
1011        let generator = SyntheticDataGenerator::new(config);
1012        let samples = generator.generate_array_vec_samples();
1013
1014        assert_eq!(samples.len(), 10);
1015        for sample in &samples {
1016            assert!(matches!(sample.sample.label, InferredOwnership::Vec));
1017            assert!(sample.rust_code.contains("Vec"));
1018        }
1019    }
1020
1021    #[test]
1022    fn synthetic_generator_const_ref() {
1023        let config = SyntheticConfig {
1024            samples_per_pattern: 10,
1025            ..Default::default()
1026        };
1027        let generator = SyntheticDataGenerator::new(config);
1028        let samples = generator.generate_const_ref_samples();
1029
1030        assert_eq!(samples.len(), 10);
1031        for sample in &samples {
1032            assert!(matches!(sample.sample.label, InferredOwnership::Borrowed));
1033            assert!(sample.c_code.contains("const"));
1034        }
1035    }
1036
1037    #[test]
1038    fn synthetic_generator_mut_ref() {
1039        let config = SyntheticConfig {
1040            samples_per_pattern: 10,
1041            ..Default::default()
1042        };
1043        let generator = SyntheticDataGenerator::new(config);
1044        let samples = generator.generate_mut_ref_samples();
1045
1046        assert_eq!(samples.len(), 10);
1047        for sample in &samples {
1048            assert!(matches!(
1049                sample.sample.label,
1050                InferredOwnership::BorrowedMut
1051            ));
1052            assert!(sample.rust_code.contains("&mut"));
1053        }
1054    }
1055
1056    #[test]
1057    fn synthetic_generator_slice() {
1058        let config = SyntheticConfig {
1059            samples_per_pattern: 10,
1060            ..Default::default()
1061        };
1062        let generator = SyntheticDataGenerator::new(config);
1063        let samples = generator.generate_slice_samples();
1064
1065        assert_eq!(samples.len(), 10);
1066        for sample in &samples {
1067            assert!(matches!(sample.sample.label, InferredOwnership::Slice));
1068            assert!(sample.rust_code.contains("&[i32]"));
1069        }
1070    }
1071
1072    #[test]
1073    fn synthetic_generator_full_dataset() {
1074        let config = SyntheticConfig {
1075            samples_per_pattern: 10,
1076            ..Default::default()
1077        };
1078        let generator = SyntheticDataGenerator::new(config);
1079        let dataset = generator.generate_full_dataset();
1080
1081        // 5 patterns * 10 samples each = 50
1082        assert_eq!(dataset.len(), 50);
1083
1084        let stats = dataset.stats();
1085        assert_eq!(stats.source_distribution.get("Synthetic"), Some(&50));
1086    }
1087
1088    // ========================================================================
1089    // TrainingDataCollector tests
1090    // ========================================================================
1091
1092    #[test]
1093    fn collector_new() {
1094        let collector = TrainingDataCollector::new();
1095        assert_eq!(collector.sample_count(), 0);
1096        assert_eq!(collector.error_count(), 0);
1097    }
1098
1099    #[test]
1100    fn collector_add_synthetic() {
1101        let mut collector = TrainingDataCollector::new();
1102        let config = SyntheticConfig {
1103            samples_per_pattern: 10,
1104            ..Default::default()
1105        };
1106        let generator = SyntheticDataGenerator::new(config);
1107
1108        collector.add_synthetic(&generator);
1109
1110        assert_eq!(collector.sample_count(), 50);
1111    }
1112
1113    #[test]
1114    fn collector_record_error() {
1115        let mut collector = TrainingDataCollector::new();
1116        collector.record_error("Failed to parse file");
1117
1118        assert_eq!(collector.error_count(), 1);
1119        assert!(collector.result().errors[0].contains("parse"));
1120    }
1121
1122    #[test]
1123    fn collector_build() {
1124        let mut collector = TrainingDataCollector::new();
1125        let config = SyntheticConfig {
1126            samples_per_pattern: 10,
1127            ..Default::default()
1128        };
1129        let generator = SyntheticDataGenerator::new(config);
1130
1131        collector.add_synthetic(&generator);
1132        let dataset = collector.build("test", "1.0.0");
1133
1134        assert_eq!(dataset.len(), 50);
1135        assert_eq!(dataset.name(), "test");
1136        assert_eq!(dataset.version(), "1.0.0");
1137    }
1138
1139    #[test]
1140    fn collector_result() {
1141        let mut collector = TrainingDataCollector::new();
1142        let config = SyntheticConfig {
1143            samples_per_pattern: 10,
1144            ..Default::default()
1145        };
1146        let generator = SyntheticDataGenerator::new(config);
1147
1148        collector.add_synthetic(&generator);
1149        let result = collector.result();
1150
1151        assert_eq!(result.samples_collected, 50);
1152        assert!(result.errors.is_empty());
1153    }
1154
1155    // ========================================================================
1156    // Integration tests
1157    // ========================================================================
1158
1159    #[test]
1160    fn generate_1000_samples() {
1161        // This test verifies we can generate 1000+ samples for training
1162        let config = SyntheticConfig {
1163            samples_per_pattern: 200,
1164            ..Default::default()
1165        };
1166        let generator = SyntheticDataGenerator::new(config);
1167        let dataset = generator.generate_full_dataset();
1168
1169        // 5 patterns * 200 samples = 1000
1170        assert!(dataset.len() >= 1000);
1171
1172        let stats = dataset.stats();
1173        // Should be reasonably balanced (each label type has samples)
1174        assert!(stats.label_distribution.len() >= 4);
1175    }
1176}