use std::collections::HashSet;
use syn::visit::Visit;
use syn::{Expr, ExprAssign, ExprBinary, ExprClosure, ExprMethodCall, Pat};
use super::super::types::CaptureMode;
#[derive(Debug, Clone)]
pub(super) struct CaptureInfo {
pub var_name: String,
pub mode: CaptureMode,
pub is_mutated: bool,
}
pub(super) struct ClosureCaptureVisitor<'a> {
outer_scope: &'a HashSet<String>,
closure_params: &'a HashSet<String>,
captures: Vec<CaptureInfo>,
mutated_vars: HashSet<String>,
is_move: bool,
}
impl<'a> ClosureCaptureVisitor<'a> {
pub fn new(
outer_scope: &'a HashSet<String>,
closure_params: &'a HashSet<String>,
is_move: bool,
) -> Self {
Self {
outer_scope,
closure_params,
captures: Vec::new(),
mutated_vars: HashSet::new(),
is_move,
}
}
pub fn finalize_captures(mut self) -> Vec<CaptureInfo> {
for capture in &mut self.captures {
if self.mutated_vars.contains(&capture.var_name) {
capture.is_mutated = true;
if !self.is_move {
capture.mode = CaptureMode::ByMutRef;
}
}
}
self.captures
}
fn record_mutation(&mut self, name: &str) {
self.mutated_vars.insert(name.to_string());
}
fn try_record_capture(&mut self, name: String) {
if name == "self" || name == "Self" {
return;
}
if self.outer_scope.contains(&name) && !self.closure_params.contains(&name) {
if !self.captures.iter().any(|c| c.var_name == name) {
self.captures.push(CaptureInfo {
var_name: name,
mode: if self.is_move {
CaptureMode::ByValue
} else {
CaptureMode::ByRef
},
is_mutated: false,
});
}
}
}
fn try_record_path_capture(&mut self, expr: &Expr) {
if let Some(name) = path_ident(expr) {
self.try_record_capture(name);
}
}
fn record_path_mutation(&mut self, expr: &Expr) {
if let Some(name) = path_ident(expr) {
self.record_mutation(&name);
self.try_record_capture(name);
}
}
fn visit_method_call(&mut self, method_call: &ExprMethodCall) {
self.visit_expr(&method_call.receiver);
self.record_mutating_method_receiver(method_call);
for arg in &method_call.args {
self.visit_expr(arg);
}
}
fn record_mutating_method_receiver(&mut self, method_call: &ExprMethodCall) {
let method_name = method_call.method.to_string();
if is_mutating_method(&method_name) {
self.record_path_mutation(&method_call.receiver);
}
}
fn visit_assign(&mut self, assign: &ExprAssign) {
self.record_path_mutation(&assign.left);
self.visit_expr(&assign.right);
}
fn visit_binary(&mut self, binary: &ExprBinary) {
if is_compound_assignment(&binary.op) {
self.record_path_mutation(&binary.left);
}
self.visit_expr(&binary.left);
self.visit_expr(&binary.right);
}
fn visit_nested_closure(&mut self, nested_closure: &ExprClosure) {
let nested_params = nested_closure_params(nested_closure);
let nested_is_move = nested_closure.capture.is_some();
let mut nested_visitor =
ClosureCaptureVisitor::new(self.outer_scope, &nested_params, nested_is_move);
nested_visitor.visit_expr(&nested_closure.body);
for capture in nested_visitor.finalize_captures() {
self.push_unique_capture(capture);
}
}
fn push_unique_capture(&mut self, capture: CaptureInfo) {
if !self
.captures
.iter()
.any(|existing| existing.var_name == capture.var_name)
{
self.captures.push(capture);
}
}
}
impl<'ast, 'a> Visit<'ast> for ClosureCaptureVisitor<'a> {
fn visit_expr(&mut self, expr: &'ast Expr) {
match expr {
Expr::Path(_) => self.try_record_path_capture(expr),
Expr::MethodCall(method_call) => self.visit_method_call(method_call),
Expr::Assign(assign) => self.visit_assign(assign),
Expr::Binary(binary) => self.visit_binary(binary),
Expr::Closure(nested_closure) => self.visit_nested_closure(nested_closure),
_ => {
syn::visit::visit_expr(self, expr);
}
}
}
}
fn path_ident(expr: &Expr) -> Option<String> {
match expr {
Expr::Path(path) => path.path.get_ident().map(ToString::to_string),
_ => None,
}
}
fn nested_closure_params(closure: &ExprClosure) -> HashSet<String> {
closure
.inputs
.iter()
.filter_map(extract_pattern_name)
.collect()
}
fn is_compound_assignment(op: &syn::BinOp) -> bool {
matches!(
op,
syn::BinOp::AddAssign(_)
| syn::BinOp::SubAssign(_)
| syn::BinOp::MulAssign(_)
| syn::BinOp::DivAssign(_)
| syn::BinOp::RemAssign(_)
| syn::BinOp::BitAndAssign(_)
| syn::BinOp::BitOrAssign(_)
| syn::BinOp::BitXorAssign(_)
| syn::BinOp::ShlAssign(_)
| syn::BinOp::ShrAssign(_)
)
}
pub(super) fn is_mutating_method(name: &str) -> bool {
matches!(
name,
"push"
| "pop"
| "insert"
| "remove"
| "clear"
| "extend"
| "drain"
| "append"
| "truncate"
| "reserve"
| "shrink_to_fit"
| "set"
| "swap"
| "sort"
| "sort_by"
| "sort_by_key"
| "dedup"
| "retain"
| "resize"
)
}
pub(super) fn extract_pattern_name(pat: &Pat) -> Option<String> {
match pat {
Pat::Ident(pat_ident) => Some(pat_ident.ident.to_string()),
Pat::Type(pat_type) => extract_pattern_name(&pat_type.pat),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
fn names(names: &[&str]) -> HashSet<String> {
names.iter().map(ToString::to_string).collect()
}
fn captures_for(
expr: Expr,
outer: &[&str],
params: &[&str],
is_move: bool,
) -> Vec<CaptureInfo> {
let outer_scope = names(outer);
let closure_params = names(params);
let mut visitor = ClosureCaptureVisitor::new(&outer_scope, &closure_params, is_move);
visitor.visit_expr(&expr);
visitor.finalize_captures()
}
fn capture<'a>(captures: &'a [CaptureInfo], name: &str) -> &'a CaptureInfo {
captures
.iter()
.find(|capture| capture.var_name == name)
.unwrap_or_else(|| panic!("missing capture for {name}"))
}
#[test]
fn path_records_outer_variable_capture() {
let captures = captures_for(parse_quote!(outer_value), &["outer_value"], &[], false);
assert_eq!(capture(&captures, "outer_value").mode, CaptureMode::ByRef);
}
#[test]
fn path_skips_closure_parameters_and_special_names() {
let expr = parse_quote!((outer_value, self, Self, item));
let captures = captures_for(
expr,
&["outer_value", "self", "Self", "item"],
&["item"],
false,
);
assert_eq!(captures.len(), 1);
assert_eq!(captures[0].var_name, "outer_value");
}
#[test]
fn assignment_records_mutable_capture() {
let captures = captures_for(parse_quote!(counter = 1), &["counter"], &[], false);
let counter = capture(&captures, "counter");
assert_eq!(counter.mode, CaptureMode::ByMutRef);
assert!(counter.is_mutated);
}
#[test]
fn move_assignment_stays_by_value_and_mutated() {
let captures = captures_for(parse_quote!(counter = 1), &["counter"], &[], true);
let counter = capture(&captures, "counter");
assert_eq!(counter.mode, CaptureMode::ByValue);
assert!(counter.is_mutated);
}
#[test]
fn mutating_method_records_mutable_receiver_capture() {
let captures = captures_for(
parse_quote!(items.push(value)),
&["items", "value"],
&[],
false,
);
let items = capture(&captures, "items");
assert_eq!(items.mode, CaptureMode::ByMutRef);
assert!(items.is_mutated);
assert_eq!(capture(&captures, "value").mode, CaptureMode::ByRef);
}
#[test]
fn non_mutating_method_records_receiver_and_args_by_ref() {
let captures = captures_for(
parse_quote!(items.contains(value)),
&["items", "value"],
&[],
false,
);
assert_eq!(capture(&captures, "items").mode, CaptureMode::ByRef);
assert!(!capture(&captures, "items").is_mutated);
assert_eq!(capture(&captures, "value").mode, CaptureMode::ByRef);
}
#[test]
fn compound_assignment_records_mutable_capture() {
let captures = captures_for(
parse_quote!(total += amount),
&["total", "amount"],
&[],
false,
);
let total = capture(&captures, "total");
assert_eq!(total.mode, CaptureMode::ByMutRef);
assert!(total.is_mutated);
assert_eq!(capture(&captures, "amount").mode, CaptureMode::ByRef);
}
#[test]
fn nested_closure_propagates_unique_outer_captures() {
let captures = captures_for(parse_quote!(|| value + value), &["value"], &[], false);
assert_eq!(captures.len(), 1);
assert_eq!(capture(&captures, "value").mode, CaptureMode::ByRef);
}
#[test]
fn typed_pattern_name_is_extracted() {
let closure: ExprClosure = parse_quote!(|item: usize| item);
assert_eq!(
extract_pattern_name(&closure.inputs[0]),
Some("item".to_string())
);
}
}