1use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TestGenSample {
11 pub function: String,
13 pub unit_tests: String,
15 #[serde(default)]
17 pub property_tests: Option<String>,
18 #[serde(default)]
20 pub metadata: SampleMetadata,
21}
22
23#[derive(Debug, Clone, Default, Serialize, Deserialize)]
25pub struct SampleMetadata {
26 #[serde(default)]
28 pub crate_name: Option<String>,
29 #[serde(default)]
31 pub complexity: Option<u32>,
32 #[serde(default)]
34 pub has_generics: bool,
35 #[serde(default)]
37 pub has_lifetimes: bool,
38 #[serde(default)]
40 pub is_async: bool,
41}
42
43#[derive(Debug, Clone)]
45pub struct TestGenCorpus {
46 pub train: Vec<TestGenSample>,
48 pub validation: Vec<TestGenSample>,
50 pub test: Vec<TestGenSample>,
52}
53
54#[derive(Debug, Clone)]
56pub struct CorpusStats {
57 pub total_samples: usize,
59 pub train_samples: usize,
61 pub validation_samples: usize,
63 pub test_samples: usize,
65 pub with_proptest: usize,
67 pub with_generics: usize,
69 pub with_lifetimes: usize,
71 pub with_async: usize,
73 pub avg_function_len: usize,
75 pub avg_test_len: usize,
77}
78
79impl TestGenCorpus {
80 #[must_use]
82 pub const fn new() -> Self {
83 Self { train: Vec::new(), validation: Vec::new(), test: Vec::new() }
84 }
85
86 pub fn load_jsonl(
92 train_path: &Path,
93 validation_path: &Path,
94 test_path: &Path,
95 ) -> Result<Self, CorpusError> {
96 let train = Self::load_jsonl_file(train_path)?;
97 let validation = Self::load_jsonl_file(validation_path)?;
98 let test = Self::load_jsonl_file(test_path)?;
99
100 Ok(Self { train, validation, test })
101 }
102
103 fn load_jsonl_file(path: &Path) -> Result<Vec<TestGenSample>, CorpusError> {
105 let content =
106 std::fs::read_to_string(path).map_err(|e| CorpusError::IoError(e.to_string()))?;
107
108 let mut samples = Vec::new();
109 for (line_num, line) in content.lines().enumerate() {
110 if line.trim().is_empty() {
111 continue;
112 }
113 let sample: TestGenSample = serde_json::from_str(line).map_err(|e| {
114 CorpusError::ParseError { line: line_num + 1, message: e.to_string() }
115 })?;
116 samples.push(sample);
117 }
118
119 Ok(samples)
120 }
121
122 #[must_use]
124 pub fn mock(train_size: usize, val_size: usize, test_size: usize) -> Self {
125 let make_samples = |n: usize| -> Vec<TestGenSample> {
126 (0..n)
127 .map(|i| TestGenSample {
128 function: format!(
129 "/// Sample function {i}\npub fn sample_{i}(x: i32) -> i32 {{ x + {i} }}"
130 ),
131 unit_tests: format!(
132 "#[test]\nfn test_sample_{i}() {{ assert_eq!(sample_{i}(0), {i}); }}"
133 ),
134 property_tests: if i % 4 == 0 {
135 Some(format!(
136 "proptest! {{ #[test] fn prop_{i}(x in any::<i32>()) {{ prop_assert!(sample_{i}(x) >= x); }} }}"
137 ))
138 } else {
139 None
140 },
141 metadata: SampleMetadata {
142 crate_name: Some(format!("crate_{}", i % 10)),
143 complexity: Some((i % 15) as u32 + 1),
144 has_generics: i % 5 == 0,
145 has_lifetimes: i % 7 == 0,
146 is_async: i % 10 == 0,
147 },
148 })
149 .collect()
150 };
151
152 Self {
153 train: make_samples(train_size),
154 validation: make_samples(val_size),
155 test: make_samples(test_size),
156 }
157 }
158
159 #[must_use]
161 pub fn stats(&self) -> CorpusStats {
162 let all: Vec<&TestGenSample> =
163 self.train.iter().chain(self.validation.iter()).chain(self.test.iter()).collect();
164
165 let total = all.len();
166 if total == 0 {
167 return CorpusStats {
168 total_samples: 0,
169 train_samples: 0,
170 validation_samples: 0,
171 test_samples: 0,
172 with_proptest: 0,
173 with_generics: 0,
174 with_lifetimes: 0,
175 with_async: 0,
176 avg_function_len: 0,
177 avg_test_len: 0,
178 };
179 }
180
181 let with_proptest = all.iter().filter(|s| s.property_tests.is_some()).count();
182 let with_generics = all.iter().filter(|s| s.metadata.has_generics).count();
183 let with_lifetimes = all.iter().filter(|s| s.metadata.has_lifetimes).count();
184 let with_async = all.iter().filter(|s| s.metadata.is_async).count();
185
186 let total_fn_len: usize = all.iter().map(|s| s.function.len()).sum();
187 let total_test_len: usize = all.iter().map(|s| s.unit_tests.len()).sum();
188
189 CorpusStats {
190 total_samples: total,
191 train_samples: self.train.len(),
192 validation_samples: self.validation.len(),
193 test_samples: self.test.len(),
194 with_proptest,
195 with_generics,
196 with_lifetimes,
197 with_async,
198 avg_function_len: total_fn_len / total,
199 avg_test_len: total_test_len / total,
200 }
201 }
202
203 #[must_use]
205 pub fn len(&self) -> usize {
206 self.train.len() + self.validation.len() + self.test.len()
207 }
208
209 #[must_use]
211 pub fn is_empty(&self) -> bool {
212 self.train.is_empty() && self.validation.is_empty() && self.test.is_empty()
213 }
214
215 pub fn shuffle_train(&mut self, seed: u64) {
217 use std::collections::hash_map::DefaultHasher;
218 use std::hash::{Hash, Hasher};
219
220 let n = self.train.len();
222 for i in (1..n).rev() {
223 let mut hasher = DefaultHasher::new();
224 seed.hash(&mut hasher);
225 i.hash(&mut hasher);
226 let j = (hasher.finish() as usize) % (i + 1);
227 self.train.swap(i, j);
228 }
229 }
230
231 #[must_use]
233 pub fn format_prompt(sample: &TestGenSample) -> String {
234 format!(
235 "<|im_start|>system\n\
236 You are a Rust testing expert. Generate comprehensive unit tests and property-based tests.\n\
237 <|im_end|>\n\
238 <|im_start|>user\n\
239 Generate tests for this function:\n\n\
240 ```rust\n{}\n```\n\
241 <|im_end|>\n\
242 <|im_start|>assistant\n",
243 sample.function
244 )
245 }
246
247 #[must_use]
249 pub fn format_target(sample: &TestGenSample) -> String {
250 let mut target = sample.unit_tests.clone();
251 if let Some(ref prop) = sample.property_tests {
252 target.push_str("\n\n");
253 target.push_str(prop);
254 }
255 target.push_str("\n<|im_end|>");
256 target
257 }
258}
259
260impl Default for TestGenCorpus {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266#[derive(Debug, Clone)]
268pub enum CorpusError {
269 IoError(String),
271 ParseError { line: usize, message: String },
273 InvalidFormat(String),
275}
276
277impl std::fmt::Display for CorpusError {
278 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279 match self {
280 Self::IoError(msg) => write!(f, "IO error: {msg}"),
281 Self::ParseError { line, message } => {
282 write!(f, "Parse error at line {line}: {message}")
283 }
284 Self::InvalidFormat(msg) => write!(f, "Invalid format: {msg}"),
285 }
286 }
287}
288
289impl std::error::Error for CorpusError {}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_corpus_new() {
297 let corpus = TestGenCorpus::new();
298 assert!(corpus.is_empty());
299 assert_eq!(corpus.len(), 0);
300 }
301
302 #[test]
303 fn test_corpus_mock() {
304 let corpus = TestGenCorpus::mock(100, 20, 20);
305 assert_eq!(corpus.train.len(), 100);
306 assert_eq!(corpus.validation.len(), 20);
307 assert_eq!(corpus.test.len(), 20);
308 assert_eq!(corpus.len(), 140);
309 assert!(!corpus.is_empty());
310 }
311
312 #[test]
313 fn test_corpus_stats() {
314 let corpus = TestGenCorpus::mock(80, 10, 10);
315 let stats = corpus.stats();
316
317 assert_eq!(stats.total_samples, 100);
318 assert_eq!(stats.train_samples, 80);
319 assert_eq!(stats.validation_samples, 10);
320 assert_eq!(stats.test_samples, 10);
321 assert!(stats.with_proptest > 0);
322 assert!(stats.avg_function_len > 0);
323 assert!(stats.avg_test_len > 0);
324 }
325
326 #[test]
327 fn test_corpus_stats_empty() {
328 let corpus = TestGenCorpus::new();
329 let stats = corpus.stats();
330 assert_eq!(stats.total_samples, 0);
331 assert_eq!(stats.avg_function_len, 0);
332 }
333
334 #[test]
335 fn test_corpus_shuffle_deterministic() {
336 let mut corpus1 = TestGenCorpus::mock(50, 0, 0);
337 let mut corpus2 = TestGenCorpus::mock(50, 0, 0);
338
339 corpus1.shuffle_train(42);
340 corpus2.shuffle_train(42);
341
342 for (a, b) in corpus1.train.iter().zip(corpus2.train.iter()) {
344 assert_eq!(a.function, b.function);
345 }
346 }
347
348 #[test]
349 fn test_corpus_shuffle_different_seeds() {
350 let mut corpus1 = TestGenCorpus::mock(50, 0, 0);
351 let mut corpus2 = TestGenCorpus::mock(50, 0, 0);
352
353 corpus1.shuffle_train(42);
354 corpus2.shuffle_train(123);
355
356 let same_count = corpus1
358 .train
359 .iter()
360 .zip(corpus2.train.iter())
361 .filter(|(a, b)| a.function == b.function)
362 .count();
363
364 assert!(same_count < 50);
366 }
367
368 #[test]
369 fn test_sample_serialization() {
370 let sample = TestGenSample {
371 function: "pub fn foo() {}".into(),
372 unit_tests: "#[test] fn test_foo() {}".into(),
373 property_tests: Some("proptest! {}".into()),
374 metadata: SampleMetadata {
375 crate_name: Some("test".into()),
376 complexity: Some(5),
377 has_generics: true,
378 has_lifetimes: false,
379 is_async: false,
380 },
381 };
382
383 let json = serde_json::to_string(&sample).expect("JSON serialization should succeed");
384 let restored: TestGenSample =
385 serde_json::from_str(&json).expect("JSON deserialization should succeed");
386
387 assert_eq!(restored.function, sample.function);
388 assert_eq!(restored.unit_tests, sample.unit_tests);
389 assert_eq!(restored.property_tests, sample.property_tests);
390 assert!(restored.metadata.has_generics);
391 }
392
393 #[test]
394 fn test_format_prompt() {
395 let sample = TestGenSample {
396 function: "pub fn add(a: i32, b: i32) -> i32 { a + b }".into(),
397 unit_tests: String::new(),
398 property_tests: None,
399 metadata: SampleMetadata::default(),
400 };
401
402 let prompt = TestGenCorpus::format_prompt(&sample);
403 assert!(prompt.contains("<|im_start|>system"));
404 assert!(prompt.contains("pub fn add"));
405 assert!(prompt.contains("<|im_start|>assistant"));
406 }
407
408 #[test]
409 fn test_format_target() {
410 let sample = TestGenSample {
411 function: String::new(),
412 unit_tests: "#[test] fn test() {}".into(),
413 property_tests: Some("proptest! {}".into()),
414 metadata: SampleMetadata::default(),
415 };
416
417 let target = TestGenCorpus::format_target(&sample);
418 assert!(target.contains("#[test]"));
419 assert!(target.contains("proptest!"));
420 assert!(target.ends_with("<|im_end|>"));
421 }
422
423 #[test]
424 fn test_corpus_error_display() {
425 let io_err = CorpusError::IoError("file not found".into());
426 assert!(io_err.to_string().contains("IO error"));
427
428 let parse_err = CorpusError::ParseError { line: 5, message: "invalid json".into() };
429 assert!(parse_err.to_string().contains("line 5"));
430 }
431
432 #[test]
433 fn test_mock_metadata_distribution() {
434 let corpus = TestGenCorpus::mock(100, 0, 0);
435 let stats = corpus.stats();
436
437 assert!(stats.with_generics >= 15 && stats.with_generics <= 25);
439
440 assert!(stats.with_proptest >= 20 && stats.with_proptest <= 30);
442
443 assert!(stats.with_async >= 8 && stats.with_async <= 12);
445 }
446
447 #[test]
448 fn test_corpus_error_invalid_format() {
449 let err = CorpusError::InvalidFormat("bad format".into());
450 assert!(err.to_string().contains("Invalid format"));
451 assert!(err.to_string().contains("bad format"));
452 }
453
454 #[test]
455 fn test_sample_metadata_default() {
456 let meta = SampleMetadata::default();
457 assert!(meta.crate_name.is_none());
458 assert!(meta.complexity.is_none());
459 assert!(!meta.has_generics);
460 assert!(!meta.has_lifetimes);
461 assert!(!meta.is_async);
462 }
463
464 #[test]
465 fn test_corpus_default() {
466 let corpus = TestGenCorpus::default();
467 assert!(corpus.is_empty());
468 assert_eq!(corpus.len(), 0);
469 }
470
471 #[test]
472 fn test_format_target_without_proptest() {
473 let sample = TestGenSample {
474 function: String::new(),
475 unit_tests: "#[test] fn test() { assert!(true); }".into(),
476 property_tests: None,
477 metadata: SampleMetadata::default(),
478 };
479
480 let target = TestGenCorpus::format_target(&sample);
481 assert!(target.contains("#[test]"));
482 assert!(!target.contains("proptest!"));
483 assert!(target.ends_with("<|im_end|>"));
484 }
485
486 #[test]
487 fn test_corpus_stats_with_lifetimes() {
488 let corpus = TestGenCorpus::mock(7, 0, 0);
489 let stats = corpus.stats();
490 assert!(stats.with_lifetimes >= 1);
492 }
493
494 #[test]
495 fn test_sample_with_all_metadata() {
496 let sample = TestGenSample {
497 function: "pub fn foo<T: Clone + 'a>(x: &'a T) -> T { x.clone() }".into(),
498 unit_tests: "#[test] fn test() {}".into(),
499 property_tests: Some("proptest! {}".into()),
500 metadata: SampleMetadata {
501 crate_name: Some("my_crate".into()),
502 complexity: Some(15),
503 has_generics: true,
504 has_lifetimes: true,
505 is_async: false,
506 },
507 };
508
509 assert!(sample.metadata.has_generics);
510 assert!(sample.metadata.has_lifetimes);
511 assert_eq!(sample.metadata.complexity, Some(15));
512 }
513
514 #[test]
515 fn test_corpus_load_jsonl_nonexistent() {
516 let result = TestGenCorpus::load_jsonl(
517 std::path::Path::new("/nonexistent/train.jsonl"),
518 std::path::Path::new("/nonexistent/val.jsonl"),
519 std::path::Path::new("/nonexistent/test.jsonl"),
520 );
521 assert!(result.is_err());
522 }
523}