use crate::parse::grammar::LanguageId;
use crate::parse::languages::parser_for_language;
use crate::parse::traits::SignatureInfo;
use rayon::prelude::*;
use std::cell::RefCell;
use std::path::PathBuf;
use std::time::Instant;
use tree_sitter::Parser;
thread_local! {
static THREAD_PARSER: RefCell<Parser> = RefCell::new(Parser::new());
}
#[derive(Debug, Clone)]
pub struct ParsingResult {
pub file_path: PathBuf,
pub language: Option<String>,
pub signatures: Vec<SignatureInfo>,
pub source_bytes: Option<Vec<u8>>,
pub error: Option<String>,
pub parse_time_ms: u64,
}
impl ParsingResult {
fn success(
file_path: PathBuf,
language: String,
signatures: Vec<SignatureInfo>,
source_bytes: Vec<u8>,
parse_time_ms: u64,
) -> Self {
Self {
file_path,
language: Some(language),
signatures,
source_bytes: Some(source_bytes),
error: None,
parse_time_ms,
}
}
fn failure(file_path: PathBuf, error: String) -> Self {
Self {
file_path,
language: None,
signatures: Vec::new(),
source_bytes: None,
error: Some(error),
parse_time_ms: 0,
}
}
pub fn is_success(&self) -> bool {
self.error.is_none()
}
pub fn is_failure(&self) -> bool {
self.error.is_some()
}
}
#[derive(Debug, Clone)]
pub struct ParsingStats {
pub total_files: usize,
pub successful_files: usize,
pub failed_files: usize,
pub total_signatures: usize,
pub total_time_ms: u64,
pub avg_time_per_file_ms: f64,
}
impl ParsingStats {
fn from_results(results: &[ParsingResult], total_time_ms: u64) -> Self {
let successful = results.iter().filter(|r| r.is_success()).count();
let failed = results.iter().filter(|r| r.is_failure()).count();
let total_signatures = results.iter().map(|r| r.signatures.len()).sum();
let avg_time = if results.is_empty() {
0.0
} else {
total_time_ms as f64 / results.len() as f64
};
Self {
total_files: results.len(),
successful_files: successful,
failed_files: failed,
total_signatures,
total_time_ms,
avg_time_per_file_ms: avg_time,
}
}
}
pub struct ParallelParser {
max_threads: Option<usize>,
collect_stats: bool,
}
impl Default for ParallelParser {
fn default() -> Self {
Self::new()
}
}
impl ParallelParser {
pub fn new() -> Self {
Self {
max_threads: None,
collect_stats: true,
}
}
pub fn with_max_threads(mut self, max_threads: usize) -> Self {
self.max_threads = Some(max_threads);
self
}
pub fn without_stats(mut self) -> Self {
self.collect_stats = false;
self
}
pub fn parse_files(&self, file_paths: Vec<PathBuf>) -> Vec<ParsingResult> {
let (results, _) = self.parse_files_with_stats(file_paths);
results
}
pub fn parse_files_with_stats(
&self,
file_paths: Vec<PathBuf>,
) -> (Vec<ParsingResult>, ParsingStats) {
let start_time = Instant::now();
let results: Vec<ParsingResult> = file_paths
.into_par_iter()
.map(|path| self.parse_single_file(path))
.collect();
let total_time = start_time.elapsed().as_millis() as u64;
let stats = ParsingStats::from_results(&results, total_time);
if self.collect_stats {
tracing::info!(
"Parsed {} files: {} successful, {} failed, {} signatures, {:.2}ms avg",
stats.total_files,
stats.successful_files,
stats.failed_files,
stats.total_signatures,
stats.avg_time_per_file_ms
);
}
(results, stats)
}
fn parse_single_file(&self, file_path: PathBuf) -> ParsingResult {
let start_time = Instant::now();
let extension = file_path
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("");
let language_id = match LanguageId::from_extension(extension) {
Some(id) => id,
None => {
let ext = extension.to_string();
return ParsingResult::failure(
file_path,
format!("Unsupported file extension: {}", ext),
);
}
};
let language_name = language_id.config().name.clone();
let source = match std::fs::read(&file_path) {
Ok(contents) => contents,
Err(e) => {
return ParsingResult::failure(file_path, format!("Failed to read file: {}", e))
}
};
let lang_parser = match parser_for_language(&language_name) {
Some(p) => p,
None => {
return ParsingResult::failure(
file_path,
format!("No parser found for language: {}", language_name),
)
}
};
let result = THREAD_PARSER.with(|parser_cell| {
let mut parser = parser_cell.borrow_mut();
lang_parser.get_signatures_with_parser(&source, &mut parser)
});
let parse_time_ms = start_time.elapsed().as_millis() as u64;
match result {
Ok(signatures) => {
ParsingResult::success(file_path, language_name, signatures, source, parse_time_ms)
}
Err(e) => ParsingResult::failure(file_path, format!("Parse error: {}", e)),
}
}
pub fn successful_results(results: &[ParsingResult]) -> Vec<&ParsingResult> {
results.iter().filter(|r| r.is_success()).collect()
}
pub fn failed_results(results: &[ParsingResult]) -> Vec<&ParsingResult> {
results.iter().filter(|r| r.is_failure()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use std::io::Write;
use tempfile::tempdir;
#[test]
fn test_parallel_parser_multiple_files() {
let dir = tempdir().unwrap();
let py_path = dir.path().join("test.py");
let rs_path = dir.path().join("test.rs");
let mut py_file = File::create(&py_path).unwrap();
writeln!(py_file, "def hello(): pass").unwrap();
let mut rs_file = File::create(&rs_path).unwrap();
writeln!(rs_file, "fn main() {{}}").unwrap();
let parser = ParallelParser::new();
let results = parser.parse_files(vec![py_path, rs_path]);
assert_eq!(results.len(), 2);
assert!(results.iter().all(|r| r.is_success()));
}
#[test]
fn test_parallel_parser_with_error() {
let dir = tempdir().unwrap();
let unsupported_path = dir.path().join("test.unknown");
File::create(&unsupported_path).unwrap();
let parser = ParallelParser::new();
let results = parser.parse_files(vec![unsupported_path]);
assert_eq!(results.len(), 1);
assert!(results[0].is_failure());
}
#[test]
fn test_parsing_stats() {
let dir = tempdir().unwrap();
let py_path = dir.path().join("test.py");
let mut py_file = File::create(&py_path).unwrap();
writeln!(py_file, "def hello(): pass").unwrap();
let parser = ParallelParser::new();
let (_, stats) = parser.parse_files_with_stats(vec![py_path]);
assert_eq!(stats.total_files, 1);
assert_eq!(stats.successful_files, 1);
let _ = stats.total_time_ms;
}
#[test]
fn test_parsing_result_includes_source_bytes() {
let dir = tempdir().unwrap();
let py_path = dir.path().join("test.py");
let mut py_file = File::create(&py_path).unwrap();
writeln!(py_file, "def hello(): pass").unwrap();
let parser = ParallelParser::new();
let (results, _) = parser.parse_files_with_stats(vec![py_path]);
assert_eq!(results.len(), 1);
assert!(
results[0].source_bytes.is_some(),
"Successful parse should include source_bytes"
);
assert!(!results[0].source_bytes.as_ref().unwrap().is_empty());
}
}