use ryo_source::pure::{PureBlock, PureClosureParam, PureExpr, PurePattern, PureStmt};
use ryo_symbol::SymbolId;
use crate::Mutation;
#[derive(Debug, Clone, Default)]
pub struct RedundantClosureMutation {
pub target_fn: Option<SymbolId>,
}
impl RedundantClosureMutation {
pub fn new() -> Self {
Self::default()
}
pub fn in_function(mut self, id: SymbolId) -> Self {
self.target_fn = Some(id);
self
}
fn get_param_names(params: &[PureClosureParam]) -> Vec<String> {
params
.iter()
.filter_map(|p| match &p.pattern {
PurePattern::Ident { name, .. } => Some(name.clone()),
_ => None,
})
.collect()
}
fn is_path_to(expr: &PureExpr, name: &str) -> bool {
matches!(expr, PureExpr::Path(p) if p == name)
}
fn is_redundant_call(params: &[String], body: &PureExpr) -> Option<PureExpr> {
match body {
PureExpr::Call { func, args } => {
if args.len() != params.len() {
return None;
}
for (param, arg) in params.iter().zip(args.iter()) {
if !Self::is_path_to(arg, param) {
return None;
}
}
Some(func.as_ref().clone())
}
PureExpr::Block { block, .. } => {
if block.stmts.len() == 1 {
match &block.stmts[0] {
PureStmt::Expr(e) => Self::is_redundant_call(params, e),
_ => None,
}
} else {
None
}
}
_ => None,
}
}
fn transform_expr(&self, expr: &mut PureExpr) -> usize {
let mut changes = 0;
if let PureExpr::Closure { params, body, .. } = expr {
let param_names = Self::get_param_names(params);
if param_names.len() == params.len() {
if let Some(func_ref) = Self::is_redundant_call(¶m_names, body) {
*expr = func_ref;
return 1;
}
}
}
match expr {
PureExpr::Binary { left, right, .. } => {
changes += self.transform_expr(left);
changes += self.transform_expr(right);
}
PureExpr::Unary { expr: inner, .. } => {
changes += self.transform_expr(inner);
}
PureExpr::Call { func, args } => {
changes += self.transform_expr(func);
for arg in args {
changes += self.transform_expr(arg);
}
}
PureExpr::MethodCall { receiver, args, .. } => {
changes += self.transform_expr(receiver);
for arg in args {
changes += self.transform_expr(arg);
}
}
PureExpr::Field { expr: inner, .. } => {
changes += self.transform_expr(inner);
}
PureExpr::Index { expr: inner, index } => {
changes += self.transform_expr(inner);
changes += self.transform_expr(index);
}
PureExpr::Block { block, .. } => {
changes += self.transform_block(block);
}
PureExpr::If {
cond,
then_branch,
else_branch,
} => {
changes += self.transform_expr(cond);
changes += self.transform_block(then_branch);
if let Some(else_expr) = else_branch {
changes += self.transform_expr(else_expr);
}
}
PureExpr::Match { expr: e, arms } => {
changes += self.transform_expr(e);
for arm in arms {
changes += self.transform_expr(&mut arm.body);
}
}
PureExpr::Loop { body: block, .. } | PureExpr::While { body: block, .. } => {
changes += self.transform_block(block);
}
PureExpr::For {
expr: iter_expr,
body,
..
} => {
changes += self.transform_expr(iter_expr);
changes += self.transform_block(body);
}
PureExpr::Closure { body, .. } => {
changes += self.transform_expr(body);
}
PureExpr::Tuple(exprs) | PureExpr::Array(exprs) => {
for e in exprs {
changes += self.transform_expr(e);
}
}
PureExpr::Struct { fields, .. } => {
for (_, e) in fields {
changes += self.transform_expr(e);
}
}
PureExpr::Ref { expr: inner, .. } => {
changes += self.transform_expr(inner);
}
PureExpr::Return(Some(inner)) => {
changes += self.transform_expr(inner);
}
PureExpr::Try(inner) | PureExpr::Await(inner) => {
changes += self.transform_expr(inner);
}
_ => {}
}
changes
}
pub fn transform_block(&self, block: &mut PureBlock) -> usize {
let mut changes = 0;
for stmt in &mut block.stmts {
changes += self.transform_stmt(stmt);
}
changes
}
fn transform_stmt(&self, stmt: &mut PureStmt) -> usize {
match stmt {
PureStmt::Local { init: Some(e), .. } => self.transform_expr(e),
PureStmt::Semi(e) | PureStmt::Expr(e) => self.transform_expr(e),
_ => 0,
}
}
}
impl Mutation for RedundantClosureMutation {
fn describe(&self) -> String {
"Simplify redundant closures (|x| foo(x) → foo)".to_string()
}
fn mutation_type(&self) -> &'static str {
"RedundantClosure"
}
fn box_clone(&self) -> Box<dyn Mutation> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_param_names() {
let params = vec![
PureClosureParam::untyped(PurePattern::Ident {
name: "x".to_string(),
is_mut: false,
}),
PureClosureParam::untyped(PurePattern::Ident {
name: "y".to_string(),
is_mut: false,
}),
];
let names = RedundantClosureMutation::get_param_names(¶ms);
assert_eq!(names, vec!["x", "y"]);
}
#[test]
fn test_is_redundant_call_single_param() {
let params = vec!["x".to_string()];
let body = PureExpr::Call {
func: Box::new(PureExpr::Path("foo".to_string())),
args: vec![PureExpr::Path("x".to_string())],
};
let result = RedundantClosureMutation::is_redundant_call(¶ms, &body);
assert!(result.is_some());
assert!(matches!(result.unwrap(), PureExpr::Path(s) if s == "foo"));
}
#[test]
fn test_is_redundant_call_multi_param() {
let params = vec!["a".to_string(), "b".to_string()];
let body = PureExpr::Call {
func: Box::new(PureExpr::Path("func".to_string())),
args: vec![
PureExpr::Path("a".to_string()),
PureExpr::Path("b".to_string()),
],
};
let result = RedundantClosureMutation::is_redundant_call(¶ms, &body);
assert!(result.is_some());
}
#[test]
fn test_is_not_redundant_wrong_order() {
let params = vec!["a".to_string(), "b".to_string()];
let body = PureExpr::Call {
func: Box::new(PureExpr::Path("func".to_string())),
args: vec![
PureExpr::Path("b".to_string()),
PureExpr::Path("a".to_string()),
],
};
let result = RedundantClosureMutation::is_redundant_call(¶ms, &body);
assert!(result.is_none());
}
#[test]
fn test_is_not_redundant_extra_args() {
let params = vec!["x".to_string()];
let body = PureExpr::Call {
func: Box::new(PureExpr::Path("foo".to_string())),
args: vec![
PureExpr::Path("x".to_string()),
PureExpr::Path("y".to_string()),
],
};
let result = RedundantClosureMutation::is_redundant_call(¶ms, &body);
assert!(result.is_none());
}
}