use std::collections::HashSet;
use syn::visit::Visit;
use syn::{Expr, Pat};
use super::super::types::CaptureMode;
#[derive(Debug, Clone)]
pub(super) struct CaptureInfo {
pub var_name: String,
pub mode: CaptureMode,
pub is_mutated: bool,
}
pub(super) struct ClosureCaptureVisitor<'a> {
outer_scope: &'a HashSet<String>,
closure_params: &'a HashSet<String>,
captures: Vec<CaptureInfo>,
mutated_vars: HashSet<String>,
is_move: bool,
}
impl<'a> ClosureCaptureVisitor<'a> {
pub fn new(
outer_scope: &'a HashSet<String>,
closure_params: &'a HashSet<String>,
is_move: bool,
) -> Self {
Self {
outer_scope,
closure_params,
captures: Vec::new(),
mutated_vars: HashSet::new(),
is_move,
}
}
pub fn finalize_captures(mut self) -> Vec<CaptureInfo> {
for capture in &mut self.captures {
if self.mutated_vars.contains(&capture.var_name) {
capture.is_mutated = true;
if !self.is_move {
capture.mode = CaptureMode::ByMutRef;
}
}
}
self.captures
}
fn record_mutation(&mut self, name: &str) {
self.mutated_vars.insert(name.to_string());
}
fn try_record_capture(&mut self, name: String) {
if name == "self" || name == "Self" {
return;
}
if self.outer_scope.contains(&name) && !self.closure_params.contains(&name) {
if !self.captures.iter().any(|c| c.var_name == name) {
self.captures.push(CaptureInfo {
var_name: name,
mode: if self.is_move {
CaptureMode::ByValue
} else {
CaptureMode::ByRef
},
is_mutated: false,
});
}
}
}
}
impl<'ast, 'a> Visit<'ast> for ClosureCaptureVisitor<'a> {
fn visit_expr(&mut self, expr: &'ast Expr) {
match expr {
Expr::Path(path) => {
if let Some(ident) = path.path.get_ident() {
self.try_record_capture(ident.to_string());
}
}
Expr::MethodCall(method_call) => {
self.visit_expr(&method_call.receiver);
let method_name = method_call.method.to_string();
if is_mutating_method(&method_name) {
if let Expr::Path(path) = &*method_call.receiver {
if let Some(ident) = path.path.get_ident() {
self.record_mutation(&ident.to_string());
}
}
}
for arg in &method_call.args {
self.visit_expr(arg);
}
}
Expr::Assign(assign) => {
if let Expr::Path(path) = &*assign.left {
if let Some(ident) = path.path.get_ident() {
self.record_mutation(&ident.to_string());
}
}
self.visit_expr(&assign.right);
}
Expr::Binary(binary) => {
if is_compound_assignment(&binary.op) {
if let Expr::Path(path) = &*binary.left {
if let Some(ident) = path.path.get_ident() {
self.record_mutation(&ident.to_string());
}
}
}
self.visit_expr(&binary.left);
self.visit_expr(&binary.right);
}
Expr::Closure(nested_closure) => {
let nested_params: HashSet<String> = nested_closure
.inputs
.iter()
.filter_map(extract_pattern_name)
.collect();
let nested_is_move = nested_closure.capture.is_some();
let mut nested_visitor =
ClosureCaptureVisitor::new(self.outer_scope, &nested_params, nested_is_move);
nested_visitor.visit_expr(&nested_closure.body);
for capture in nested_visitor.finalize_captures() {
if !self.captures.iter().any(|c| c.var_name == capture.var_name) {
self.captures.push(capture);
}
}
}
_ => {
syn::visit::visit_expr(self, expr);
}
}
}
}
fn is_compound_assignment(op: &syn::BinOp) -> bool {
matches!(
op,
syn::BinOp::AddAssign(_)
| syn::BinOp::SubAssign(_)
| syn::BinOp::MulAssign(_)
| syn::BinOp::DivAssign(_)
| syn::BinOp::RemAssign(_)
| syn::BinOp::BitAndAssign(_)
| syn::BinOp::BitOrAssign(_)
| syn::BinOp::BitXorAssign(_)
| syn::BinOp::ShlAssign(_)
| syn::BinOp::ShrAssign(_)
)
}
pub(super) fn is_mutating_method(name: &str) -> bool {
matches!(
name,
"push"
| "pop"
| "insert"
| "remove"
| "clear"
| "extend"
| "drain"
| "append"
| "truncate"
| "reserve"
| "shrink_to_fit"
| "set"
| "swap"
| "sort"
| "sort_by"
| "sort_by_key"
| "dedup"
| "retain"
| "resize"
)
}
pub(super) fn extract_pattern_name(pat: &Pat) -> Option<String> {
match pat {
Pat::Ident(pat_ident) => Some(pat_ident.ident.to_string()),
Pat::Type(pat_type) => extract_pattern_name(&pat_type.pat),
_ => None,
}
}