use ryo_source::pure::{PureBlock, PureExpr, PureFn, PureParam, PureStmt, PureType};
use ryo_symbol::SymbolId;
use std::collections::HashSet;
use crate::Mutation;
#[derive(Debug, Clone, Default)]
pub struct CloneOnCopyMutation {
pub target_fn: Option<SymbolId>,
pub aggressive: bool,
pub copy_vars: Vec<String>,
}
impl CloneOnCopyMutation {
pub fn new() -> Self {
Self::default()
}
pub fn in_function(mut self, id: SymbolId) -> Self {
self.target_fn = Some(id);
self
}
pub fn aggressive(mut self) -> Self {
self.aggressive = true;
self
}
pub fn with_copy_var(mut self, var: impl Into<String>) -> Self {
self.copy_vars.push(var.into());
self
}
fn is_copy_literal(expr: &PureExpr) -> bool {
match expr {
PureExpr::Lit(lit) => {
lit.chars().all(|c| c.is_ascii_digit() || c == '_')
|| lit.contains('.') && lit.chars().all(|c| c.is_ascii_digit() || c == '.' || c == '_')
|| lit == "true" || lit == "false"
|| (lit.starts_with('\'') && lit.ends_with('\''))
}
PureExpr::Path(path) => {
path == "true" || path == "false"
}
_ => false,
}
}
fn is_known_copy_var(&self, path: &str, fn_copy_vars: &HashSet<String>) -> bool {
self.copy_vars.iter().any(|v| v == path) || fn_copy_vars.contains(path)
}
const COPY_TYPES: &'static [&'static str] = &[
"i8",
"i16",
"i32",
"i64",
"i128",
"isize",
"u8",
"u16",
"u32",
"u64",
"u128",
"usize",
"f32",
"f64",
"bool",
"char",
"NonZeroI8",
"NonZeroI16",
"NonZeroI32",
"NonZeroI64",
"NonZeroI128",
"NonZeroIsize",
"NonZeroU8",
"NonZeroU16",
"NonZeroU32",
"NonZeroU64",
"NonZeroU128",
"NonZeroUsize",
];
fn is_copy_type(ty: &PureType) -> bool {
match ty {
PureType::Path(path) => {
Self::COPY_TYPES
.iter()
.any(|&t| path == t || path.ends_with(&format!("::{}", t)))
}
PureType::Ref { .. } => {
true
}
PureType::Tuple(types) => {
types.iter().all(Self::is_copy_type)
}
PureType::Array { ty, .. } => {
Self::is_copy_type(ty)
}
_ => false,
}
}
fn collect_copy_vars_from_params(params: &[PureParam]) -> HashSet<String> {
let mut copy_vars = HashSet::new();
for param in params {
if let PureParam::Typed { name, ty } = param {
if Self::is_copy_type(ty) {
copy_vars.insert(name.clone());
}
}
}
copy_vars
}
fn transform_expr(&self, expr: &mut PureExpr, fn_copy_vars: &HashSet<String>) -> usize {
let mut changes = 0;
if let PureExpr::MethodCall {
receiver,
method,
args,
..
} = expr
{
if method == "clone" && args.is_empty() {
let should_remove = if self.aggressive {
true
} else {
Self::is_copy_literal(receiver)
|| matches!(receiver.as_ref(), PureExpr::Path(p) if self.is_known_copy_var(p, fn_copy_vars))
};
if should_remove {
let inner = std::mem::replace(
receiver.as_mut(),
PureExpr::Path("__placeholder".to_string()),
);
*expr = inner;
return 1;
}
}
}
match expr {
PureExpr::Binary { left, right, .. } => {
changes += self.transform_expr(left, fn_copy_vars);
changes += self.transform_expr(right, fn_copy_vars);
}
PureExpr::Unary { expr: inner, .. } => {
changes += self.transform_expr(inner, fn_copy_vars);
}
PureExpr::Call { func, args } => {
changes += self.transform_expr(func, fn_copy_vars);
for arg in args {
changes += self.transform_expr(arg, fn_copy_vars);
}
}
PureExpr::MethodCall { receiver, args, .. } => {
changes += self.transform_expr(receiver, fn_copy_vars);
for arg in args {
changes += self.transform_expr(arg, fn_copy_vars);
}
}
PureExpr::Field { expr: inner, .. } => {
changes += self.transform_expr(inner, fn_copy_vars);
}
PureExpr::Index { expr: inner, index } => {
changes += self.transform_expr(inner, fn_copy_vars);
changes += self.transform_expr(index, fn_copy_vars);
}
PureExpr::Block { block, .. } => {
changes += self.transform_block(block, fn_copy_vars);
}
PureExpr::If {
cond,
then_branch,
else_branch,
} => {
changes += self.transform_expr(cond, fn_copy_vars);
changes += self.transform_block(then_branch, fn_copy_vars);
if let Some(else_expr) = else_branch {
changes += self.transform_expr(else_expr, fn_copy_vars);
}
}
PureExpr::Match { expr: e, arms } => {
changes += self.transform_expr(e, fn_copy_vars);
for arm in arms {
changes += self.transform_expr(&mut arm.body, fn_copy_vars);
}
}
PureExpr::Loop { body: block, .. } | PureExpr::While { body: block, .. } => {
changes += self.transform_block(block, fn_copy_vars);
}
PureExpr::For {
expr: iter_expr,
body,
..
} => {
changes += self.transform_expr(iter_expr, fn_copy_vars);
changes += self.transform_block(body, fn_copy_vars);
}
PureExpr::Closure { body, .. } => {
changes += self.transform_expr(body, fn_copy_vars);
}
PureExpr::Tuple(exprs) | PureExpr::Array(exprs) => {
for e in exprs {
changes += self.transform_expr(e, fn_copy_vars);
}
}
PureExpr::Struct { fields, .. } => {
for (_, e) in fields {
changes += self.transform_expr(e, fn_copy_vars);
}
}
PureExpr::Ref { expr: inner, .. } => {
changes += self.transform_expr(inner, fn_copy_vars);
}
PureExpr::Return(Some(inner)) => {
changes += self.transform_expr(inner, fn_copy_vars);
}
PureExpr::Try(inner) | PureExpr::Await(inner) => {
changes += self.transform_expr(inner, fn_copy_vars);
}
_ => {}
}
changes
}
fn transform_block(&self, block: &mut PureBlock, fn_copy_vars: &HashSet<String>) -> usize {
let mut changes = 0;
for stmt in &mut block.stmts {
changes += self.transform_stmt(stmt, fn_copy_vars);
}
changes
}
fn transform_stmt(&self, stmt: &mut PureStmt, fn_copy_vars: &HashSet<String>) -> usize {
match stmt {
PureStmt::Local { init: Some(e), .. } => self.transform_expr(e, fn_copy_vars),
PureStmt::Semi(e) | PureStmt::Expr(e) => self.transform_expr(e, fn_copy_vars),
_ => 0,
}
}
pub fn transform_fn(&self, func: &mut PureFn) -> usize {
let fn_copy_vars = Self::collect_copy_vars_from_params(&func.params);
self.transform_block(&mut func.body, &fn_copy_vars)
}
}
impl Mutation for CloneOnCopyMutation {
fn describe(&self) -> String {
"Remove unnecessary .clone() on Copy types".to_string()
}
fn mutation_type(&self) -> &'static str {
"CloneOnCopy"
}
fn box_clone(&self) -> Box<dyn Mutation> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_copy_literal_integer() {
assert!(CloneOnCopyMutation::is_copy_literal(&PureExpr::Lit(
"42".to_string()
)));
assert!(CloneOnCopyMutation::is_copy_literal(&PureExpr::Lit(
"1_000".to_string()
)));
}
#[test]
fn test_is_copy_literal_bool() {
assert!(CloneOnCopyMutation::is_copy_literal(&PureExpr::Lit(
"true".to_string()
)));
assert!(CloneOnCopyMutation::is_copy_literal(&PureExpr::Lit(
"false".to_string()
)));
assert!(CloneOnCopyMutation::is_copy_literal(&PureExpr::Path(
"true".to_string()
)));
}
#[test]
fn test_is_copy_literal_char() {
assert!(CloneOnCopyMutation::is_copy_literal(&PureExpr::Lit(
"'a'".to_string()
)));
assert!(CloneOnCopyMutation::is_copy_literal(&PureExpr::Lit(
"'\\n'".to_string()
)));
}
#[test]
fn test_is_not_copy_literal() {
assert!(!CloneOnCopyMutation::is_copy_literal(&PureExpr::Lit(
"\"string\"".to_string()
)));
assert!(!CloneOnCopyMutation::is_copy_literal(&PureExpr::Path(
"variable".to_string()
)));
}
#[test]
fn test_known_copy_var() {
let mutation = CloneOnCopyMutation::new()
.with_copy_var("x")
.with_copy_var("count");
let empty_set = std::collections::HashSet::new();
assert!(mutation.is_known_copy_var("x", &empty_set));
assert!(mutation.is_known_copy_var("count", &empty_set));
assert!(!mutation.is_known_copy_var("other", &empty_set));
}
}