#[cfg(feature = "parquet")]
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use indicatif::{ProgressBar, ProgressStyle};
use rayon::prelude::*;
use crate::generator::{GeneratedCode, Generator, SamplingStrategy};
use crate::Language;
#[derive(Debug, Clone)]
pub struct PipelineConfig {
pub count: usize,
pub max_depth: usize,
pub seed: u64,
pub strategy: PipelineStrategy,
pub shard_size_bytes: usize,
pub output_dir: Option<String>,
pub show_progress: bool,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
count: 10_000,
max_depth: 3,
seed: 42,
strategy: PipelineStrategy::CoverageGuided,
shard_size_bytes: 1024 * 1024 * 1024, output_dir: None,
show_progress: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PipelineStrategy {
Exhaustive,
CoverageGuided,
Swarm,
Boundary,
Random,
}
#[derive(Debug, Clone, Default)]
pub struct PipelineStats {
pub total_generated: usize,
pub valid_count: usize,
pub invalid_count: usize,
pub shards_written: usize,
pub bytes_written: usize,
pub generation_time_ms: u64,
}
impl PipelineStats {
#[must_use]
pub fn throughput(&self) -> f64 {
if self.generation_time_ms == 0 {
return 0.0;
}
(self.total_generated as f64) / (self.generation_time_ms as f64 / 1000.0)
}
#[must_use]
pub fn pass_rate(&self) -> f64 {
if self.total_generated == 0 {
return 0.0;
}
(self.valid_count as f64 / self.total_generated as f64) * 100.0
}
}
#[derive(Debug)]
pub struct DataPipeline {
config: PipelineConfig,
languages: Vec<Language>,
}
impl DataPipeline {
#[must_use]
pub fn new() -> Self {
Self {
config: PipelineConfig::default(),
languages: vec![Language::Python],
}
}
#[must_use]
pub fn with_config(config: PipelineConfig) -> Self {
Self {
config,
languages: vec![Language::Python],
}
}
#[must_use]
pub fn languages(mut self, languages: Vec<Language>) -> Self {
self.languages = languages;
self
}
#[must_use]
pub fn count(mut self, count: usize) -> Self {
self.config.count = count;
self
}
#[must_use]
pub fn max_depth(mut self, depth: usize) -> Self {
self.config.max_depth = depth;
self
}
#[must_use]
pub fn seed(mut self, seed: u64) -> Self {
self.config.seed = seed;
self
}
#[must_use]
pub fn strategy(mut self, strategy: PipelineStrategy) -> Self {
self.config.strategy = strategy;
self
}
#[must_use]
pub fn output_dir(mut self, dir: impl Into<String>) -> Self {
self.config.output_dir = Some(dir.into());
self
}
#[must_use]
pub fn show_progress(mut self, show: bool) -> Self {
self.config.show_progress = show;
self
}
pub fn generate(&self) -> (Vec<GeneratedCode>, PipelineStats) {
contract_pre_generator_coverage!(input);
let start = std::time::Instant::now();
let count_per_language = self.config.count / self.languages.len().max(1);
let progress = if self.config.show_progress {
let pb = ProgressBar::new(self.config.count as u64);
if let Ok(style) = ProgressStyle::default_bar().template(
"{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
) {
pb.set_style(style.progress_chars("#>-"));
}
Some(pb)
} else {
None
};
let valid_count = Arc::new(AtomicUsize::new(0));
let invalid_count = Arc::new(AtomicUsize::new(0));
let all_programs: Vec<GeneratedCode> = self
.languages
.par_iter()
.flat_map(|lang| {
let generator = Generator::new(*lang);
self.generate_for_language(
&generator,
count_per_language,
progress.as_ref(),
&valid_count,
&invalid_count,
)
})
.collect();
if let Some(pb) = &progress {
pb.finish_with_message("Generation complete");
}
let elapsed = start.elapsed();
let stats = PipelineStats {
total_generated: all_programs.len(),
valid_count: valid_count.load(Ordering::Relaxed),
invalid_count: invalid_count.load(Ordering::Relaxed),
shards_written: 0,
bytes_written: 0,
generation_time_ms: elapsed.as_millis() as u64,
};
(all_programs, stats)
}
fn generate_for_language(
&self,
generator: &Generator,
count: usize,
progress: Option<&ProgressBar>,
valid_count: &Arc<AtomicUsize>,
_invalid_count: &Arc<AtomicUsize>,
) -> Vec<GeneratedCode> {
let batch_size = 100;
let num_batches = (count + batch_size - 1) / batch_size;
(0..num_batches)
.into_par_iter()
.flat_map(|batch_idx| {
let batch_count = if batch_idx == num_batches - 1 {
count - (batch_idx * batch_size)
} else {
batch_size
};
let batch_seed = self.config.seed.wrapping_add(batch_idx as u64);
let programs = self.generate_batch(generator, batch_count, batch_seed);
let valid = programs.len();
valid_count.fetch_add(valid, Ordering::Relaxed);
if let Some(pb) = progress {
pb.inc(batch_count as u64);
}
programs
})
.collect()
}
fn generate_batch(&self, generator: &Generator, count: usize, seed: u64) -> Vec<GeneratedCode> {
match self.config.strategy {
PipelineStrategy::Exhaustive => generator
.generate_exhaustive(self.config.max_depth)
.into_iter()
.take(count)
.collect(),
PipelineStrategy::CoverageGuided => {
generator.generate_coverage_guided(count, self.config.max_depth, seed)
}
PipelineStrategy::Swarm => {
generator.generate_swarm(count, self.config.max_depth, 5, seed)
}
PipelineStrategy::Boundary => {
let strategy = SamplingStrategy::Boundary {
boundary_probability: 0.3,
};
generator.generate(strategy, count).unwrap_or_default()
}
PipelineStrategy::Random => {
let strategy = SamplingStrategy::Random { seed, count };
generator.generate(strategy, count).unwrap_or_default()
}
}
}
#[cfg(feature = "parquet")]
pub fn generate_to_parquet(&self, output_dir: &Path) -> crate::Result<PipelineStats> {
use super::parquet::ParquetWriter;
use crate::data::{CodeFeatures, GenerationMetadata, TestCase, TestResult};
let (programs, mut stats) = self.generate();
std::fs::create_dir_all(output_dir)
.map_err(|e| crate::Error::Data(format!("Failed to create output dir: {e}")))?;
let shard_count = 1000; let mut shard_idx = 0;
let mut bytes_written = 0;
for chunk in programs.chunks(shard_count) {
let shard_path = output_dir.join(format!("shard_{shard_idx:05}.parquet"));
let mut writer = ParquetWriter::new(&shard_path, 100)?;
for prog in chunk {
let test_case = TestCase {
id: uuid::Uuid::new_v4(),
source_language: prog.language,
source_code: prog.code.clone(),
target_language: Language::Rust,
target_code: None,
result: TestResult::Pass, features: CodeFeatures {
ast_depth: prog.ast_depth as u32,
..Default::default()
},
metadata: GenerationMetadata {
strategy: format!("{:?}", self.config.strategy),
mutation_operators: vec![],
timestamp: format!(
"{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
),
transpiler_version: env!("CARGO_PKG_VERSION").to_string(),
},
};
writer.write(test_case)?;
}
writer.close()?;
if let Ok(meta) = std::fs::metadata(&shard_path) {
bytes_written += meta.len() as usize;
}
shard_idx += 1;
}
stats.shards_written = shard_idx;
stats.bytes_written = bytes_written;
Ok(stats)
}
}
impl Default for DataPipeline {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_config_default() {
let config = PipelineConfig::default();
assert_eq!(config.count, 10_000);
assert_eq!(config.max_depth, 3);
assert_eq!(config.seed, 42);
assert_eq!(config.strategy, PipelineStrategy::CoverageGuided);
}
#[test]
fn test_pipeline_new() {
let pipeline = DataPipeline::new();
assert_eq!(pipeline.config.count, 10_000);
assert_eq!(pipeline.languages.len(), 1);
}
#[test]
fn test_pipeline_builder() {
let pipeline = DataPipeline::new()
.count(1000)
.max_depth(2)
.seed(123)
.strategy(PipelineStrategy::Swarm)
.show_progress(false);
assert_eq!(pipeline.config.count, 1000);
assert_eq!(pipeline.config.max_depth, 2);
assert_eq!(pipeline.config.seed, 123);
assert_eq!(pipeline.config.strategy, PipelineStrategy::Swarm);
assert!(!pipeline.config.show_progress);
}
#[test]
fn test_pipeline_languages() {
let pipeline = DataPipeline::new().languages(vec![Language::Python, Language::Bash]);
assert_eq!(pipeline.languages.len(), 2);
}
#[test]
fn test_pipeline_generate_small() {
let pipeline = DataPipeline::new()
.count(10)
.max_depth(2)
.show_progress(false);
let (programs, stats) = pipeline.generate();
assert!(!programs.is_empty());
assert!(stats.total_generated > 0);
assert!(stats.generation_time_ms > 0 || stats.total_generated < 10);
}
#[test]
fn test_pipeline_generate_exhaustive() {
let pipeline = DataPipeline::new()
.count(50)
.max_depth(2)
.strategy(PipelineStrategy::Exhaustive)
.show_progress(false);
let (programs, stats) = pipeline.generate();
assert!(!programs.is_empty());
assert!(stats.valid_count > 0);
}
#[test]
fn test_pipeline_generate_coverage() {
let pipeline = DataPipeline::new()
.count(20)
.max_depth(2)
.strategy(PipelineStrategy::CoverageGuided)
.show_progress(false);
let (programs, _stats) = pipeline.generate();
assert!(!programs.is_empty());
}
#[test]
fn test_pipeline_generate_swarm() {
let pipeline = DataPipeline::new()
.count(20)
.max_depth(2)
.strategy(PipelineStrategy::Swarm)
.show_progress(false);
let (programs, _stats) = pipeline.generate();
assert!(!programs.is_empty());
}
#[test]
fn test_pipeline_generate_boundary() {
let pipeline = DataPipeline::new()
.count(10)
.strategy(PipelineStrategy::Boundary)
.show_progress(false);
let (programs, _stats) = pipeline.generate();
assert!(!programs.is_empty());
}
#[test]
fn test_pipeline_generate_random() {
let pipeline = DataPipeline::new()
.count(10)
.strategy(PipelineStrategy::Random)
.show_progress(false);
let (programs, _stats) = pipeline.generate();
assert!(!programs.is_empty());
}
#[test]
fn test_pipeline_stats_throughput() {
let stats = PipelineStats {
total_generated: 1000,
valid_count: 950,
invalid_count: 50,
shards_written: 1,
bytes_written: 1024,
generation_time_ms: 1000,
};
assert!((stats.throughput() - 1000.0).abs() < 0.1);
assert!((stats.pass_rate() - 95.0).abs() < 0.1);
}
#[test]
fn test_pipeline_stats_zero_time() {
let stats = PipelineStats {
total_generated: 100,
generation_time_ms: 0,
..Default::default()
};
assert!((stats.throughput() - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_pipeline_stats_default() {
let stats = PipelineStats::default();
assert_eq!(stats.total_generated, 0);
assert!((stats.pass_rate() - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_pipeline_multi_language() {
let pipeline = DataPipeline::new()
.languages(vec![Language::Python, Language::Bash])
.count(20)
.max_depth(2)
.show_progress(false);
let (programs, stats) = pipeline.generate();
let python_count = programs
.iter()
.filter(|p| p.language == Language::Python)
.count();
let bash_count = programs
.iter()
.filter(|p| p.language == Language::Bash)
.count();
assert!(python_count > 0 || bash_count > 0);
assert!(stats.total_generated > 0);
}
#[test]
fn test_pipeline_config_clone() {
let config = PipelineConfig::default();
let cloned = config.clone();
assert_eq!(cloned.count, config.count);
}
#[test]
fn test_pipeline_config_debug() {
let config = PipelineConfig::default();
let debug = format!("{:?}", config);
assert!(debug.contains("PipelineConfig"));
}
#[test]
fn test_pipeline_strategy_eq() {
assert_eq!(PipelineStrategy::Exhaustive, PipelineStrategy::Exhaustive);
assert_ne!(PipelineStrategy::Exhaustive, PipelineStrategy::Swarm);
}
}