1#[cfg(feature = "parquet")]
14use std::path::Path;
15use std::sync::atomic::{AtomicUsize, Ordering};
16use std::sync::Arc;
17
18use indicatif::{ProgressBar, ProgressStyle};
19use rayon::prelude::*;
20
21use crate::generator::{GeneratedCode, Generator, SamplingStrategy};
22use crate::Language;
23
24#[derive(Debug, Clone)]
26pub struct PipelineConfig {
27 pub count: usize,
29 pub max_depth: usize,
31 pub seed: u64,
33 pub strategy: PipelineStrategy,
35 pub shard_size_bytes: usize,
37 pub output_dir: Option<String>,
39 pub show_progress: bool,
41}
42
43impl Default for PipelineConfig {
44 fn default() -> Self {
45 Self {
46 count: 10_000,
47 max_depth: 3,
48 seed: 42,
49 strategy: PipelineStrategy::CoverageGuided,
50 shard_size_bytes: 1024 * 1024 * 1024, output_dir: None,
52 show_progress: true,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum PipelineStrategy {
60 Exhaustive,
62 CoverageGuided,
64 Swarm,
66 Boundary,
68 Random,
70}
71
72#[derive(Debug, Clone, Default)]
74pub struct PipelineStats {
75 pub total_generated: usize,
77 pub valid_count: usize,
79 pub invalid_count: usize,
81 pub shards_written: usize,
83 pub bytes_written: usize,
85 pub generation_time_ms: u64,
87}
88
89impl PipelineStats {
90 #[must_use]
92 pub fn throughput(&self) -> f64 {
93 if self.generation_time_ms == 0 {
94 return 0.0;
95 }
96 (self.total_generated as f64) / (self.generation_time_ms as f64 / 1000.0)
97 }
98
99 #[must_use]
101 pub fn pass_rate(&self) -> f64 {
102 if self.total_generated == 0 {
103 return 0.0;
104 }
105 (self.valid_count as f64 / self.total_generated as f64) * 100.0
106 }
107}
108
109#[derive(Debug)]
111pub struct DataPipeline {
112 config: PipelineConfig,
113 languages: Vec<Language>,
114}
115
116impl DataPipeline {
117 #[must_use]
119 pub fn new() -> Self {
120 Self {
121 config: PipelineConfig::default(),
122 languages: vec![Language::Python],
123 }
124 }
125
126 #[must_use]
128 pub fn with_config(config: PipelineConfig) -> Self {
129 Self {
130 config,
131 languages: vec![Language::Python],
132 }
133 }
134
135 #[must_use]
137 pub fn languages(mut self, languages: Vec<Language>) -> Self {
138 self.languages = languages;
139 self
140 }
141
142 #[must_use]
144 pub fn count(mut self, count: usize) -> Self {
145 self.config.count = count;
146 self
147 }
148
149 #[must_use]
151 pub fn max_depth(mut self, depth: usize) -> Self {
152 self.config.max_depth = depth;
153 self
154 }
155
156 #[must_use]
158 pub fn seed(mut self, seed: u64) -> Self {
159 self.config.seed = seed;
160 self
161 }
162
163 #[must_use]
165 pub fn strategy(mut self, strategy: PipelineStrategy) -> Self {
166 self.config.strategy = strategy;
167 self
168 }
169
170 #[must_use]
172 pub fn output_dir(mut self, dir: impl Into<String>) -> Self {
173 self.config.output_dir = Some(dir.into());
174 self
175 }
176
177 #[must_use]
179 pub fn show_progress(mut self, show: bool) -> Self {
180 self.config.show_progress = show;
181 self
182 }
183
184 pub fn generate(&self) -> (Vec<GeneratedCode>, PipelineStats) {
188 contract_pre_generator_coverage!(input);
189 let start = std::time::Instant::now();
190 let count_per_language = self.config.count / self.languages.len().max(1);
191
192 let progress = if self.config.show_progress {
194 let pb = ProgressBar::new(self.config.count as u64);
195 if let Ok(style) = ProgressStyle::default_bar().template(
197 "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
198 ) {
199 pb.set_style(style.progress_chars("#>-"));
200 }
201 Some(pb)
202 } else {
203 None
204 };
205
206 let valid_count = Arc::new(AtomicUsize::new(0));
207 let invalid_count = Arc::new(AtomicUsize::new(0));
208
209 let all_programs: Vec<GeneratedCode> = self
211 .languages
212 .par_iter()
213 .flat_map(|lang| {
214 let generator = Generator::new(*lang);
215 self.generate_for_language(
216 &generator,
217 count_per_language,
218 progress.as_ref(),
219 &valid_count,
220 &invalid_count,
221 )
222 })
223 .collect();
224
225 if let Some(pb) = &progress {
226 pb.finish_with_message("Generation complete");
227 }
228
229 let elapsed = start.elapsed();
230 let stats = PipelineStats {
231 total_generated: all_programs.len(),
232 valid_count: valid_count.load(Ordering::Relaxed),
233 invalid_count: invalid_count.load(Ordering::Relaxed),
234 shards_written: 0,
235 bytes_written: 0,
236 generation_time_ms: elapsed.as_millis() as u64,
237 };
238
239 (all_programs, stats)
240 }
241
242 fn generate_for_language(
244 &self,
245 generator: &Generator,
246 count: usize,
247 progress: Option<&ProgressBar>,
248 valid_count: &Arc<AtomicUsize>,
249 _invalid_count: &Arc<AtomicUsize>,
250 ) -> Vec<GeneratedCode> {
251 let batch_size = 100;
252 let num_batches = (count + batch_size - 1) / batch_size;
253
254 (0..num_batches)
255 .into_par_iter()
256 .flat_map(|batch_idx| {
257 let batch_count = if batch_idx == num_batches - 1 {
258 count - (batch_idx * batch_size)
259 } else {
260 batch_size
261 };
262
263 let batch_seed = self.config.seed.wrapping_add(batch_idx as u64);
264 let programs = self.generate_batch(generator, batch_count, batch_seed);
265
266 let valid = programs.len();
268 valid_count.fetch_add(valid, Ordering::Relaxed);
269
270 if let Some(pb) = progress {
271 pb.inc(batch_count as u64);
272 }
273
274 programs
275 })
276 .collect()
277 }
278
279 fn generate_batch(&self, generator: &Generator, count: usize, seed: u64) -> Vec<GeneratedCode> {
281 match self.config.strategy {
282 PipelineStrategy::Exhaustive => generator
283 .generate_exhaustive(self.config.max_depth)
284 .into_iter()
285 .take(count)
286 .collect(),
287 PipelineStrategy::CoverageGuided => {
288 generator.generate_coverage_guided(count, self.config.max_depth, seed)
289 }
290 PipelineStrategy::Swarm => {
291 generator.generate_swarm(count, self.config.max_depth, 5, seed)
292 }
293 PipelineStrategy::Boundary => {
294 let strategy = SamplingStrategy::Boundary {
295 boundary_probability: 0.3,
296 };
297 generator.generate(strategy, count).unwrap_or_default()
298 }
299 PipelineStrategy::Random => {
300 let strategy = SamplingStrategy::Random { seed, count };
301 generator.generate(strategy, count).unwrap_or_default()
302 }
303 }
304 }
305
306 #[cfg(feature = "parquet")]
316 pub fn generate_to_parquet(&self, output_dir: &Path) -> crate::Result<PipelineStats> {
317 use super::parquet::ParquetWriter;
318 use crate::data::{CodeFeatures, GenerationMetadata, TestCase, TestResult};
319
320 let (programs, mut stats) = self.generate();
321
322 std::fs::create_dir_all(output_dir)
324 .map_err(|e| crate::Error::Data(format!("Failed to create output dir: {e}")))?;
325
326 let shard_count = 1000; let mut shard_idx = 0;
329 let mut bytes_written = 0;
330
331 for chunk in programs.chunks(shard_count) {
332 let shard_path = output_dir.join(format!("shard_{shard_idx:05}.parquet"));
333 let mut writer = ParquetWriter::new(&shard_path, 100)?;
334
335 for prog in chunk {
336 let test_case = TestCase {
337 id: uuid::Uuid::new_v4(),
338 source_language: prog.language,
339 source_code: prog.code.clone(),
340 target_language: Language::Rust,
341 target_code: None,
342 result: TestResult::Pass, features: CodeFeatures {
344 ast_depth: prog.ast_depth as u32,
345 ..Default::default()
346 },
347 metadata: GenerationMetadata {
348 strategy: format!("{:?}", self.config.strategy),
349 mutation_operators: vec![],
350 timestamp: format!(
351 "{}",
352 std::time::SystemTime::now()
353 .duration_since(std::time::UNIX_EPOCH)
354 .unwrap_or_default()
355 .as_secs()
356 ),
357 transpiler_version: env!("CARGO_PKG_VERSION").to_string(),
358 },
359 };
360 writer.write(test_case)?;
361 }
362
363 writer.close()?;
364
365 if let Ok(meta) = std::fs::metadata(&shard_path) {
367 bytes_written += meta.len() as usize;
368 }
369 shard_idx += 1;
370 }
371
372 stats.shards_written = shard_idx;
373 stats.bytes_written = bytes_written;
374
375 Ok(stats)
376 }
377}
378
379impl Default for DataPipeline {
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[test]
390 fn test_pipeline_config_default() {
391 let config = PipelineConfig::default();
392 assert_eq!(config.count, 10_000);
393 assert_eq!(config.max_depth, 3);
394 assert_eq!(config.seed, 42);
395 assert_eq!(config.strategy, PipelineStrategy::CoverageGuided);
396 }
397
398 #[test]
399 fn test_pipeline_new() {
400 let pipeline = DataPipeline::new();
401 assert_eq!(pipeline.config.count, 10_000);
402 assert_eq!(pipeline.languages.len(), 1);
403 }
404
405 #[test]
406 fn test_pipeline_builder() {
407 let pipeline = DataPipeline::new()
408 .count(1000)
409 .max_depth(2)
410 .seed(123)
411 .strategy(PipelineStrategy::Swarm)
412 .show_progress(false);
413
414 assert_eq!(pipeline.config.count, 1000);
415 assert_eq!(pipeline.config.max_depth, 2);
416 assert_eq!(pipeline.config.seed, 123);
417 assert_eq!(pipeline.config.strategy, PipelineStrategy::Swarm);
418 assert!(!pipeline.config.show_progress);
419 }
420
421 #[test]
422 fn test_pipeline_languages() {
423 let pipeline = DataPipeline::new().languages(vec![Language::Python, Language::Bash]);
424 assert_eq!(pipeline.languages.len(), 2);
425 }
426
427 #[test]
428 fn test_pipeline_generate_small() {
429 let pipeline = DataPipeline::new()
430 .count(10)
431 .max_depth(2)
432 .show_progress(false);
433
434 let (programs, stats) = pipeline.generate();
435
436 assert!(!programs.is_empty());
437 assert!(stats.total_generated > 0);
438 assert!(stats.generation_time_ms > 0 || stats.total_generated < 10);
439 }
440
441 #[test]
442 fn test_pipeline_generate_exhaustive() {
443 let pipeline = DataPipeline::new()
444 .count(50)
445 .max_depth(2)
446 .strategy(PipelineStrategy::Exhaustive)
447 .show_progress(false);
448
449 let (programs, stats) = pipeline.generate();
450 assert!(!programs.is_empty());
451 assert!(stats.valid_count > 0);
452 }
453
454 #[test]
455 fn test_pipeline_generate_coverage() {
456 let pipeline = DataPipeline::new()
457 .count(20)
458 .max_depth(2)
459 .strategy(PipelineStrategy::CoverageGuided)
460 .show_progress(false);
461
462 let (programs, _stats) = pipeline.generate();
463 assert!(!programs.is_empty());
464 }
465
466 #[test]
467 fn test_pipeline_generate_swarm() {
468 let pipeline = DataPipeline::new()
469 .count(20)
470 .max_depth(2)
471 .strategy(PipelineStrategy::Swarm)
472 .show_progress(false);
473
474 let (programs, _stats) = pipeline.generate();
475 assert!(!programs.is_empty());
476 }
477
478 #[test]
479 fn test_pipeline_generate_boundary() {
480 let pipeline = DataPipeline::new()
481 .count(10)
482 .strategy(PipelineStrategy::Boundary)
483 .show_progress(false);
484
485 let (programs, _stats) = pipeline.generate();
486 assert!(!programs.is_empty());
487 }
488
489 #[test]
490 fn test_pipeline_generate_random() {
491 let pipeline = DataPipeline::new()
492 .count(10)
493 .strategy(PipelineStrategy::Random)
494 .show_progress(false);
495
496 let (programs, _stats) = pipeline.generate();
497 assert!(!programs.is_empty());
498 }
499
500 #[test]
501 fn test_pipeline_stats_throughput() {
502 let stats = PipelineStats {
503 total_generated: 1000,
504 valid_count: 950,
505 invalid_count: 50,
506 shards_written: 1,
507 bytes_written: 1024,
508 generation_time_ms: 1000,
509 };
510
511 assert!((stats.throughput() - 1000.0).abs() < 0.1);
512 assert!((stats.pass_rate() - 95.0).abs() < 0.1);
513 }
514
515 #[test]
516 fn test_pipeline_stats_zero_time() {
517 let stats = PipelineStats {
518 total_generated: 100,
519 generation_time_ms: 0,
520 ..Default::default()
521 };
522 assert!((stats.throughput() - 0.0).abs() < f64::EPSILON);
523 }
524
525 #[test]
526 fn test_pipeline_stats_default() {
527 let stats = PipelineStats::default();
528 assert_eq!(stats.total_generated, 0);
529 assert!((stats.pass_rate() - 0.0).abs() < f64::EPSILON);
530 }
531
532 #[test]
533 fn test_pipeline_multi_language() {
534 let pipeline = DataPipeline::new()
535 .languages(vec![Language::Python, Language::Bash])
536 .count(20)
537 .max_depth(2)
538 .show_progress(false);
539
540 let (programs, stats) = pipeline.generate();
541
542 let python_count = programs
544 .iter()
545 .filter(|p| p.language == Language::Python)
546 .count();
547 let bash_count = programs
548 .iter()
549 .filter(|p| p.language == Language::Bash)
550 .count();
551
552 assert!(python_count > 0 || bash_count > 0);
553 assert!(stats.total_generated > 0);
554 }
555
556 #[test]
557 fn test_pipeline_config_clone() {
558 let config = PipelineConfig::default();
559 let cloned = config.clone();
560 assert_eq!(cloned.count, config.count);
561 }
562
563 #[test]
564 fn test_pipeline_config_debug() {
565 let config = PipelineConfig::default();
566 let debug = format!("{:?}", config);
567 assert!(debug.contains("PipelineConfig"));
568 }
569
570 #[test]
571 fn test_pipeline_strategy_eq() {
572 assert_eq!(PipelineStrategy::Exhaustive, PipelineStrategy::Exhaustive);
573 assert_ne!(PipelineStrategy::Exhaustive, PipelineStrategy::Swarm);
574 }
575}