Skip to main content

entrenar/finetune/
corpus.rs

1//! Test generation corpus loader
2//!
3//! Loads and manages training data for Rust test generation.
4
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8/// A single training sample for test generation
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TestGenSample {
11    /// Source function code
12    pub function: String,
13    /// Generated unit tests
14    pub unit_tests: String,
15    /// Property-based tests (optional)
16    #[serde(default)]
17    pub property_tests: Option<String>,
18    /// Metadata about the sample
19    #[serde(default)]
20    pub metadata: SampleMetadata,
21}
22
23/// Metadata about a training sample
24#[derive(Debug, Clone, Default, Serialize, Deserialize)]
25pub struct SampleMetadata {
26    /// Source crate name
27    #[serde(default)]
28    pub crate_name: Option<String>,
29    /// Cyclomatic complexity
30    #[serde(default)]
31    pub complexity: Option<u32>,
32    /// Whether function uses generics
33    #[serde(default)]
34    pub has_generics: bool,
35    /// Whether function uses lifetimes
36    #[serde(default)]
37    pub has_lifetimes: bool,
38    /// Whether function is async
39    #[serde(default)]
40    pub is_async: bool,
41}
42
43/// Test generation corpus
44#[derive(Debug, Clone)]
45pub struct TestGenCorpus {
46    /// Training samples
47    pub train: Vec<TestGenSample>,
48    /// Validation samples
49    pub validation: Vec<TestGenSample>,
50    /// Test samples (holdout)
51    pub test: Vec<TestGenSample>,
52}
53
54/// Corpus statistics
55#[derive(Debug, Clone)]
56pub struct CorpusStats {
57    /// Total number of samples
58    pub total_samples: usize,
59    /// Training samples
60    pub train_samples: usize,
61    /// Validation samples
62    pub validation_samples: usize,
63    /// Test samples
64    pub test_samples: usize,
65    /// Samples with property tests
66    pub with_proptest: usize,
67    /// Samples with generics
68    pub with_generics: usize,
69    /// Samples with lifetimes
70    pub with_lifetimes: usize,
71    /// Samples with async
72    pub with_async: usize,
73    /// Average function length (chars)
74    pub avg_function_len: usize,
75    /// Average test length (chars)
76    pub avg_test_len: usize,
77}
78
79impl TestGenCorpus {
80    /// Create empty corpus
81    #[must_use]
82    pub const fn new() -> Self {
83        Self { train: Vec::new(), validation: Vec::new(), test: Vec::new() }
84    }
85
86    /// Load corpus from JSONL files
87    ///
88    /// # Errors
89    ///
90    /// Returns error if files cannot be read or parsed.
91    pub fn load_jsonl(
92        train_path: &Path,
93        validation_path: &Path,
94        test_path: &Path,
95    ) -> Result<Self, CorpusError> {
96        let train = Self::load_jsonl_file(train_path)?;
97        let validation = Self::load_jsonl_file(validation_path)?;
98        let test = Self::load_jsonl_file(test_path)?;
99
100        Ok(Self { train, validation, test })
101    }
102
103    /// Load samples from a single JSONL file
104    fn load_jsonl_file(path: &Path) -> Result<Vec<TestGenSample>, CorpusError> {
105        let content =
106            std::fs::read_to_string(path).map_err(|e| CorpusError::IoError(e.to_string()))?;
107
108        let mut samples = Vec::new();
109        for (line_num, line) in content.lines().enumerate() {
110            if line.trim().is_empty() {
111                continue;
112            }
113            let sample: TestGenSample = serde_json::from_str(line).map_err(|e| {
114                CorpusError::ParseError { line: line_num + 1, message: e.to_string() }
115            })?;
116            samples.push(sample);
117        }
118
119        Ok(samples)
120    }
121
122    /// Create mock corpus for testing
123    #[must_use]
124    pub fn mock(train_size: usize, val_size: usize, test_size: usize) -> Self {
125        let make_samples = |n: usize| -> Vec<TestGenSample> {
126            (0..n)
127                .map(|i| TestGenSample {
128                    function: format!(
129                        "/// Sample function {i}\npub fn sample_{i}(x: i32) -> i32 {{ x + {i} }}"
130                    ),
131                    unit_tests: format!(
132                        "#[test]\nfn test_sample_{i}() {{ assert_eq!(sample_{i}(0), {i}); }}"
133                    ),
134                    property_tests: if i % 4 == 0 {
135                        Some(format!(
136                            "proptest! {{ #[test] fn prop_{i}(x in any::<i32>()) {{ prop_assert!(sample_{i}(x) >= x); }} }}"
137                        ))
138                    } else {
139                        None
140                    },
141                    metadata: SampleMetadata {
142                        crate_name: Some(format!("crate_{}", i % 10)),
143                        complexity: Some((i % 15) as u32 + 1),
144                        has_generics: i % 5 == 0,
145                        has_lifetimes: i % 7 == 0,
146                        is_async: i % 10 == 0,
147                    },
148                })
149                .collect()
150        };
151
152        Self {
153            train: make_samples(train_size),
154            validation: make_samples(val_size),
155            test: make_samples(test_size),
156        }
157    }
158
159    /// Get corpus statistics
160    #[must_use]
161    pub fn stats(&self) -> CorpusStats {
162        let all: Vec<&TestGenSample> =
163            self.train.iter().chain(self.validation.iter()).chain(self.test.iter()).collect();
164
165        let total = all.len();
166        if total == 0 {
167            return CorpusStats {
168                total_samples: 0,
169                train_samples: 0,
170                validation_samples: 0,
171                test_samples: 0,
172                with_proptest: 0,
173                with_generics: 0,
174                with_lifetimes: 0,
175                with_async: 0,
176                avg_function_len: 0,
177                avg_test_len: 0,
178            };
179        }
180
181        let with_proptest = all.iter().filter(|s| s.property_tests.is_some()).count();
182        let with_generics = all.iter().filter(|s| s.metadata.has_generics).count();
183        let with_lifetimes = all.iter().filter(|s| s.metadata.has_lifetimes).count();
184        let with_async = all.iter().filter(|s| s.metadata.is_async).count();
185
186        let total_fn_len: usize = all.iter().map(|s| s.function.len()).sum();
187        let total_test_len: usize = all.iter().map(|s| s.unit_tests.len()).sum();
188
189        CorpusStats {
190            total_samples: total,
191            train_samples: self.train.len(),
192            validation_samples: self.validation.len(),
193            test_samples: self.test.len(),
194            with_proptest,
195            with_generics,
196            with_lifetimes,
197            with_async,
198            avg_function_len: total_fn_len / total,
199            avg_test_len: total_test_len / total,
200        }
201    }
202
203    /// Total number of samples
204    #[must_use]
205    pub fn len(&self) -> usize {
206        self.train.len() + self.validation.len() + self.test.len()
207    }
208
209    /// Check if corpus is empty
210    #[must_use]
211    pub fn is_empty(&self) -> bool {
212        self.train.is_empty() && self.validation.is_empty() && self.test.is_empty()
213    }
214
215    /// Shuffle training data with seed
216    pub fn shuffle_train(&mut self, seed: u64) {
217        use std::collections::hash_map::DefaultHasher;
218        use std::hash::{Hash, Hasher};
219
220        // Simple Fisher-Yates with deterministic pseudo-random
221        let n = self.train.len();
222        for i in (1..n).rev() {
223            let mut hasher = DefaultHasher::new();
224            seed.hash(&mut hasher);
225            i.hash(&mut hasher);
226            let j = (hasher.finish() as usize) % (i + 1);
227            self.train.swap(i, j);
228        }
229    }
230
231    /// Format sample as prompt for model
232    #[must_use]
233    pub fn format_prompt(sample: &TestGenSample) -> String {
234        format!(
235            "<|im_start|>system\n\
236            You are a Rust testing expert. Generate comprehensive unit tests and property-based tests.\n\
237            <|im_end|>\n\
238            <|im_start|>user\n\
239            Generate tests for this function:\n\n\
240            ```rust\n{}\n```\n\
241            <|im_end|>\n\
242            <|im_start|>assistant\n",
243            sample.function
244        )
245    }
246
247    /// Format sample as target for training
248    #[must_use]
249    pub fn format_target(sample: &TestGenSample) -> String {
250        let mut target = sample.unit_tests.clone();
251        if let Some(ref prop) = sample.property_tests {
252            target.push_str("\n\n");
253            target.push_str(prop);
254        }
255        target.push_str("\n<|im_end|>");
256        target
257    }
258}
259
260impl Default for TestGenCorpus {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266/// Corpus loading error
267#[derive(Debug, Clone)]
268pub enum CorpusError {
269    /// IO error
270    IoError(String),
271    /// JSON parse error
272    ParseError { line: usize, message: String },
273    /// Invalid format
274    InvalidFormat(String),
275}
276
277impl std::fmt::Display for CorpusError {
278    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279        match self {
280            Self::IoError(msg) => write!(f, "IO error: {msg}"),
281            Self::ParseError { line, message } => {
282                write!(f, "Parse error at line {line}: {message}")
283            }
284            Self::InvalidFormat(msg) => write!(f, "Invalid format: {msg}"),
285        }
286    }
287}
288
289impl std::error::Error for CorpusError {}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_corpus_new() {
297        let corpus = TestGenCorpus::new();
298        assert!(corpus.is_empty());
299        assert_eq!(corpus.len(), 0);
300    }
301
302    #[test]
303    fn test_corpus_mock() {
304        let corpus = TestGenCorpus::mock(100, 20, 20);
305        assert_eq!(corpus.train.len(), 100);
306        assert_eq!(corpus.validation.len(), 20);
307        assert_eq!(corpus.test.len(), 20);
308        assert_eq!(corpus.len(), 140);
309        assert!(!corpus.is_empty());
310    }
311
312    #[test]
313    fn test_corpus_stats() {
314        let corpus = TestGenCorpus::mock(80, 10, 10);
315        let stats = corpus.stats();
316
317        assert_eq!(stats.total_samples, 100);
318        assert_eq!(stats.train_samples, 80);
319        assert_eq!(stats.validation_samples, 10);
320        assert_eq!(stats.test_samples, 10);
321        assert!(stats.with_proptest > 0);
322        assert!(stats.avg_function_len > 0);
323        assert!(stats.avg_test_len > 0);
324    }
325
326    #[test]
327    fn test_corpus_stats_empty() {
328        let corpus = TestGenCorpus::new();
329        let stats = corpus.stats();
330        assert_eq!(stats.total_samples, 0);
331        assert_eq!(stats.avg_function_len, 0);
332    }
333
334    #[test]
335    fn test_corpus_shuffle_deterministic() {
336        let mut corpus1 = TestGenCorpus::mock(50, 0, 0);
337        let mut corpus2 = TestGenCorpus::mock(50, 0, 0);
338
339        corpus1.shuffle_train(42);
340        corpus2.shuffle_train(42);
341
342        // Same seed should produce same order
343        for (a, b) in corpus1.train.iter().zip(corpus2.train.iter()) {
344            assert_eq!(a.function, b.function);
345        }
346    }
347
348    #[test]
349    fn test_corpus_shuffle_different_seeds() {
350        let mut corpus1 = TestGenCorpus::mock(50, 0, 0);
351        let mut corpus2 = TestGenCorpus::mock(50, 0, 0);
352
353        corpus1.shuffle_train(42);
354        corpus2.shuffle_train(123);
355
356        // Different seeds should produce different order
357        let same_count = corpus1
358            .train
359            .iter()
360            .zip(corpus2.train.iter())
361            .filter(|(a, b)| a.function == b.function)
362            .count();
363
364        // Some might match by chance, but not all
365        assert!(same_count < 50);
366    }
367
368    #[test]
369    fn test_sample_serialization() {
370        let sample = TestGenSample {
371            function: "pub fn foo() {}".into(),
372            unit_tests: "#[test] fn test_foo() {}".into(),
373            property_tests: Some("proptest! {}".into()),
374            metadata: SampleMetadata {
375                crate_name: Some("test".into()),
376                complexity: Some(5),
377                has_generics: true,
378                has_lifetimes: false,
379                is_async: false,
380            },
381        };
382
383        let json = serde_json::to_string(&sample).expect("JSON serialization should succeed");
384        let restored: TestGenSample =
385            serde_json::from_str(&json).expect("JSON deserialization should succeed");
386
387        assert_eq!(restored.function, sample.function);
388        assert_eq!(restored.unit_tests, sample.unit_tests);
389        assert_eq!(restored.property_tests, sample.property_tests);
390        assert!(restored.metadata.has_generics);
391    }
392
393    #[test]
394    fn test_format_prompt() {
395        let sample = TestGenSample {
396            function: "pub fn add(a: i32, b: i32) -> i32 { a + b }".into(),
397            unit_tests: String::new(),
398            property_tests: None,
399            metadata: SampleMetadata::default(),
400        };
401
402        let prompt = TestGenCorpus::format_prompt(&sample);
403        assert!(prompt.contains("<|im_start|>system"));
404        assert!(prompt.contains("pub fn add"));
405        assert!(prompt.contains("<|im_start|>assistant"));
406    }
407
408    #[test]
409    fn test_format_target() {
410        let sample = TestGenSample {
411            function: String::new(),
412            unit_tests: "#[test] fn test() {}".into(),
413            property_tests: Some("proptest! {}".into()),
414            metadata: SampleMetadata::default(),
415        };
416
417        let target = TestGenCorpus::format_target(&sample);
418        assert!(target.contains("#[test]"));
419        assert!(target.contains("proptest!"));
420        assert!(target.ends_with("<|im_end|>"));
421    }
422
423    #[test]
424    fn test_corpus_error_display() {
425        let io_err = CorpusError::IoError("file not found".into());
426        assert!(io_err.to_string().contains("IO error"));
427
428        let parse_err = CorpusError::ParseError { line: 5, message: "invalid json".into() };
429        assert!(parse_err.to_string().contains("line 5"));
430    }
431
432    #[test]
433    fn test_mock_metadata_distribution() {
434        let corpus = TestGenCorpus::mock(100, 0, 0);
435        let stats = corpus.stats();
436
437        // ~20% should have generics (every 5th)
438        assert!(stats.with_generics >= 15 && stats.with_generics <= 25);
439
440        // ~25% should have proptest (every 4th)
441        assert!(stats.with_proptest >= 20 && stats.with_proptest <= 30);
442
443        // ~10% should be async (every 10th)
444        assert!(stats.with_async >= 8 && stats.with_async <= 12);
445    }
446
447    #[test]
448    fn test_corpus_error_invalid_format() {
449        let err = CorpusError::InvalidFormat("bad format".into());
450        assert!(err.to_string().contains("Invalid format"));
451        assert!(err.to_string().contains("bad format"));
452    }
453
454    #[test]
455    fn test_sample_metadata_default() {
456        let meta = SampleMetadata::default();
457        assert!(meta.crate_name.is_none());
458        assert!(meta.complexity.is_none());
459        assert!(!meta.has_generics);
460        assert!(!meta.has_lifetimes);
461        assert!(!meta.is_async);
462    }
463
464    #[test]
465    fn test_corpus_default() {
466        let corpus = TestGenCorpus::default();
467        assert!(corpus.is_empty());
468        assert_eq!(corpus.len(), 0);
469    }
470
471    #[test]
472    fn test_format_target_without_proptest() {
473        let sample = TestGenSample {
474            function: String::new(),
475            unit_tests: "#[test] fn test() { assert!(true); }".into(),
476            property_tests: None,
477            metadata: SampleMetadata::default(),
478        };
479
480        let target = TestGenCorpus::format_target(&sample);
481        assert!(target.contains("#[test]"));
482        assert!(!target.contains("proptest!"));
483        assert!(target.ends_with("<|im_end|>"));
484    }
485
486    #[test]
487    fn test_corpus_stats_with_lifetimes() {
488        let corpus = TestGenCorpus::mock(7, 0, 0);
489        let stats = corpus.stats();
490        // Every 7th sample has lifetimes (i % 7 == 0)
491        assert!(stats.with_lifetimes >= 1);
492    }
493
494    #[test]
495    fn test_sample_with_all_metadata() {
496        let sample = TestGenSample {
497            function: "pub fn foo<T: Clone + 'a>(x: &'a T) -> T { x.clone() }".into(),
498            unit_tests: "#[test] fn test() {}".into(),
499            property_tests: Some("proptest! {}".into()),
500            metadata: SampleMetadata {
501                crate_name: Some("my_crate".into()),
502                complexity: Some(15),
503                has_generics: true,
504                has_lifetimes: true,
505                is_async: false,
506            },
507        };
508
509        assert!(sample.metadata.has_generics);
510        assert!(sample.metadata.has_lifetimes);
511        assert_eq!(sample.metadata.complexity, Some(15));
512    }
513
514    #[test]
515    fn test_corpus_load_jsonl_nonexistent() {
516        let result = TestGenCorpus::load_jsonl(
517            std::path::Path::new("/nonexistent/train.jsonl"),
518            std::path::Path::new("/nonexistent/val.jsonl"),
519            std::path::Path::new("/nonexistent/test.jsonl"),
520        );
521        assert!(result.is_err());
522    }
523}