use std::collections::HashSet;
use syn::{visit::Visit, Expr, ExprClosure, ExprMethodCall, File, Item, Stmt};
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ParallelLibrary {
Rayon,
Tokio,
StdThread,
Crossbeam,
}
impl std::fmt::Display for ParallelLibrary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ParallelLibrary::Rayon => write!(f, "rayon"),
ParallelLibrary::Tokio => write!(f, "tokio"),
ParallelLibrary::StdThread => write!(f, "std::thread"),
ParallelLibrary::Crossbeam => write!(f, "crossbeam"),
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ClosureInfo {
pub line_number: usize,
pub captures: Vec<String>,
pub is_move: bool,
pub closure_complexity: usize,
pub lines: usize,
pub extractable: bool,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ParallelPattern {
pub library: ParallelLibrary,
pub closure_count: usize,
pub total_captures: usize,
pub avg_captures_per_closure: f64,
pub setup_lines: usize,
pub execution_lines: usize,
pub aggregation_lines: usize,
pub cyclomatic_complexity: usize,
pub coordination_complexity: f64,
pub sync_primitives: Vec<String>,
pub has_move_closures: bool,
pub closures: Vec<ClosureInfo>,
}
pub struct ParallelPatternDetector {
pub min_closure_captures: usize,
pub min_parallel_calls: usize,
}
impl Default for ParallelPatternDetector {
fn default() -> Self {
Self {
min_closure_captures: 3,
min_parallel_calls: 1,
}
}
}
impl ParallelPatternDetector {
pub fn detect(&self, ast: &File, source_content: &str) -> Option<ParallelPattern> {
let mut visitor = ParallelVisitor::new();
visitor.visit_file(ast);
if visitor.parallel_calls.is_empty() {
return None;
}
let library = self.detect_library(&visitor)?;
let closures = self.analyze_closures(&visitor, source_content);
let total_captures: usize = closures.iter().map(|c| c.captures.len()).sum();
if closures.is_empty() || total_captures < self.min_closure_captures {
return None;
}
let avg_captures = total_captures as f64 / closures.len() as f64;
let coordination_complexity = calculate_coordination_complexity(
total_captures,
closures.len(),
visitor.sync_primitives.len(),
);
let total_lines = source_content.lines().count();
let setup_lines = total_lines / 5; let execution_lines = total_lines / 2;
let aggregation_lines = total_lines - setup_lines - execution_lines;
Some(ParallelPattern {
library,
closure_count: closures.len(),
total_captures,
avg_captures_per_closure: avg_captures,
setup_lines,
execution_lines,
aggregation_lines,
cyclomatic_complexity: 0, coordination_complexity,
sync_primitives: visitor.sync_primitives.clone(),
has_move_closures: closures.iter().any(|c| c.is_move),
closures,
})
}
pub fn confidence(&self, pattern: &ParallelPattern) -> f64 {
let mut confidence: f64 = 0.7;
if pattern.closure_count >= 2 {
confidence += 0.1;
}
confidence += 0.1;
if !pattern.sync_primitives.is_empty() {
confidence += 0.1;
}
if pattern.avg_captures_per_closure < 2.0 {
confidence -= 0.1;
}
confidence.clamp(0.0, 1.0)
}
fn detect_library(&self, visitor: &ParallelVisitor) -> Option<ParallelLibrary> {
if visitor
.parallel_calls
.iter()
.any(|call| call.contains("par_iter") || call.contains("par_bridge"))
{
return Some(ParallelLibrary::Rayon);
}
if visitor.parallel_calls.iter().any(|call| {
call.contains("tokio::spawn")
|| call.contains("spawn")
|| call.contains("join!")
|| call.contains("select!")
}) {
return Some(ParallelLibrary::Tokio);
}
if visitor
.parallel_calls
.iter()
.any(|call| call.contains("thread::spawn") || call.contains("thread::scope"))
{
return Some(ParallelLibrary::StdThread);
}
if visitor
.parallel_calls
.iter()
.any(|call| call.contains("crossbeam"))
{
return Some(ParallelLibrary::Crossbeam);
}
None
}
fn analyze_closures(
&self,
visitor: &ParallelVisitor,
source_content: &str,
) -> Vec<ClosureInfo> {
visitor
.closures
.iter()
.map(|closure_expr| {
let captures = estimate_captures(closure_expr, source_content);
let is_move = is_move_closure(closure_expr);
let lines = estimate_closure_lines(closure_expr, source_content);
let extractable = captures.len() <= 2 && lines > 20;
ClosureInfo {
line_number: 0, captures: captures.clone(),
is_move,
closure_complexity: estimate_closure_complexity(closure_expr),
lines,
extractable,
}
})
.collect()
}
}
struct ParallelVisitor {
parallel_calls: Vec<String>,
closures: Vec<ExprClosure>,
sync_primitives: Vec<String>,
}
impl ParallelVisitor {
fn new() -> Self {
Self {
parallel_calls: Vec::new(),
closures: Vec::new(),
sync_primitives: Vec::new(),
}
}
}
impl<'ast> Visit<'ast> for ParallelVisitor {
fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
let method_name = node.method.to_string();
if method_name.contains("par_iter")
|| method_name.contains("par_bridge")
|| method_name == "spawn"
{
self.parallel_calls.push(method_name);
}
syn::visit::visit_expr_method_call(self, node);
}
fn visit_expr_closure(&mut self, node: &'ast ExprClosure) {
self.closures.push(node.clone());
syn::visit::visit_expr_closure(self, node);
}
fn visit_item(&mut self, node: &'ast Item) {
if let Item::Type(ty) = node {
let ty_str = quote::quote!(#ty).to_string();
if ty_str.contains("Mutex")
|| ty_str.contains("RwLock")
|| ty_str.contains("AtomicBool")
|| ty_str.contains("Arc")
{
self.sync_primitives
.push(extract_sync_primitive_name(&ty_str));
}
}
syn::visit::visit_item(self, node);
}
}
fn calculate_coordination_complexity(
total_captures: usize,
closure_count: usize,
sync_primitive_count: usize,
) -> f64 {
let capture_complexity = total_captures as f64 * 0.5;
let closure_complexity = closure_count as f64 * 1.0;
let sync_complexity = sync_primitive_count as f64 * 0.8;
capture_complexity + closure_complexity + sync_complexity
}
pub fn adjust_parallel_score(base_score: f64, pattern: &ParallelPattern) -> f64 {
let coordination_factor = if pattern.avg_captures_per_closure > 5.0 {
0.5 } else if pattern.avg_captures_per_closure > 3.0 {
0.6 } else {
0.8 };
let closure_factor = if pattern.closure_count > 2 {
0.9 } else {
1.0
};
base_score * coordination_factor * closure_factor
}
fn estimate_captures(closure: &ExprClosure, _source_content: &str) -> Vec<String> {
let mut captures = HashSet::new();
let params: HashSet<String> = closure
.inputs
.iter()
.filter_map(|pat| {
if let syn::Pat::Ident(ident) = pat {
Some(ident.ident.to_string())
} else {
None
}
})
.collect();
let mut identifier_visitor = IdentifierVisitor {
identifiers: HashSet::new(),
};
identifier_visitor.visit_expr(&closure.body);
for ident in identifier_visitor.identifiers {
if !params.contains(&ident) && !is_keyword(&ident) {
captures.insert(ident);
}
}
captures.into_iter().collect()
}
fn is_move_closure(closure: &ExprClosure) -> bool {
closure.capture.is_some()
}
fn estimate_closure_lines(closure: &ExprClosure, _source_content: &str) -> usize {
if let Expr::Block(block) = closure.body.as_ref() {
block.block.stmts.len()
} else {
1
}
}
fn estimate_closure_complexity(closure: &ExprClosure) -> usize {
let mut complexity = 1;
if let Expr::Block(block) = closure.body.as_ref() {
for stmt in &block.block.stmts {
if let Stmt::Expr(expr, _) = stmt {
complexity += count_branches(expr);
}
}
}
complexity
}
fn count_branches(expr: &Expr) -> usize {
match expr {
Expr::If(_) => 1,
Expr::Match(match_expr) => match_expr.arms.len(),
Expr::While(_) | Expr::ForLoop(_) | Expr::Loop(_) => 1,
_ => 0,
}
}
fn extract_sync_primitive_name(ty_str: &str) -> String {
if ty_str.contains("Mutex") {
"Mutex".to_string()
} else if ty_str.contains("RwLock") {
"RwLock".to_string()
} else if ty_str.contains("AtomicBool") {
"AtomicBool".to_string()
} else if ty_str.contains("Arc") {
"Arc".to_string()
} else {
"Unknown".to_string()
}
}
fn is_keyword(ident: &str) -> bool {
matches!(
ident,
"self" | "Self" | "true" | "false" | "Some" | "None" | "Ok" | "Err"
)
}
struct IdentifierVisitor {
identifiers: HashSet<String>,
}
impl<'ast> Visit<'ast> for IdentifierVisitor {
fn visit_ident(&mut self, node: &'ast syn::Ident) {
self.identifiers.insert(node.to_string());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_library_detection() {
let code = r#"
fn search_parallel(args: &Args) -> Result<bool> {
let results = items.par_iter().map(|item| {
process(item)
}).collect();
Ok(results)
}
"#;
let ast = syn::parse_file(code).unwrap();
let detector = ParallelPatternDetector::default();
if let Some(pattern) = detector.detect(&ast, code) {
assert_eq!(pattern.library, ParallelLibrary::Rayon);
}
}
#[test]
fn test_coordination_complexity_calculation() {
let complexity = calculate_coordination_complexity(6, 1, 2);
assert!((complexity - 5.6).abs() < 0.01);
}
#[test]
fn test_score_adjustment() {
let pattern = ParallelPattern {
library: ParallelLibrary::Rayon,
closure_count: 1,
total_captures: 6,
avg_captures_per_closure: 6.0,
setup_lines: 10,
execution_lines: 40,
aggregation_lines: 5,
cyclomatic_complexity: 15,
coordination_complexity: 8.0,
sync_primitives: vec!["AtomicBool".into(), "Mutex".into()],
has_move_closures: true,
closures: vec![],
};
let base_score = 1000.0;
let adjusted = adjust_parallel_score(base_score, &pattern);
assert!((adjusted - 500.0).abs() < 0.01);
}
#[test]
fn test_move_closure_detection() {
let code = r#"move || { println!("test"); }"#;
let ast: ExprClosure = syn::parse_str(code).unwrap();
assert!(is_move_closure(&ast));
}
}