use ryo_analysis::SymbolKind;
use ryo_mutations::basic::stmt::{
InsertPosition, InsertStatementMutation, RemoveStatementMutation, ReplaceExprAtMutation,
ReplaceExprMutation, ReplaceStatementMutation, WrapExprMutation,
};
use ryo_mutations::{Mutation, MutationResult};
use ryo_source::pure::{MacroDelimiter, PureBlock, PureExpr, PureItem, PureStmt};
use crate::engine::{ASTMutationContext, ASTRegApply, ModificationType};
impl ASTRegApply for ReplaceExprMutation {
fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
let fn_id = self.target_fn;
if !matches!(ctx.symbol_registry.kind(fn_id), Some(SymbolKind::Function)) {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("Target function '{:?}' not found", fn_id),
};
}
if let Some(PureItem::Fn(func)) = ctx.ast_registry.get(fn_id) {
let mut new_func = func.clone();
let replacements = replace_expr_in_block(
&mut new_func.body,
&self.old_expr,
&self.new_expr,
self.replace_all,
);
if replacements > 0 {
ctx.set_ast(fn_id, PureItem::Fn(new_func));
ctx.emit_modified(fn_id, ModificationType::BodyModified);
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: replacements,
description: format!("Replaced {} expression(s)", replacements),
};
}
}
MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: "No matching expressions found".to_string(),
}
}
}
impl ASTRegApply for WrapExprMutation {
fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
let fn_id = self.target_fn;
if !matches!(ctx.symbol_registry.kind(fn_id), Some(SymbolKind::Function)) {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("Target function '{:?}' not found", fn_id),
};
}
if let Some(PureItem::Fn(func)) = ctx.ast_registry.get(fn_id) {
let mut new_func = func.clone();
let wraps = wrap_expr_in_block(
&mut new_func.body,
&self.target_expr,
&self.wrapper_macro,
self.wrap_all,
);
if wraps > 0 {
ctx.set_ast(fn_id, PureItem::Fn(new_func));
ctx.emit_modified(fn_id, ModificationType::BodyModified);
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: wraps,
description: format!(
"Wrapped {} expression(s) with {}!()",
wraps, self.wrapper_macro
),
};
}
}
MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: "No matching expressions found".to_string(),
}
}
}
impl ASTRegApply for RemoveStatementMutation {
fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
let fn_id = self.target_fn;
if !matches!(ctx.symbol_registry.kind(fn_id), Some(SymbolKind::Function)) {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("Target function '{:?}' not found", fn_id),
};
}
if let Some(PureItem::Fn(func)) = ctx.ast_registry.get(fn_id) {
let mut new_func = func.clone();
let removed = remove_stmts_in_block(&mut new_func.body, &self.pattern, self.remove_all);
if removed > 0 {
ctx.set_ast(fn_id, PureItem::Fn(new_func));
ctx.emit_modified(fn_id, ModificationType::BodyModified);
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: removed,
description: format!("Removed {} statement(s)", removed),
};
}
}
MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: "No matching statements found".to_string(),
}
}
}
impl ASTRegApply for InsertStatementMutation {
fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
let fn_id = self.target_fn;
if !matches!(ctx.symbol_registry.kind(fn_id), Some(SymbolKind::Function)) {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("Target function '{:?}' not found", fn_id),
};
}
if let Some(PureItem::Fn(func)) = ctx.ast_registry.get(fn_id) {
let mut new_func = func.clone();
let inserted = match self.position {
InsertPosition::Start => {
new_func.body.stmts.insert(0, self.stmt.clone());
true
}
InsertPosition::End => {
let insert_idx = find_return_index(&new_func.body.stmts);
new_func.body.stmts.insert(insert_idx, self.stmt.clone());
true
}
InsertPosition::BeforePattern | InsertPosition::AfterPattern => {
if let Some(ref reference_stmt) = self.reference_stmt {
insert_relative_to_stmt(
&mut new_func.body.stmts,
&self.stmt,
reference_stmt,
self.position == InsertPosition::AfterPattern,
)
} else {
false
}
}
};
if inserted {
ctx.set_ast(fn_id, PureItem::Fn(new_func));
ctx.emit_modified(fn_id, ModificationType::BodyModified);
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 1,
description: format!(
"Inserted statement in '{}' at {:?}",
self.target_fn, self.position
),
};
}
}
MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: "Failed to insert statement".to_string(),
}
}
}
impl ASTRegApply for ReplaceStatementMutation {
fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
let fn_id = self.target_fn;
if !matches!(ctx.symbol_registry.kind(fn_id), Some(SymbolKind::Function)) {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("Target function '{:?}' not found", fn_id),
};
}
if let Some(PureItem::Fn(func)) = ctx.ast_registry.get(fn_id) {
let mut new_func = func.clone();
let replaced =
replace_stmts_in_block(&mut new_func.body, &self.old_stmt, &self.new_stmt);
if replaced > 0 {
ctx.set_ast(fn_id, PureItem::Fn(new_func));
ctx.emit_modified(fn_id, ModificationType::BodyModified);
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: replaced,
description: format!("Replaced {} statement(s)", replaced),
};
}
}
MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: "No matching statements found".to_string(),
}
}
}
fn replace_expr_in_block(
block: &mut PureBlock,
old: &PureExpr,
new: &PureExpr,
replace_all: bool,
) -> usize {
let mut count = 0;
for stmt in &mut block.stmts {
count += replace_expr_in_stmt(stmt, old, new, replace_all);
if !replace_all && count > 0 {
return count;
}
}
count
}
fn wrap_expr_in_block(
block: &mut PureBlock,
target: &PureExpr,
wrapper_macro: &str,
wrap_all: bool,
) -> usize {
let mut count = 0;
for stmt in &mut block.stmts {
count += wrap_expr_in_stmt(stmt, target, wrapper_macro, wrap_all);
if !wrap_all && count > 0 {
return count;
}
}
count
}
fn wrap_expr_in_stmt(
stmt: &mut PureStmt,
target: &PureExpr,
wrapper_macro: &str,
wrap_all: bool,
) -> usize {
match stmt {
PureStmt::Local { init, .. } => {
if let Some(expr) = init {
wrap_expr(expr, target, wrapper_macro, wrap_all)
} else {
0
}
}
PureStmt::Semi(expr) | PureStmt::Expr(expr) => {
wrap_expr(expr, target, wrapper_macro, wrap_all)
}
PureStmt::Item(_) => 0,
}
}
fn wrap_expr(expr: &mut PureExpr, target: &PureExpr, wrapper_macro: &str, wrap_all: bool) -> usize {
if expr == target {
let wrapped = PureExpr::Macro {
name: wrapper_macro.to_string(),
delimiter: MacroDelimiter::Paren,
tokens: format!("{:?}", expr),
};
*expr = wrapped;
return 1;
}
match expr {
PureExpr::Binary { left, right, .. } => {
let mut c = wrap_expr(left, target, wrapper_macro, wrap_all);
if wrap_all || c == 0 {
c += wrap_expr(right, target, wrapper_macro, wrap_all);
}
c
}
PureExpr::Unary { expr: inner, .. } => wrap_expr(inner, target, wrapper_macro, wrap_all),
PureExpr::Call { func, args } => {
let mut c = wrap_expr(func, target, wrapper_macro, wrap_all);
for arg in args {
if wrap_all || c == 0 {
c += wrap_expr(arg, target, wrapper_macro, wrap_all);
}
}
c
}
PureExpr::MethodCall { receiver, args, .. } => {
let mut c = wrap_expr(receiver, target, wrapper_macro, wrap_all);
for arg in args {
if wrap_all || c == 0 {
c += wrap_expr(arg, target, wrapper_macro, wrap_all);
}
}
c
}
PureExpr::Field { expr: inner, .. } => wrap_expr(inner, target, wrapper_macro, wrap_all),
PureExpr::Index { expr: e, index } => {
let mut c = wrap_expr(e, target, wrapper_macro, wrap_all);
if wrap_all || c == 0 {
c += wrap_expr(index, target, wrapper_macro, wrap_all);
}
c
}
PureExpr::Block { block: b, .. } => wrap_expr_in_block(b, target, wrapper_macro, wrap_all),
PureExpr::If {
cond,
then_branch,
else_branch,
} => {
let mut c = wrap_expr(cond, target, wrapper_macro, wrap_all);
if wrap_all || c == 0 {
c += wrap_expr_in_block(then_branch, target, wrapper_macro, wrap_all);
}
if let Some(else_expr) = else_branch {
if wrap_all || c == 0 {
c += wrap_expr(else_expr, target, wrapper_macro, wrap_all);
}
}
c
}
PureExpr::Match { expr: e, arms } => {
let mut c = wrap_expr(e, target, wrapper_macro, wrap_all);
for arm in arms {
if wrap_all || c == 0 {
c += wrap_expr(&mut arm.body, target, wrapper_macro, wrap_all);
}
if let Some(guard) = &mut arm.guard {
if wrap_all || c == 0 {
c += wrap_expr(guard, target, wrapper_macro, wrap_all);
}
}
}
c
}
PureExpr::Loop { body: b, .. } => wrap_expr_in_block(b, target, wrapper_macro, wrap_all),
PureExpr::While { cond, body, .. } => {
let mut c = wrap_expr(cond, target, wrapper_macro, wrap_all);
if wrap_all || c == 0 {
c += wrap_expr_in_block(body, target, wrapper_macro, wrap_all);
}
c
}
PureExpr::For { expr: e, body, .. } => {
let mut c = wrap_expr(e, target, wrapper_macro, wrap_all);
if wrap_all || c == 0 {
c += wrap_expr_in_block(body, target, wrapper_macro, wrap_all);
}
c
}
PureExpr::Return(Some(inner))
| PureExpr::Break {
expr: Some(inner), ..
} => wrap_expr(inner, target, wrapper_macro, wrap_all),
PureExpr::Closure { body, .. } => wrap_expr(body, target, wrapper_macro, wrap_all),
PureExpr::Struct { fields, .. } => {
let mut c = 0;
for (_, field_expr) in fields {
if wrap_all || c == 0 {
c += wrap_expr(field_expr, target, wrapper_macro, wrap_all);
}
}
c
}
PureExpr::Tuple(exprs) | PureExpr::Array(exprs) => {
let mut c = 0;
for e in exprs {
if wrap_all || c == 0 {
c += wrap_expr(e, target, wrapper_macro, wrap_all);
}
}
c
}
PureExpr::Ref { expr: inner, .. } => wrap_expr(inner, target, wrapper_macro, wrap_all),
PureExpr::Await(inner) | PureExpr::Try(inner) => {
wrap_expr(inner, target, wrapper_macro, wrap_all)
}
PureExpr::Range { start, end, .. } => {
let mut c = 0;
if let Some(s) = start {
c += wrap_expr(s, target, wrapper_macro, wrap_all);
}
if let Some(e) = end {
if wrap_all || c == 0 {
c += wrap_expr(e, target, wrapper_macro, wrap_all);
}
}
c
}
PureExpr::Cast { expr: inner, .. } => wrap_expr(inner, target, wrapper_macro, wrap_all),
PureExpr::Let { expr: inner, .. } => wrap_expr(inner, target, wrapper_macro, wrap_all),
PureExpr::Async { body, .. } | PureExpr::Unsafe(body) => {
wrap_expr_in_block(body, target, wrapper_macro, wrap_all)
}
PureExpr::Repeat { expr: e, len } => {
let mut c = wrap_expr(e, target, wrapper_macro, wrap_all);
if wrap_all || c == 0 {
c += wrap_expr(len, target, wrapper_macro, wrap_all);
}
c
}
PureExpr::Lit(_)
| PureExpr::Path(_)
| PureExpr::Macro { .. }
| PureExpr::Return(None)
| PureExpr::Break { expr: None, .. }
| PureExpr::Continue { .. }
| PureExpr::Other(_) => 0,
}
}
fn replace_expr_in_stmt(
stmt: &mut PureStmt,
old: &PureExpr,
new: &PureExpr,
replace_all: bool,
) -> usize {
match stmt {
PureStmt::Local { init, .. } => {
if let Some(expr) = init {
replace_expr(expr, old, new, replace_all)
} else {
0
}
}
PureStmt::Semi(expr) | PureStmt::Expr(expr) => replace_expr(expr, old, new, replace_all),
PureStmt::Item(_) => 0,
}
}
fn replace_expr(expr: &mut PureExpr, old: &PureExpr, new: &PureExpr, replace_all: bool) -> usize {
if expr == old {
*expr = new.clone();
return 1;
}
match expr {
PureExpr::Binary { left, right, .. } => {
let mut c = replace_expr(left, old, new, replace_all);
if replace_all || c == 0 {
c += replace_expr(right, old, new, replace_all);
}
c
}
PureExpr::Unary { expr: inner, .. } => replace_expr(inner, old, new, replace_all),
PureExpr::Call { func, args } => {
let mut c = replace_expr(func, old, new, replace_all);
for arg in args {
if replace_all || c == 0 {
c += replace_expr(arg, old, new, replace_all);
}
}
c
}
PureExpr::MethodCall { receiver, args, .. } => {
let mut c = replace_expr(receiver, old, new, replace_all);
for arg in args {
if replace_all || c == 0 {
c += replace_expr(arg, old, new, replace_all);
}
}
c
}
PureExpr::Field { expr: inner, .. } => replace_expr(inner, old, new, replace_all),
PureExpr::Index { expr: e, index } => {
let mut c = replace_expr(e, old, new, replace_all);
if replace_all || c == 0 {
c += replace_expr(index, old, new, replace_all);
}
c
}
PureExpr::Block { block: b, .. } => replace_expr_in_block(b, old, new, replace_all),
PureExpr::If {
cond,
then_branch,
else_branch,
} => {
let mut c = replace_expr(cond, old, new, replace_all);
if replace_all || c == 0 {
c += replace_expr_in_block(then_branch, old, new, replace_all);
}
if let Some(else_expr) = else_branch {
if replace_all || c == 0 {
c += replace_expr(else_expr, old, new, replace_all);
}
}
c
}
PureExpr::Match { expr: e, arms } => {
let mut c = replace_expr(e, old, new, replace_all);
for arm in arms {
if replace_all || c == 0 {
c += replace_expr(&mut arm.body, old, new, replace_all);
}
if let Some(guard) = &mut arm.guard {
if replace_all || c == 0 {
c += replace_expr(guard, old, new, replace_all);
}
}
}
c
}
PureExpr::Loop { body: b, .. } => replace_expr_in_block(b, old, new, replace_all),
PureExpr::While { cond, body, .. } => {
let mut c = replace_expr(cond, old, new, replace_all);
if replace_all || c == 0 {
c += replace_expr_in_block(body, old, new, replace_all);
}
c
}
PureExpr::For { expr: e, body, .. } => {
let mut c = replace_expr(e, old, new, replace_all);
if replace_all || c == 0 {
c += replace_expr_in_block(body, old, new, replace_all);
}
c
}
PureExpr::Return(Some(inner))
| PureExpr::Break {
expr: Some(inner), ..
} => replace_expr(inner, old, new, replace_all),
PureExpr::Closure { body, .. } => replace_expr(body, old, new, replace_all),
PureExpr::Struct { fields, .. } => {
let mut c = 0;
for (_, field_expr) in fields {
if replace_all || c == 0 {
c += replace_expr(field_expr, old, new, replace_all);
}
}
c
}
PureExpr::Tuple(exprs) | PureExpr::Array(exprs) => {
let mut c = 0;
for e in exprs {
if replace_all || c == 0 {
c += replace_expr(e, old, new, replace_all);
}
}
c
}
PureExpr::Ref { expr: inner, .. } => replace_expr(inner, old, new, replace_all),
PureExpr::Await(inner) | PureExpr::Try(inner) => replace_expr(inner, old, new, replace_all),
PureExpr::Range { start, end, .. } => {
let mut c = 0;
if let Some(s) = start {
c += replace_expr(s, old, new, replace_all);
}
if let Some(e) = end {
if replace_all || c == 0 {
c += replace_expr(e, old, new, replace_all);
}
}
c
}
PureExpr::Cast { expr: inner, .. } => replace_expr(inner, old, new, replace_all),
PureExpr::Let { expr: inner, .. } => replace_expr(inner, old, new, replace_all),
PureExpr::Async { body, .. } | PureExpr::Unsafe(body) => {
replace_expr_in_block(body, old, new, replace_all)
}
PureExpr::Repeat { expr: e, len } => {
let mut c = replace_expr(e, old, new, replace_all);
if replace_all || c == 0 {
c += replace_expr(len, old, new, replace_all);
}
c
}
PureExpr::Lit(_)
| PureExpr::Path(_)
| PureExpr::Macro { .. }
| PureExpr::Return(None)
| PureExpr::Break { expr: None, .. }
| PureExpr::Continue { .. }
| PureExpr::Other(_) => 0,
}
}
fn stmt_matches(target: &PureStmt, stmt: &PureStmt) -> bool {
target == stmt
}
fn stmt_matches_pattern(stmt: &PureStmt, pattern: &str) -> bool {
let macro_name = pattern
.strip_suffix("!(..)")
.or_else(|| pattern.strip_suffix('!'));
if let Some(name) = macro_name {
match stmt {
PureStmt::Semi(PureExpr::Macro { name: macro_n, .. })
| PureStmt::Expr(PureExpr::Macro { name: macro_n, .. }) => {
return macro_n == name;
}
_ => {}
}
}
if let Some(after_let) = pattern.strip_prefix("let ") {
let var_name = if let Some(eq_pos) = after_let.find(" = ") {
after_let[..eq_pos].trim()
} else if let Some(eq_pos) = after_let.find("=") {
after_let[..eq_pos].trim()
} else {
after_let.trim()
};
if let PureStmt::Local {
pattern: ryo_source::pure::PurePattern::Ident { name, .. },
..
} = stmt
{
if name == var_name {
return true;
}
}
}
let stmt_str = format!("{:?}", stmt);
stmt_str.contains(pattern)
}
fn remove_stmts_in_block(block: &mut PureBlock, pattern: &str, remove_all: bool) -> usize {
let initial_len = block.stmts.len();
let mut removed = 0;
block.stmts.retain(|stmt| {
if stmt_matches_pattern(stmt, pattern) {
if !remove_all && removed > 0 {
return true; }
removed += 1;
false } else {
true }
});
for stmt in &mut block.stmts {
match stmt {
PureStmt::Semi(expr) | PureStmt::Expr(expr) => {
removed += remove_stmts_in_expr(expr, pattern, remove_all);
}
_ => {}
}
}
removed.min(initial_len - block.stmts.len() + removed)
}
fn remove_stmts_in_expr(expr: &mut PureExpr, pattern: &str, remove_all: bool) -> usize {
match expr {
PureExpr::Block { block, .. } => remove_stmts_in_block(block, pattern, remove_all),
PureExpr::If {
then_branch,
else_branch,
..
} => {
let mut c = remove_stmts_in_block(then_branch, pattern, remove_all);
if let Some(else_expr) = else_branch {
c += remove_stmts_in_expr(else_expr, pattern, remove_all);
}
c
}
PureExpr::Loop { body: block, .. } | PureExpr::While { body: block, .. } => {
remove_stmts_in_block(block, pattern, remove_all)
}
PureExpr::For { body, .. } => remove_stmts_in_block(body, pattern, remove_all),
PureExpr::Match { arms, .. } => {
let mut c = 0;
for arm in arms {
c += remove_stmts_in_expr(&mut arm.body, pattern, remove_all);
}
c
}
PureExpr::Closure { body, .. } => remove_stmts_in_expr(body, pattern, remove_all),
PureExpr::Async { body, .. } | PureExpr::Unsafe(body) => {
remove_stmts_in_block(body, pattern, remove_all)
}
_ => 0,
}
}
fn replace_stmts_in_block(block: &mut PureBlock, old: &PureStmt, new: &PureStmt) -> usize {
let mut count = 0;
for stmt in &mut block.stmts {
if stmt == old {
*stmt = new.clone();
count += 1;
} else {
match stmt {
PureStmt::Semi(expr) | PureStmt::Expr(expr) => {
count += replace_stmts_in_expr(expr, old, new);
}
_ => {}
}
}
}
count
}
fn replace_stmts_in_expr(expr: &mut PureExpr, old: &PureStmt, new: &PureStmt) -> usize {
match expr {
PureExpr::Block { block, .. } => replace_stmts_in_block(block, old, new),
PureExpr::If {
then_branch,
else_branch,
..
} => {
let mut c = replace_stmts_in_block(then_branch, old, new);
if let Some(else_expr) = else_branch {
c += replace_stmts_in_expr(else_expr, old, new);
}
c
}
PureExpr::Loop { body: block, .. } | PureExpr::While { body: block, .. } => {
replace_stmts_in_block(block, old, new)
}
PureExpr::For { body, .. } => replace_stmts_in_block(body, old, new),
PureExpr::Match { arms, .. } => {
let mut c = 0;
for arm in arms {
c += replace_stmts_in_expr(&mut arm.body, old, new);
}
c
}
PureExpr::Closure { body, .. } => replace_stmts_in_expr(body, old, new),
PureExpr::Async { body, .. } | PureExpr::Unsafe(body) => {
replace_stmts_in_block(body, old, new)
}
_ => 0,
}
}
fn find_return_index(stmts: &[PureStmt]) -> usize {
for (i, stmt) in stmts.iter().enumerate().rev() {
match stmt {
PureStmt::Expr(PureExpr::Return(_)) | PureStmt::Semi(PureExpr::Return(_)) => {
return i;
}
PureStmt::Expr(_) if i == stmts.len() - 1 => {
return i;
}
_ => {}
}
}
stmts.len() }
fn insert_relative_to_stmt(
stmts: &mut Vec<PureStmt>,
stmt: &PureStmt,
reference_stmt: &PureStmt,
after: bool,
) -> bool {
for i in 0..stmts.len() {
if stmt_matches(reference_stmt, &stmts[i]) {
let insert_idx = if after { i + 1 } else { i };
stmts.insert(insert_idx, stmt.clone());
return true;
}
}
false
}
impl ASTRegApply for ReplaceExprAtMutation {
fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
let fn_id = self.target_fn;
if !matches!(ctx.symbol_registry.kind(fn_id), Some(SymbolKind::Function)) {
return MutationResult {
mutation_type: "ReplaceExprAt".to_string(),
changes: 0,
description: format!("Target function '{:?}' not found", fn_id),
};
}
if let Some(PureItem::Fn(func)) = ctx.ast_registry.get(fn_id) {
let mut new_func = func.clone();
if self.body_indices.is_empty() {
return MutationResult {
mutation_type: "ReplaceExprAt".to_string(),
changes: 0,
description: "Empty body indices".to_string(),
};
}
let stmt_idx = self.body_indices[0];
if let Some(stmt) = new_func.body.stmts.get_mut(stmt_idx) {
let replaced = if self.body_indices.len() == 1 {
replace_stmt_expr_at(stmt, &self.new_expr)
} else {
replace_in_stmt_at_path(stmt, &self.body_indices[1..], &self.new_expr)
};
if replaced {
ctx.set_ast(fn_id, PureItem::Fn(new_func));
ctx.emit_modified(fn_id, ModificationType::BodyModified);
return MutationResult {
mutation_type: "ReplaceExprAt".to_string(),
changes: 1,
description: format!(
"Replaced expression at {:?} in '{}'",
self.body_indices, self.target_fn
),
};
}
}
}
MutationResult {
mutation_type: "ReplaceExprAt".to_string(),
changes: 0,
description: "Failed to replace expression at position".to_string(),
}
}
}
fn replace_stmt_expr_at(stmt: &mut PureStmt, new_expr: &PureExpr) -> bool {
match stmt {
PureStmt::Local { init, .. } => {
if init.is_some() {
*init = Some(new_expr.clone());
true
} else {
false
}
}
PureStmt::Semi(expr) | PureStmt::Expr(expr) => {
*expr = new_expr.clone();
true
}
PureStmt::Item(_) => false,
}
}
fn replace_in_stmt_at_path(stmt: &mut PureStmt, path: &[usize], new_expr: &PureExpr) -> bool {
let expr = match stmt {
PureStmt::Local {
init: Some(expr), ..
} => expr,
PureStmt::Semi(expr) | PureStmt::Expr(expr) => expr,
_ => return false,
};
if path.len() == 1 {
replace_child_at(expr, path[0], new_expr)
} else {
if let Some(child) = navigate_expr_mut(expr, &path[..path.len() - 1]) {
replace_child_at(child, path[path.len() - 1], new_expr)
} else {
false
}
}
}
fn navigate_expr_mut<'a>(expr: &'a mut PureExpr, path: &[usize]) -> Option<&'a mut PureExpr> {
if path.is_empty() {
return Some(expr);
}
let idx = path[0];
let child: Option<&'a mut PureExpr> = match expr {
PureExpr::Binary { left, right, .. } => match idx {
0 => Some(left.as_mut()),
1 => Some(right.as_mut()),
_ => None,
},
PureExpr::Unary { expr: inner, .. } => {
if idx == 0 {
Some(inner.as_mut())
} else {
None
}
}
PureExpr::Call { args, .. } => args.get_mut(idx),
PureExpr::MethodCall { receiver, args, .. } => {
if idx == 0 {
Some(receiver.as_mut())
} else {
args.get_mut(idx - 1)
}
}
PureExpr::Tuple(exprs) | PureExpr::Array(exprs) => exprs.get_mut(idx),
_ => None,
};
child.and_then(|c| navigate_expr_mut(c, &path[1..]))
}
fn replace_child_at(expr: &mut PureExpr, idx: usize, new_expr: &PureExpr) -> bool {
match expr {
PureExpr::Binary { left, right, .. } => match idx {
0 => {
**left = new_expr.clone();
true
}
1 => {
**right = new_expr.clone();
true
}
_ => false,
},
PureExpr::Unary { expr: inner, .. } if idx == 0 => {
**inner = new_expr.clone();
true
}
PureExpr::Call { args, .. } => {
if let Some(arg) = args.get_mut(idx) {
*arg = new_expr.clone();
true
} else {
false
}
}
PureExpr::MethodCall { receiver, args, .. } => {
if idx == 0 {
**receiver = new_expr.clone();
true
} else if let Some(arg) = args.get_mut(idx - 1) {
*arg = new_expr.clone();
true
} else {
false
}
}
PureExpr::Tuple(exprs) | PureExpr::Array(exprs) => {
if let Some(e) = exprs.get_mut(idx) {
*e = new_expr.clone();
true
} else {
false
}
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::ASTMutationEngine;
use ryo_analysis::testing::ContextBuilder;
#[test]
fn test_v2_replace_expr() {
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
fn compute() -> i32 {
1 + 2
}
"#,
)
.build();
let compute_id = ctx
.registry
.iter()
.find(|(id, path)| {
matches!(ctx.registry.kind(*id), Some(SymbolKind::Function))
&& path.name() == "compute"
})
.map(|(id, _)| id)
.expect("compute function not found");
let old = PureExpr::Binary {
op: "+".to_string(),
left: Box::new(PureExpr::Lit("1".to_string())),
right: Box::new(PureExpr::Lit("2".to_string())),
};
let new = PureExpr::Lit("3".to_string());
let mutation = ReplaceExprMutation::new(old, new, compute_id);
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
println!("ReplaceExpr result: {:?}", result.result);
}
#[test]
fn test_v2_insert_statement_start() {
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
fn greet() {
println!("World");
}
"#,
)
.build();
let greet_id = ctx
.registry
.iter()
.find(|(id, path)| {
matches!(ctx.registry.kind(*id), Some(SymbolKind::Function))
&& path.name() == "greet"
})
.map(|(id, _)| id)
.expect("greet function not found");
let stmt = PureStmt::Semi(PureExpr::Macro {
name: "println".to_string(),
delimiter: MacroDelimiter::Paren,
tokens: "\"Hello\"".to_string(),
});
let mutation = InsertStatementMutation {
stmt,
target_fn: greet_id,
position: InsertPosition::Start,
reference_stmt: None,
};
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
assert_eq!(result.result.changes, 1);
}
#[test]
fn test_v2_insert_statement_after_pattern_let() {
use ryo_source::pure::ToPure;
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
fn process() {
let x = 1;
let y = 2;
}
"#,
)
.build();
let process_id = ctx
.registry
.iter()
.find(|(id, path)| {
matches!(ctx.registry.kind(*id), Some(SymbolKind::Function))
&& path.name() == "process"
})
.map(|(id, _)| id)
.expect("process function not found");
let new_stmt: syn::Stmt = syn::parse_str("let z = 3;").unwrap();
let ref_stmt: syn::Stmt = syn::parse_str("let x = 1;").unwrap();
let ref_pure = ref_stmt.to_pure();
if let Some(PureItem::Fn(func)) = ctx.ast_registry.get(process_id) {
eprintln!("=== AST body stmts ===");
for (i, s) in func.body.stmts.iter().enumerate() {
eprintln!(" [{}] {:?}", i, s);
}
eprintln!("=== reference_stmt ===");
eprintln!(" {:?}", ref_pure);
eprintln!("=== match result ===");
for (i, s) in func.body.stmts.iter().enumerate() {
eprintln!(" [{}] == ref? {}", i, &ref_pure == s);
}
}
let mutation = InsertStatementMutation {
stmt: new_stmt.to_pure(),
target_fn: process_id,
position: InsertPosition::AfterPattern,
reference_stmt: Some(ref_pure),
};
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
assert_eq!(
result.result.changes, 1,
"AfterPattern should insert 1 statement, got description: {}",
result.result.description
);
}
#[test]
fn test_v2_insert_statement_after_pattern_macro() {
use ryo_source::pure::ToPure;
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
fn greet() {
println!("hello");
println!("world");
}
"#,
)
.build();
let greet_id = ctx
.registry
.iter()
.find(|(id, path)| {
matches!(ctx.registry.kind(*id), Some(SymbolKind::Function))
&& path.name() == "greet"
})
.map(|(id, _)| id)
.expect("greet function not found");
let new_stmt: syn::Stmt = syn::parse_str("println!(\"inserted\");").unwrap();
let ref_stmt: syn::Stmt = syn::parse_str("println!(\"hello\");").unwrap();
let ref_pure = ref_stmt.to_pure();
if let Some(PureItem::Fn(func)) = ctx.ast_registry.get(greet_id) {
eprintln!("=== AST body stmts (macro test) ===");
for (i, s) in func.body.stmts.iter().enumerate() {
eprintln!(" [{}] {:?}", i, s);
}
eprintln!("=== reference_stmt ===");
eprintln!(" {:?}", ref_pure);
eprintln!("=== match result ===");
for (i, s) in func.body.stmts.iter().enumerate() {
eprintln!(" [{}] == ref? {}", i, &ref_pure == s);
}
}
let mutation = InsertStatementMutation {
stmt: new_stmt.to_pure(),
target_fn: greet_id,
position: InsertPosition::AfterPattern,
reference_stmt: Some(ref_pure),
};
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
assert_eq!(
result.result.changes, 1,
"AfterPattern should insert 1 statement, got description: {}",
result.result.description
);
}
#[test]
fn test_v2_insert_statement_after_pattern_method_chain() {
use ryo_source::pure::ToPure;
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
fn process(config: &mut Config) {
config.set_timeout(Duration::from_secs(30));
config.set_retries(3);
}
"#,
)
.build();
let process_id = ctx
.registry
.iter()
.find(|(id, path)| {
matches!(ctx.registry.kind(*id), Some(SymbolKind::Function))
&& path.name() == "process"
})
.map(|(id, _)| id)
.expect("process function not found");
let ref_stmt: syn::Stmt =
syn::parse_str("config.set_timeout(Duration::from_secs(30));").unwrap();
let new_stmt: syn::Stmt = syn::parse_str("config.enable_logging();").unwrap();
let ref_pure = ref_stmt.to_pure();
if let Some(PureItem::Fn(func)) = ctx.ast_registry.get(process_id) {
eprintln!("=== AST body stmts (method chain test) ===");
for (i, s) in func.body.stmts.iter().enumerate() {
eprintln!(" [{}] {:?}", i, s);
}
eprintln!("=== reference_stmt ===");
eprintln!(" {:?}", ref_pure);
eprintln!("=== match ===");
for (i, s) in func.body.stmts.iter().enumerate() {
eprintln!(" [{}] == ref? {}", i, &ref_pure == s);
}
}
let mutation = InsertStatementMutation {
stmt: new_stmt.to_pure(),
target_fn: process_id,
position: InsertPosition::AfterPattern,
reference_stmt: Some(ref_pure),
};
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
assert_eq!(
result.result.changes, 1,
"AfterPattern for method chain should work, got: {}",
result.result.description
);
}
#[test]
fn test_v2_remove_statement_no_match() {
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
fn simple() -> i32 {
42
}
"#,
)
.build();
use ryo_analysis::SymbolKind;
use ryo_source::pure::ToPure;
let simple_id = ctx
.registry
.iter()
.find(|(id, path)| {
matches!(ctx.registry.kind(*id), Some(SymbolKind::Function))
&& path.name() == "simple"
})
.map(|(id, _)| id)
.expect("simple function not found");
let target_stmt: syn::Stmt = syn::parse_str("nonexistent;").unwrap();
let mutation = RemoveStatementMutation::new(
target_stmt.to_pure(),
"nonexistent;".to_string(),
simple_id,
);
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
assert_eq!(result.result.changes, 0);
}
}