use crate::parsers::lightweight::{LightweightFileInfo, LightweightParseStats};
use crate::parsers::parse_file_lightweight;
use crossbeam_channel::{bounded, Receiver};
use std::path::PathBuf;
use std::thread;
pub struct ParallelPipelineResult {
receiver: Option<Receiver<LightweightFileInfo>>,
pub total_files: usize,
producer_handle: Option<thread::JoinHandle<()>>,
worker_handles: Vec<thread::JoinHandle<WorkerStats>>,
}
impl ParallelPipelineResult {
pub fn take_receiver(&mut self) -> Option<Receiver<LightweightFileInfo>> {
self.receiver.take()
}
pub fn iter(&mut self) -> impl Iterator<Item = LightweightFileInfo> + '_ {
self.receiver.take().into_iter().flatten()
}
}
#[derive(Debug, Default)]
pub struct WorkerStats {
pub parsed: usize,
pub errors: usize,
}
#[derive(Debug, Default)]
pub struct PipelineStats {
pub total_files: usize,
pub parsed_files: usize,
pub parse_errors: usize,
pub total_functions: usize,
pub total_classes: usize,
}
impl ParallelPipelineResult {
pub fn join(mut self) -> PipelineStats {
if let Some(h) = self.producer_handle.take() {
let _ = h.join();
}
let mut stats = PipelineStats {
total_files: self.total_files,
..Default::default()
};
for handle in self.worker_handles.drain(..) {
if let Ok(worker_stats) = handle.join() {
stats.parsed_files += worker_stats.parsed;
stats.parse_errors += worker_stats.errors;
}
}
stats
}
}
pub fn parse_files_pipeline(
files: Vec<PathBuf>,
num_workers: usize,
buffer_size: usize,
) -> ParallelPipelineResult {
let total_files = files.len();
let num_workers = num_workers.max(1);
let (file_tx, file_rx) = bounded::<PathBuf>(buffer_size);
let (result_tx, result_rx) = bounded::<LightweightFileInfo>(buffer_size);
let producer_handle = thread::spawn(move || {
for file in files {
if file_tx.send(file).is_err() {
break;
}
}
drop(file_tx);
});
let mut worker_handles = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
let rx = file_rx.clone();
let tx = result_tx.clone();
let handle = thread::spawn(move || {
let mut stats = WorkerStats::default();
for path in rx {
match parse_file_lightweight(&path) {
Ok(info) => {
stats.parsed += 1;
if tx.send(info).is_err() {
break;
}
}
Err(e) => {
stats.errors += 1;
tracing::warn!("Failed to parse {}: {}", path.display(), e);
}
}
}
stats
});
worker_handles.push(handle);
}
drop(file_rx);
drop(result_tx);
ParallelPipelineResult {
receiver: Some(result_rx),
total_files,
producer_handle: Some(producer_handle),
worker_handles,
}
}
pub fn parse_files_parallel_pipeline(
files: Vec<PathBuf>,
num_workers: usize,
buffer_size: usize,
progress: Option<&(dyn Fn(usize, usize) + Sync)>,
) -> (Vec<LightweightFileInfo>, LightweightParseStats) {
let total = files.len();
let mut pipeline = parse_files_pipeline(files, num_workers, buffer_size);
let mut results = Vec::with_capacity(total);
let mut stats = LightweightParseStats {
total_files: total,
..Default::default()
};
let receiver = pipeline.take_receiver().expect("receiver already taken");
let mut count = 0;
for info in receiver {
count += 1;
if let Some(cb) = progress {
if count % 100 == 0 || count == total {
cb(count, total);
}
}
stats.add_file(&info);
results.push(info);
}
let pipeline_stats = pipeline.join();
stats.parse_errors = pipeline_stats.parse_errors;
stats.parsed_files = pipeline_stats.parsed_files;
(results, stats)
}
pub fn stream_parse_parallel<F>(
files: Vec<PathBuf>,
num_workers: usize,
buffer_size: usize,
mut on_file: F,
progress: Option<&(dyn Fn(usize, usize) + Sync)>,
) -> LightweightParseStats
where
F: FnMut(LightweightFileInfo),
{
let total = files.len();
let mut pipeline = parse_files_pipeline(files, num_workers, buffer_size);
let mut stats = LightweightParseStats {
total_files: total,
..Default::default()
};
let receiver = pipeline.take_receiver().expect("receiver already taken");
let mut count = 0;
for info in receiver {
count += 1;
if let Some(cb) = progress {
if count % 100 == 0 || count == total {
cb(count, total);
}
}
stats.add_file(&info);
on_file(info);
}
let pipeline_stats = pipeline.join();
stats.parse_errors = pipeline_stats.parse_errors;
stats.parsed_files = pipeline_stats.parsed_files;
stats
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_pipeline_single_file() {
let mut file = NamedTempFile::with_suffix(".py").unwrap();
writeln!(file, "def hello():\n pass").unwrap();
let files = vec![file.path().to_path_buf()];
let mut pipeline = parse_files_pipeline(files, 1, 10);
let receiver = pipeline.take_receiver().unwrap();
let results: Vec<_> = receiver.iter().collect();
assert_eq!(results.len(), 1);
assert!(!results[0].functions.is_empty());
let _ = pipeline.join();
}
#[test]
fn test_pipeline_multiple_workers() {
let mut files = Vec::new();
let mut temp_files = Vec::new();
for i in 0..10 {
let mut file = NamedTempFile::with_suffix(".py").unwrap();
writeln!(file, "def func{}():\n pass", i).unwrap();
files.push(file.path().to_path_buf());
temp_files.push(file);
}
let (results, stats) = parse_files_parallel_pipeline(files, 4, 5, None);
assert_eq!(results.len(), 10);
assert_eq!(stats.parsed_files, 10);
assert_eq!(stats.total_functions, 10);
}
#[test]
fn test_stream_parse() {
let mut file = NamedTempFile::with_suffix(".py").unwrap();
writeln!(file, "def test(): pass").unwrap();
let files = vec![file.path().to_path_buf()];
let mut count = 0;
let stats = stream_parse_parallel(files, 1, 10, |_info| count += 1, None);
assert_eq!(count, 1);
assert_eq!(stats.parsed_files, 1);
}
}