use std::collections::HashSet;
use syn::{visit::Visit, Expr, ExprClosure};
use super::purity_detector::{MutationScope, PurityDetector};
use super::scope_tracker::ScopeTracker;
use crate::core::PurityLevel;
#[derive(Debug, Clone)]
pub struct ClosurePurity {
pub level: PurityLevel,
pub confidence: f32,
pub captures: Vec<Capture>,
pub has_nested_closures: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CaptureMode {
ByValue,
ByRef,
ByMutRef,
}
#[derive(Debug, Clone)]
pub struct Capture {
pub var_name: String,
pub mode: CaptureMode,
pub is_mutated: bool,
pub scope: MutationScope,
}
#[derive(Debug)]
pub struct ClosureAnalyzer<'a> {
parent_scope: &'a ScopeTracker,
captures: Vec<Capture>,
confidence_penalties: Vec<&'static str>,
}
impl<'a> ClosureAnalyzer<'a> {
pub fn new(parent_scope: &'a ScopeTracker) -> Self {
Self {
parent_scope,
captures: Vec::new(),
confidence_penalties: Vec::new(),
}
}
pub fn analyze_closure(&mut self, closure: &ExprClosure) -> ClosurePurity {
let mut body_detector = PurityDetector::new();
for input in &closure.inputs {
if let syn::Pat::Ident(pat_ident) = input {
body_detector
.scope_mut()
.add_local_var(pat_ident.ident.to_string());
}
}
body_detector.visit_expr(&closure.body);
self.captures = self.find_captures(closure, &body_detector);
self.infer_capture_modes(closure, &body_detector);
let has_nested_closures = self.contains_nested_closures(&closure.body);
if has_nested_closures {
self.confidence_penalties.push("nested_closures");
}
let level = self.determine_purity_level(&body_detector);
let confidence = self.calculate_confidence(&body_detector);
ClosurePurity {
level,
confidence,
captures: self.captures.clone(),
has_nested_closures,
}
}
fn find_captures(
&self,
closure: &ExprClosure,
_body_detector: &PurityDetector,
) -> Vec<Capture> {
let mut params: HashSet<String> = HashSet::new();
for input in &closure.inputs {
if let syn::Pat::Ident(pat_ident) = input {
params.insert(pat_ident.ident.to_string());
}
}
let mut visitor = CaptureDetector {
params: ¶ms,
parent_scope: self.parent_scope,
captures: Vec::new(),
};
visitor.visit_expr(&closure.body);
visitor.captures
}
fn infer_capture_modes(&mut self, closure: &ExprClosure, body_detector: &PurityDetector) {
let has_move = closure.capture.is_some();
for capture in &mut self.captures {
if has_move {
capture.mode = CaptureMode::ByValue;
continue;
}
let is_mutated = body_detector
.local_mutations()
.iter()
.any(|m| m.target == capture.var_name);
capture.is_mutated = is_mutated;
capture.mode = if is_mutated {
CaptureMode::ByMutRef
} else {
CaptureMode::ByRef
};
capture.scope = if self.parent_scope.is_local(&capture.var_name) {
MutationScope::Local
} else {
MutationScope::External
};
}
}
fn contains_nested_closures(&self, expr: &Expr) -> bool {
let mut visitor = ClosureDetector { found: false };
visitor.visit_expr(expr);
visitor.found
}
fn determine_purity_level(&self, body_detector: &PurityDetector) -> PurityLevel {
if body_detector.has_io_operations() || body_detector.has_unsafe_blocks() {
return PurityLevel::Impure;
}
if body_detector.modifies_external_state() {
return PurityLevel::Impure;
}
let mutates_external = self
.captures
.iter()
.any(|c| c.is_mutated && c.scope == MutationScope::External);
if mutates_external {
return PurityLevel::Impure;
}
let mutates_local = self
.captures
.iter()
.any(|c| c.is_mutated && c.scope == MutationScope::Local);
if mutates_local || !body_detector.local_mutations().is_empty() {
return PurityLevel::LocallyPure;
}
if body_detector.accesses_external_state() {
return PurityLevel::ReadOnly;
}
PurityLevel::StrictlyPure
}
fn calculate_confidence(&self, body_detector: &PurityDetector) -> f32 {
let mut confidence: f32 = 1.0;
if self.confidence_penalties.contains(&"nested_closures") {
confidence *= 0.85;
}
if body_detector.accesses_external_state() {
confidence *= 0.80;
}
if self.captures.len() > 3 {
confidence *= 0.90;
}
if self.captures.iter().any(|c| c.mode != CaptureMode::ByValue) {
confidence *= 0.95;
}
confidence.clamp(0.5, 1.0)
}
}
struct CaptureDetector<'a> {
params: &'a HashSet<String>,
parent_scope: &'a ScopeTracker,
captures: Vec<Capture>,
}
impl<'ast, 'a> Visit<'ast> for CaptureDetector<'a> {
fn visit_expr(&mut self, expr: &'ast Expr) {
if let Expr::Path(path) = expr {
if let Some(ident) = path.path.get_ident() {
let name = ident.to_string();
if !self.params.contains(&name) && name != "self" && name != "Self" {
if self.parent_scope.is_local(&name) || self.parent_scope.is_self(&name) {
if !self.captures.iter().any(|c| c.var_name == name) {
self.captures.push(Capture {
var_name: name,
mode: CaptureMode::ByRef, is_mutated: false,
scope: MutationScope::Local,
});
}
}
}
}
}
syn::visit::visit_expr(self, expr);
}
}
struct ClosureDetector {
found: bool,
}
impl<'ast> Visit<'ast> for ClosureDetector {
fn visit_expr(&mut self, expr: &'ast Expr) {
if matches!(expr, Expr::Closure(_)) {
self.found = true;
return;
}
syn::visit::visit_expr(self, expr);
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_simple_closure_no_captures() {
let closure: ExprClosure = parse_quote!(|x| x * 2);
let parent_scope = ScopeTracker::new();
let mut analyzer = ClosureAnalyzer::new(&parent_scope);
let result = analyzer.analyze_closure(&closure);
assert_eq!(result.level, PurityLevel::StrictlyPure);
assert!(result.captures.is_empty());
assert!(!result.has_nested_closures);
}
#[test]
fn test_closure_with_capture() {
let closure: ExprClosure = parse_quote!(|x| x + y);
let mut parent_scope = ScopeTracker::new();
parent_scope.add_local_var("y".to_string());
let mut analyzer = ClosureAnalyzer::new(&parent_scope);
let result = analyzer.analyze_closure(&closure);
assert_eq!(result.captures.len(), 1);
assert_eq!(result.captures[0].var_name, "y");
assert_eq!(result.captures[0].mode, CaptureMode::ByRef);
}
#[test]
fn test_move_closure() {
let closure: ExprClosure = parse_quote!(move |x| x + y);
let mut parent_scope = ScopeTracker::new();
parent_scope.add_local_var("y".to_string());
let mut analyzer = ClosureAnalyzer::new(&parent_scope);
let result = analyzer.analyze_closure(&closure);
assert_eq!(result.captures.len(), 1);
assert_eq!(result.captures[0].mode, CaptureMode::ByValue);
}
#[test]
fn test_nested_closure_detection() {
let closure: ExprClosure = parse_quote!(|x| {
let f = |y| y * 2;
f(x)
});
let parent_scope = ScopeTracker::new();
let mut analyzer = ClosureAnalyzer::new(&parent_scope);
let result = analyzer.analyze_closure(&closure);
assert!(result.has_nested_closures);
assert!(result.confidence < 0.9); }
}