use ryo_analysis::SymbolKind;
use ryo_mutations::basic::{
EnumToTraitMutation, EnumToTraitStrategy, ExtractTraitMutation, InlineTraitMutation,
MatchHandling, RemoveTraitMutation,
};
use ryo_mutations::{Mutation, MutationResult};
use ryo_source::pure::{
MacroDelimiter, PureBlock, PureExpr, PureField, PureFields, PureFn, PureGenericParam,
PureGenerics, PureImpl, PureImplItem, PureItem, PureParam, PureStmt, PureStruct, PureTrait,
PureTraitItem, PureType, PureVis,
};
use crate::engine::{ASTMutationContext, ASTRegApply};
impl ASTRegApply for ExtractTraitMutation {
fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
let impl_id = self.symbol_id;
let impl_path = match ctx.symbol_registry.path(impl_id) {
Some(path) => path.clone(),
None => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("SymbolId {:?} not found in registry", impl_id),
};
}
};
let inherent_impl = match ctx.ast_registry.get(impl_id) {
Some(PureItem::Impl(imp)) => {
if imp.trait_.is_some() {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!(
"SymbolId {:?} is a trait impl, not an inherent impl",
impl_id
),
};
}
imp.clone()
}
Some(_) => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("SymbolId {:?} is not an impl block", impl_id),
};
}
None => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("No AST found for SymbolId {:?}", impl_id),
};
}
};
let struct_name = inherent_impl.self_ty.clone();
let (extracted_items, remaining_items): (Vec<_>, Vec<_>) =
inherent_impl.items.into_iter().partition(|item| {
if let PureImplItem::Fn(f) = item {
match &self.methods {
Some(methods) => methods.contains(&f.name),
None => true, }
} else {
false }
});
if extracted_items.is_empty() {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: "No methods to extract".to_string(),
};
}
let mut changes = 0;
let trait_items: Vec<PureTraitItem> = extracted_items
.iter()
.filter_map(|item| {
if let PureImplItem::Fn(f) = item {
let trait_fn = PureFn {
attrs: f.attrs.clone(),
vis: PureVis::Private, is_async: f.is_async,
is_async_inferred: f.is_async_inferred,
is_const: f.is_const,
is_unsafe: f.is_unsafe,
abi: None,
name: f.name.clone(),
generics: f.generics.clone(),
params: f.params.clone(),
ret: f.ret.clone(),
body: PureBlock::default(), };
Some(PureTraitItem::Fn(trait_fn))
} else {
None
}
})
.collect();
let new_trait = PureTrait {
attrs: Vec::new(),
vis: PureVis::Public, is_unsafe: false,
is_auto: false,
name: self.trait_name.clone(),
generics: PureGenerics::default(),
supertraits: Vec::new(),
items: trait_items,
};
let trait_path = match impl_path
.parent()
.and_then(|p| p.child(&self.trait_name).ok())
{
Some(path) => path,
None => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("Failed to create path for trait '{}'", self.trait_name),
};
}
};
if ctx
.register_with_ast(
trait_path.clone(),
SymbolKind::Trait,
PureItem::Trait(new_trait),
)
.is_some()
{
changes += 1;
}
let trait_impl_items: Vec<PureImplItem> = extracted_items
.into_iter()
.map(|item| {
if let PureImplItem::Fn(mut f) = item {
f.vis = PureVis::Private; PureImplItem::Fn(f)
} else {
item
}
})
.collect();
let trait_impl = PureImpl {
attrs: Vec::new(),
generics: inherent_impl.generics.clone(),
is_unsafe: false,
trait_: Some(self.trait_name.clone()),
self_ty: struct_name.clone(),
items: trait_impl_items,
};
let trait_impl_name = format!(
"<impl {} for {}>",
self.trait_name,
struct_name
.replace("::", "_")
.replace('<', "_")
.replace('>', "")
);
let trait_impl_path = match impl_path
.parent()
.and_then(|p| p.child(&trait_impl_name).ok())
{
Some(path) => path,
None => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes,
description: "Failed to create path for trait impl".to_string(),
};
}
};
if let Some(_trait_impl_id) = ctx.register_with_ast(
trait_impl_path,
SymbolKind::Impl,
PureItem::Impl(trait_impl.clone()),
) {
if let Some(parent_path) = impl_path.parent() {
if let Some(parent_id) = ctx.symbol_registry.lookup(&parent_path) {
if let Some(module_items) = ctx.ast_registry.get_module_items_mut(parent_id) {
module_items.push(PureItem::Impl(trait_impl));
}
}
}
changes += 1;
}
if remaining_items.is_empty() {
ctx.remove_symbol(impl_id);
changes += 1;
} else {
let updated_impl = PureImpl {
attrs: inherent_impl.attrs,
generics: inherent_impl.generics,
is_unsafe: inherent_impl.is_unsafe,
trait_: None,
self_ty: struct_name.clone(),
items: remaining_items,
};
ctx.set_ast(impl_id, PureItem::Impl(updated_impl));
changes += 1;
}
MutationResult {
mutation_type: self.mutation_type().to_string(),
changes,
description: format!(
"Extracted trait '{}' from '{}'",
self.trait_name, struct_name
),
}
}
}
impl ASTRegApply for InlineTraitMutation {
fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
let trait_name = match ctx.symbol_registry.path(self.symbol_id) {
Some(path) => path.name().to_string(),
None => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("Trait symbol {:?} not found in registry", self.symbol_id),
};
}
};
let trait_impl_entry = ctx.symbol_registry.iter().find(|(id, _path)| {
if !matches!(ctx.symbol_registry.kind(*id), Some(SymbolKind::Impl)) {
return false;
}
if let Some(PureItem::Impl(imp)) = ctx.ast_registry.get(*id) {
imp.trait_.as_ref() == Some(&trait_name) && imp.self_ty == self.struct_name
} else {
false
}
});
let (trait_impl_id, trait_impl_path) = match trait_impl_entry {
Some((id, path)) => (id, path.clone()),
None => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!(
"No impl of '{}' for '{}' found",
trait_name, self.struct_name
),
};
}
};
let trait_impl = match ctx.ast_registry.get(trait_impl_id) {
Some(PureItem::Impl(imp)) => imp.clone(),
_ => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: "No AST found for trait impl".to_string(),
};
}
};
let mut changes = 0;
let inherent_impl_entry = ctx.symbol_registry.iter().find(|(id, _path)| {
if !matches!(ctx.symbol_registry.kind(*id), Some(SymbolKind::Impl)) {
return false;
}
if let Some(PureItem::Impl(imp)) = ctx.ast_registry.get(*id) {
imp.trait_.is_none() && imp.self_ty == self.struct_name
} else {
false
}
});
if let Some((inherent_impl_id, _)) = inherent_impl_entry {
if let Some(PureItem::Impl(mut inherent_impl)) =
ctx.ast_registry.get(inherent_impl_id).cloned()
{
inherent_impl.items.extend(trait_impl.items.clone());
ctx.set_ast(inherent_impl_id, PureItem::Impl(inherent_impl.clone()));
if let Some(parent_path) = trait_impl_path.parent() {
if let Some(parent_id) = ctx.symbol_registry.lookup(&parent_path) {
if let Some(module_items) = ctx.ast_registry.get_module_items_mut(parent_id)
{
for item in module_items.iter_mut() {
if let PureItem::Impl(impl_block) = item {
if impl_block.trait_.is_none()
&& impl_block.self_ty == self.struct_name
{
impl_block.items.extend(trait_impl.items.clone());
break;
}
}
}
}
}
}
changes += 1;
}
} else {
let new_inherent_impl = PureImpl {
attrs: Vec::new(),
generics: trait_impl.generics.clone(),
is_unsafe: false,
trait_: None,
self_ty: self.struct_name.clone(),
items: trait_impl.items.clone(),
};
let impl_name = format!(
"<impl {}>",
self.struct_name
.replace("::", "_")
.replace('<', "_")
.replace('>', "")
);
let impl_path = match trait_impl_path
.parent()
.and_then(|p| p.child(&impl_name).ok())
{
Some(path) => path,
None => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: "Failed to create path for inherent impl".to_string(),
};
}
};
if let Some(_new_impl_id) = ctx.register_with_ast(
impl_path,
SymbolKind::Impl,
PureItem::Impl(new_inherent_impl.clone()),
) {
if let Some(parent_path) = trait_impl_path.parent() {
if let Some(parent_id) = ctx.symbol_registry.lookup(&parent_path) {
if let Some(module_items) = ctx.ast_registry.get_module_items_mut(parent_id)
{
module_items.push(PureItem::Impl(new_inherent_impl));
}
}
}
changes += 1;
}
}
ctx.remove_symbol(trait_impl_id);
changes += 1;
if self.remove_trait {
ctx.remove_symbol(self.symbol_id);
changes += 1;
}
MutationResult {
mutation_type: self.mutation_type().to_string(),
changes,
description: format!(
"Inlined trait '{}' into '{}'{}",
trait_name,
self.struct_name,
if self.remove_trait {
" (trait removed)"
} else {
""
}
),
}
}
}
impl ASTRegApply for RemoveTraitMutation {
fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
let trait_id = self.trait_id;
if ctx.symbol_registry.kind(trait_id) != Some(SymbolKind::Trait) {
return MutationResult {
mutation_type: "RemoveTrait".to_string(),
changes: 0,
description: format!("Symbol {} is not a trait", trait_id),
};
}
ctx.ast_registry.remove(trait_id);
MutationResult {
mutation_type: "RemoveTrait".to_string(),
changes: 1,
description: format!("Removed trait {}", trait_id),
}
}
}
impl ASTRegApply for EnumToTraitMutation {
fn apply_to_registry(&self, ctx: &mut ASTMutationContext) -> MutationResult {
let enum_id = self.symbol_id;
if !matches!(ctx.symbol_registry.kind(enum_id), Some(SymbolKind::Enum)) {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("Symbol {:?} is not an enum or not found", enum_id),
};
}
let enum_path = match ctx.symbol_registry.path(enum_id) {
Some(p) => p.clone(),
None => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("Path not found for symbol {:?}", enum_id),
};
}
};
let enum_name = enum_path.name().to_string();
let default_trait_name;
let trait_name = match &self.trait_name {
Some(name) => name.as_str(),
None => {
match self.strategy {
EnumToTraitStrategy::MarkerOnly => {
default_trait_name = format!("{}Trait", enum_name);
&default_trait_name
}
_ => &enum_name,
}
}
};
let enum_def = match ctx.ast_registry.get(enum_id) {
Some(PureItem::Enum(e)) => e.clone(),
_ => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("No AST found for enum '{}'", enum_name),
};
}
};
let mut changes = 0;
let parent_path = match enum_path.parent() {
Some(p) => p,
None => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: "Cannot determine parent module".to_string(),
};
}
};
let variant_names: Vec<String> = enum_def.variants.iter().map(|v| v.name.clone()).collect();
let enum_impl_id: Option<_> = ctx
.symbol_registry
.iter()
.find(|(id, _)| {
if !matches!(ctx.symbol_registry.kind(*id), Some(SymbolKind::Impl)) {
return false;
}
if let Some(PureItem::Impl(imp)) = ctx.ast_registry.get(*id) {
imp.trait_.is_none() && imp.self_ty == enum_name
} else {
false
}
})
.map(|(id, _)| id);
let enum_methods: Vec<PureFn> = if let Some(impl_id) = enum_impl_id {
if let Some(PureItem::Impl(imp)) = ctx.ast_registry.get(impl_id) {
imp.items
.iter()
.filter_map(|item| {
if let PureImplItem::Fn(f) = item {
let has_self = f
.params
.iter()
.any(|p| matches!(p, PureParam::SelfValue { .. }));
if has_self {
Some(f.clone())
} else {
None
}
} else {
None
}
})
.collect()
} else {
Vec::new()
}
} else {
Vec::new()
};
let enum_removed_early = match self.strategy {
EnumToTraitStrategy::MarkerOnly => false, _ if self.remove_enum && trait_name == enum_name => {
ctx.remove_symbol(enum_id);
changes += 1;
if let Some(impl_id) = enum_impl_id {
ctx.remove_symbol(impl_id);
changes += 1;
}
true
}
_ => false,
};
let trait_items: Vec<PureTraitItem> = enum_methods
.iter()
.map(|f| {
let trait_fn = PureFn {
attrs: Vec::new(),
vis: PureVis::Private, is_async: f.is_async,
is_async_inferred: f.is_async_inferred,
is_const: f.is_const,
is_unsafe: f.is_unsafe,
abi: None,
name: f.name.clone(),
generics: f.generics.clone(),
params: f.params.clone(),
ret: f.ret.clone(),
body: PureBlock::default(), };
PureTraitItem::Fn(trait_fn)
})
.collect();
let new_trait = PureTrait {
attrs: Vec::new(),
vis: PureVis::Public,
is_unsafe: false,
is_auto: false,
name: trait_name.to_string(),
generics: PureGenerics::default(),
supertraits: Vec::new(),
items: trait_items,
};
let trait_path = match parent_path.child(trait_name) {
Ok(path) => path,
Err(_) => {
return MutationResult {
mutation_type: self.mutation_type().to_string(),
changes: 0,
description: format!("Failed to create path for trait '{}'", trait_name),
};
}
};
if ctx
.register_with_ast(
trait_path.clone(),
SymbolKind::Trait,
PureItem::Trait(new_trait),
)
.is_some()
{
changes += 1;
}
for variant in &enum_def.variants {
let struct_fields = match &variant.fields {
PureFields::Named(fields) => PureFields::Named(
fields
.iter()
.map(|f| PureField {
attrs: Vec::new(),
vis: PureVis::Public,
name: f.name.clone(),
ty: f.ty.clone(),
})
.collect(),
),
PureFields::Tuple(types) => PureFields::Tuple(types.clone()),
PureFields::Unit => PureFields::Unit,
};
let new_struct = PureStruct {
attrs: Vec::new(),
vis: PureVis::Public,
name: variant.name.clone(),
generics: PureGenerics::default(),
fields: struct_fields,
};
let struct_path = match parent_path.child(&variant.name) {
Ok(path) => path,
Err(_) => continue,
};
if ctx
.register_with_ast(
struct_path.clone(),
SymbolKind::Struct,
PureItem::Struct(new_struct),
)
.is_some()
{
changes += 1;
}
let impl_items: Vec<PureImplItem> = enum_methods
.iter()
.map(|f| {
let impl_fn = PureFn {
attrs: Vec::new(),
vis: PureVis::Private, is_async: f.is_async,
is_async_inferred: f.is_async_inferred,
is_const: f.is_const,
is_unsafe: f.is_unsafe,
abi: None,
name: f.name.clone(),
generics: f.generics.clone(),
params: f.params.clone(),
ret: f.ret.clone(),
body: PureBlock {
stmts: vec![PureStmt::Expr(PureExpr::Macro {
name: "todo".to_string(),
delimiter: MacroDelimiter::Paren,
tokens: format!("\"{}::{}::{}\"", trait_name, variant.name, f.name),
})],
},
};
PureImplItem::Fn(impl_fn)
})
.collect();
let trait_impl = PureImpl {
attrs: Vec::new(),
generics: PureGenerics::default(),
is_unsafe: false,
trait_: Some(trait_name.to_string()),
self_ty: variant.name.clone(),
items: impl_items,
};
let impl_name = format!("<impl {} for {}>", trait_name, variant.name);
let impl_path = match parent_path.child(&impl_name) {
Ok(path) => path,
Err(_) => continue,
};
if ctx
.register_with_ast(impl_path, SymbolKind::Impl, PureItem::Impl(trait_impl))
.is_some()
{
changes += 1;
}
}
let usage_changes = match self.strategy {
EnumToTraitStrategy::MarkerOnly => 0, _ => replace_enum_usages(ctx, &enum_name, &variant_names),
};
changes += usage_changes;
let type_changes = match self.strategy {
EnumToTraitStrategy::Dynamic => {
replace_type_annotations(ctx, &enum_name, trait_name, TypeReplacement::BoxDyn)
}
EnumToTraitStrategy::Static => {
replace_type_annotations(ctx, &enum_name, trait_name, TypeReplacement::ImplTrait)
}
EnumToTraitStrategy::Generic => {
replace_type_annotations(ctx, &enum_name, trait_name, TypeReplacement::Generic)
}
EnumToTraitStrategy::MarkerOnly => {
0
}
};
changes += type_changes;
let match_changes = match self.match_handling {
MatchHandling::WarnOnly => {
count_match_expressions(ctx, &enum_name)
}
MatchHandling::Downcast => {
0
}
MatchHandling::BlockOnMatch => {
0
}
};
let should_remove = match self.strategy {
EnumToTraitStrategy::MarkerOnly => false, _ => self.remove_enum && !enum_removed_early,
};
if should_remove {
ctx.remove_symbol(enum_id);
changes += 1;
if let Some(impl_id) = enum_impl_id {
ctx.remove_symbol(impl_id);
changes += 1;
}
}
let strategy_desc = match self.strategy {
EnumToTraitStrategy::Dynamic => " with Box<dyn>",
EnumToTraitStrategy::Static => " with impl Trait",
EnumToTraitStrategy::Generic => " with generics",
EnumToTraitStrategy::MarkerOnly => " (marker only)",
};
let match_warning = if match_changes > 0 {
format!(
" ({} match expression(s) need manual migration)",
match_changes
)
} else {
String::new()
};
let enum_actually_removed = enum_removed_early || should_remove;
MutationResult {
mutation_type: self.mutation_type().to_string(),
changes,
description: format!(
"Converted enum '{}' to trait '{}' with {} variants{}{}{}",
enum_name,
trait_name,
variant_names.len(),
strategy_desc,
if enum_actually_removed {
" (enum removed)"
} else {
""
},
match_warning
),
}
}
}
enum TypeReplacement {
BoxDyn,
ImplTrait,
Generic,
}
fn replace_type_annotations(
ctx: &mut ASTMutationContext,
enum_name: &str,
trait_name: &str,
replacement: TypeReplacement,
) -> usize {
let mut changes = 0;
let symbol_ids: Vec<_> = ctx.symbol_registry.iter().map(|(id, _)| id).collect();
for symbol_id in symbol_ids {
let item = match ctx.ast_registry.get(symbol_id) {
Some(item) => item.clone(),
None => continue,
};
let updated_item = match item {
PureItem::Fn(mut f) => {
let fn_changes = replace_types_in_fn(&mut f, enum_name, trait_name, &replacement);
if fn_changes > 0 {
changes += fn_changes;
Some(PureItem::Fn(f))
} else {
None
}
}
PureItem::Struct(mut s) => {
let field_replacement = match replacement {
TypeReplacement::ImplTrait => &TypeReplacement::BoxDyn,
_ => &replacement,
};
let struct_changes = replace_types_in_fields(
&mut s.fields,
enum_name,
trait_name,
field_replacement,
);
if struct_changes > 0 {
if matches!(replacement, TypeReplacement::Generic) {
add_generic_param(&mut s.generics, trait_name);
}
changes += struct_changes;
Some(PureItem::Struct(s))
} else {
None
}
}
PureItem::Impl(mut imp) => {
let mut impl_changed = false;
for item in &mut imp.items {
if let PureImplItem::Fn(ref mut f) = item {
if replace_types_in_fn(f, enum_name, trait_name, &replacement) > 0 {
impl_changed = true;
}
}
}
if impl_changed {
changes += 1;
Some(PureItem::Impl(imp))
} else {
None
}
}
PureItem::Trait(mut t) => {
let mut trait_changed = false;
for item in &mut t.items {
if let PureTraitItem::Fn(ref mut f) = item {
if replace_types_in_fn(f, enum_name, trait_name, &replacement) > 0 {
trait_changed = true;
}
}
}
if trait_changed {
changes += 1;
Some(PureItem::Trait(t))
} else {
None
}
}
_ => None,
};
if let Some(new_item) = updated_item {
ctx.set_ast(symbol_id, new_item);
}
}
changes
}
fn replace_types_in_fn(
f: &mut PureFn,
enum_name: &str,
trait_name: &str,
replacement: &TypeReplacement,
) -> usize {
let mut changes = 0;
for param in &mut f.params {
if let PureParam::Typed { ty, .. } = param {
if replace_type(ty, enum_name, trait_name, replacement) {
changes += 1;
}
}
}
if let Some(ref mut ret) = f.ret {
if replace_type(ret, enum_name, trait_name, replacement) {
changes += 1;
}
}
if changes > 0 && matches!(replacement, TypeReplacement::Generic) {
add_generic_param(&mut f.generics, trait_name);
}
changes
}
fn replace_types_in_fields(
fields: &mut PureFields,
enum_name: &str,
trait_name: &str,
replacement: &TypeReplacement,
) -> usize {
let mut changes = 0;
match fields {
PureFields::Named(named_fields) => {
for field in named_fields {
if replace_type(&mut field.ty, enum_name, trait_name, replacement) {
changes += 1;
}
}
}
PureFields::Tuple(types) => {
for ty in types {
if replace_type(ty, enum_name, trait_name, replacement) {
changes += 1;
}
}
}
PureFields::Unit => {}
}
changes
}
fn add_generic_param(generics: &mut PureGenerics, trait_name: &str) {
let has_t = generics
.params
.iter()
.any(|p| matches!(p, PureGenericParam::Type { name, .. } if name == "T"));
if !has_t {
generics.params.push(PureGenericParam::Type {
name: "T".to_string(),
bounds: vec![trait_name.to_string()],
});
}
}
fn replace_type(
ty: &mut PureType,
enum_name: &str,
trait_name: &str,
replacement: &TypeReplacement,
) -> bool {
match ty {
PureType::Path(path) => {
let type_name = path.split("::").last().unwrap_or(path);
if type_name == enum_name || path == enum_name {
*ty = match replacement {
TypeReplacement::BoxDyn => PureType::Path(format!("Box<dyn {}>", trait_name)),
TypeReplacement::ImplTrait => PureType::ImplTrait(vec![trait_name.to_string()]),
TypeReplacement::Generic => {
PureType::Path("T".to_string())
}
};
return true;
}
if path.contains('<') && path.contains(enum_name) {
let replacement_str = match replacement {
TypeReplacement::BoxDyn => format!("Box<dyn {}>", trait_name),
TypeReplacement::ImplTrait => format!("impl {}", trait_name),
TypeReplacement::Generic => "T".to_string(),
};
let new_path = replace_type_in_generic_path(path, enum_name, &replacement_str);
if new_path != *path {
*path = new_path;
return true;
}
}
false
}
PureType::Ref { ty: inner, .. } => replace_type(inner, enum_name, trait_name, replacement),
PureType::Tuple(types) => {
let mut changed = false;
for t in types {
if replace_type(t, enum_name, trait_name, replacement) {
changed = true;
}
}
changed
}
PureType::Array { ty: inner, .. } => {
replace_type(inner, enum_name, trait_name, replacement)
}
PureType::Slice(inner) => replace_type(inner, enum_name, trait_name, replacement),
PureType::Fn { params, ret } => {
let mut changed = false;
for p in params {
if replace_type(p, enum_name, trait_name, replacement) {
changed = true;
}
}
if let Some(ref mut r) = ret {
if replace_type(r, enum_name, trait_name, replacement) {
changed = true;
}
}
changed
}
_ => false,
}
}
fn replace_type_in_generic_path(path: &str, enum_name: &str, replacement: &str) -> String {
let mut result = String::new();
let chars = path.chars().peekable();
let mut current_word = String::new();
for c in chars {
if c.is_alphanumeric() || c == '_' {
current_word.push(c);
} else {
if current_word == enum_name {
result.push_str(replacement);
} else {
result.push_str(¤t_word);
}
current_word.clear();
result.push(c);
}
}
if current_word == enum_name {
result.push_str(replacement);
} else {
result.push_str(¤t_word);
}
result
}
fn count_match_expressions(ctx: &ASTMutationContext, enum_name: &str) -> usize {
let mut count = 0;
let fn_ids: Vec<_> = ctx
.symbol_registry
.iter()
.filter(|(id, _)| matches!(ctx.symbol_registry.kind(*id), Some(SymbolKind::Function)))
.map(|(id, _)| id)
.collect();
for fn_id in fn_ids {
if let Some(PureItem::Fn(func)) = ctx.ast_registry.get(fn_id) {
count += count_matches_in_block(&func.body, enum_name);
}
}
count
}
fn count_matches_in_block(block: &PureBlock, enum_name: &str) -> usize {
let mut count = 0;
for stmt in &block.stmts {
count += count_matches_in_stmt(stmt, enum_name);
}
count
}
fn count_matches_in_stmt(stmt: &PureStmt, enum_name: &str) -> usize {
match stmt {
PureStmt::Local { init, .. } => {
if let Some(expr) = init {
count_matches_in_expr(expr, enum_name)
} else {
0
}
}
PureStmt::Semi(expr) | PureStmt::Expr(expr) => count_matches_in_expr(expr, enum_name),
PureStmt::Item(_) => 0,
}
}
fn count_matches_in_expr(expr: &PureExpr, enum_name: &str) -> usize {
match expr {
PureExpr::Match {
expr: scrutinee,
arms,
} => {
let mut count = count_matches_in_expr(scrutinee, enum_name);
for arm in arms {
if pattern_references_enum(&arm.pattern, enum_name) {
count += 1;
break; }
}
for arm in arms {
count += count_matches_in_expr(&arm.body, enum_name);
}
count
}
PureExpr::If {
cond,
then_branch,
else_branch,
} => {
let mut count = count_matches_in_expr(cond, enum_name);
count += count_matches_in_block(then_branch, enum_name);
if let Some(else_expr) = else_branch {
count += count_matches_in_expr(else_expr, enum_name);
}
count
}
PureExpr::Block { block, .. } => count_matches_in_block(block, enum_name),
PureExpr::Call { func, args, .. } => {
let mut count = count_matches_in_expr(func, enum_name);
for arg in args {
count += count_matches_in_expr(arg, enum_name);
}
count
}
PureExpr::MethodCall { receiver, args, .. } => {
let mut count = count_matches_in_expr(receiver, enum_name);
for arg in args {
count += count_matches_in_expr(arg, enum_name);
}
count
}
PureExpr::Closure { body, .. } => count_matches_in_expr(body, enum_name),
PureExpr::Loop { body: block, .. } => count_matches_in_block(block, enum_name),
PureExpr::While { cond, body, .. } => {
count_matches_in_expr(cond, enum_name) + count_matches_in_block(body, enum_name)
}
PureExpr::For { expr, body, .. } => {
count_matches_in_expr(expr, enum_name) + count_matches_in_block(body, enum_name)
}
_ => 0,
}
}
fn pattern_references_enum(pattern: &PurePattern, enum_name: &str) -> bool {
match pattern {
PurePattern::Path(path) => path.starts_with(&format!("{}::", enum_name)),
PurePattern::Struct { path, .. } => path.starts_with(&format!("{}::", enum_name)),
PurePattern::Tuple(elements) | PurePattern::Slice(elements) => elements
.iter()
.any(|p| pattern_references_enum(p, enum_name)),
PurePattern::Or(patterns) => patterns
.iter()
.any(|p| pattern_references_enum(p, enum_name)),
PurePattern::Ref { pattern: inner, .. } => pattern_references_enum(inner, enum_name),
_ => false,
}
}
fn replace_enum_usages(
ctx: &mut ASTMutationContext,
enum_name: &str,
variant_names: &[String],
) -> usize {
let mut changes = 0;
let fn_ids: Vec<_> = ctx
.symbol_registry
.iter()
.filter(|(id, _)| matches!(ctx.symbol_registry.kind(*id), Some(SymbolKind::Function)))
.map(|(id, _)| id)
.collect();
for fn_id in fn_ids {
if let Some(PureItem::Fn(mut func)) = ctx.ast_registry.get(fn_id).cloned() {
let fn_changes = replace_in_block(&mut func.body, enum_name, variant_names);
if fn_changes > 0 {
ctx.set_ast(fn_id, PureItem::Fn(func));
changes += fn_changes;
}
}
}
changes
}
fn replace_in_block(block: &mut PureBlock, enum_name: &str, variant_names: &[String]) -> usize {
let mut changes = 0;
for stmt in &mut block.stmts {
changes += replace_in_stmt(stmt, enum_name, variant_names);
}
changes
}
fn replace_in_stmt(stmt: &mut PureStmt, enum_name: &str, variant_names: &[String]) -> usize {
match stmt {
PureStmt::Local { init, .. } => {
if let Some(expr) = init {
return replace_in_expr(expr, enum_name, variant_names);
}
0
}
PureStmt::Semi(expr) | PureStmt::Expr(expr) => {
replace_in_expr(expr, enum_name, variant_names)
}
PureStmt::Item(_) => 0,
}
}
fn replace_in_expr(expr: &mut PureExpr, enum_name: &str, variant_names: &[String]) -> usize {
match expr {
PureExpr::Path(path) => {
if path.starts_with(&format!("{}::", enum_name)) {
let variant_part = path.strip_prefix(&format!("{}::", enum_name));
if let Some(variant) = variant_part {
if variant_names.contains(&variant.to_string()) {
*path = variant.to_string();
return 1;
}
}
}
0
}
PureExpr::Call { func, args, .. } => {
let mut changes = replace_in_expr(func, enum_name, variant_names);
for arg in args {
changes += replace_in_expr(arg, enum_name, variant_names);
}
changes
}
PureExpr::MethodCall { receiver, args, .. } => {
let mut changes = replace_in_expr(receiver, enum_name, variant_names);
for arg in args {
changes += replace_in_expr(arg, enum_name, variant_names);
}
changes
}
PureExpr::Binary { left, right, .. } => {
replace_in_expr(left, enum_name, variant_names)
+ replace_in_expr(right, enum_name, variant_names)
}
PureExpr::Unary { expr, .. } => replace_in_expr(expr, enum_name, variant_names),
PureExpr::If {
cond,
then_branch,
else_branch,
} => {
let mut changes = replace_in_expr(cond, enum_name, variant_names);
changes += replace_in_block(then_branch, enum_name, variant_names);
if let Some(else_expr) = else_branch {
changes += replace_in_expr(else_expr, enum_name, variant_names);
}
changes
}
PureExpr::Match { expr, arms } => {
let mut changes = replace_in_expr(expr, enum_name, variant_names);
for arm in arms {
changes += replace_in_pattern(&mut arm.pattern, enum_name, variant_names);
changes += replace_in_expr(&mut arm.body, enum_name, variant_names);
}
changes
}
PureExpr::Block { block, .. } => replace_in_block(block, enum_name, variant_names),
PureExpr::Return(Some(v)) => replace_in_expr(v, enum_name, variant_names),
PureExpr::Return(None) => 0,
PureExpr::Struct { fields, .. } => {
let mut changes = 0;
for (_, field_expr) in fields {
changes += replace_in_expr(field_expr, enum_name, variant_names);
}
changes
}
PureExpr::Tuple(elements) => {
let mut changes = 0;
for elem in elements {
changes += replace_in_expr(elem, enum_name, variant_names);
}
changes
}
PureExpr::Array(elements) => {
let mut changes = 0;
for elem in elements {
changes += replace_in_expr(elem, enum_name, variant_names);
}
changes
}
PureExpr::Index { expr, index, .. } => {
replace_in_expr(expr, enum_name, variant_names)
+ replace_in_expr(index, enum_name, variant_names)
}
PureExpr::Field { expr, .. } => replace_in_expr(expr, enum_name, variant_names),
PureExpr::Ref { expr, .. } => replace_in_expr(expr, enum_name, variant_names),
PureExpr::Try(inner) => replace_in_expr(inner, enum_name, variant_names),
PureExpr::Await(inner) => replace_in_expr(inner, enum_name, variant_names),
PureExpr::Closure { body, .. } => replace_in_expr(body, enum_name, variant_names),
PureExpr::Loop { body, .. } => replace_in_block(body, enum_name, variant_names),
PureExpr::While { cond, body, .. } => {
replace_in_expr(cond, enum_name, variant_names)
+ replace_in_block(body, enum_name, variant_names)
}
PureExpr::For { expr, body, .. } => {
replace_in_expr(expr, enum_name, variant_names)
+ replace_in_block(body, enum_name, variant_names)
}
PureExpr::Let { expr, .. } => replace_in_expr(expr, enum_name, variant_names),
PureExpr::Range { start, end, .. } => {
let mut changes = 0;
if let Some(s) = start {
changes += replace_in_expr(s, enum_name, variant_names);
}
if let Some(e) = end {
changes += replace_in_expr(e, enum_name, variant_names);
}
changes
}
PureExpr::Cast { expr, .. } => replace_in_expr(expr, enum_name, variant_names),
_ => 0,
}
}
use ryo_source::pure::PurePattern;
fn replace_in_pattern(
pattern: &mut PurePattern,
enum_name: &str,
variant_names: &[String],
) -> usize {
match pattern {
PurePattern::Struct { path, fields, .. } => {
let mut changes = 0;
if path.starts_with(&format!("{}::", enum_name)) {
if let Some(variant) = path.strip_prefix(&format!("{}::", enum_name)) {
let base_variant = variant
.split(|c: char| !c.is_alphanumeric() && c != '_')
.next()
.unwrap_or(variant);
if variant_names.contains(&base_variant.to_string()) {
*path = variant.to_string();
changes += 1;
}
}
}
for (_, field_pattern) in fields {
changes += replace_in_pattern(field_pattern, enum_name, variant_names);
}
changes
}
PurePattern::Tuple(elements) | PurePattern::Slice(elements) => {
let mut changes = 0;
for elem in elements {
changes += replace_in_pattern(elem, enum_name, variant_names);
}
changes
}
PurePattern::Path(path) => {
if path.starts_with(&format!("{}::", enum_name)) {
if let Some(variant) = path.strip_prefix(&format!("{}::", enum_name)) {
if variant_names.contains(&variant.to_string()) {
*path = variant.to_string();
return 1;
}
}
}
0
}
PurePattern::Ref { pattern: inner, .. } => {
replace_in_pattern(inner, enum_name, variant_names)
}
PurePattern::Or(patterns) => {
let mut changes = 0;
for p in patterns {
changes += replace_in_pattern(p, enum_name, variant_names);
}
changes
}
_ => 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::ASTMutationEngine;
use ryo_analysis::testing::ContextBuilder;
#[test]
fn test_v2_extract_trait() {
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
struct Foo {
value: i32,
}
impl Foo {
fn get_value(&self) -> i32 {
self.value
}
fn set_value(&mut self, v: i32) {
self.value = v;
}
fn helper(&self) {}
}
"#,
)
.build();
let impl_id = ctx
.registry
.iter()
.find(|(id, _path)| {
if !matches!(ctx.registry.kind(*id), Some(SymbolKind::Impl)) {
return false;
}
if let Some(PureItem::Impl(imp)) = ctx.ast_registry.get(*id) {
imp.trait_.is_none() && imp.self_ty == "Foo"
} else {
false
}
})
.map(|(id, _)| id)
.expect("Should find impl Foo");
let mutation = ExtractTraitMutation::new(impl_id, "ValueAccessor")
.with_methods(vec!["get_value".to_string(), "set_value".to_string()]);
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
println!("ExtractTrait result: {:?}", result.result);
assert!(result.result.changes >= 2, "Expected at least 2 changes");
}
#[test]
fn test_v2_inline_trait() {
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
struct Foo;
trait Greet {
fn greet(&self) -> String;
}
impl Greet for Foo {
fn greet(&self) -> String {
"Hello".to_string()
}
}
"#,
)
.build();
let trait_id = ctx
.registry
.iter()
.find(|(id, _path)| matches!(ctx.registry.kind(*id), Some(SymbolKind::Trait)))
.map(|(id, _)| id)
.expect("Should find trait Greet");
let mutation = InlineTraitMutation::new(trait_id, "Foo");
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
println!("InlineTrait result: {:?}", result.result);
assert!(result.result.changes >= 2, "Expected at least 2 changes");
}
#[test]
fn test_v2_inline_trait_keep_trait() {
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
struct Foo;
trait Greet {
fn greet(&self) -> String;
}
impl Greet for Foo {
fn greet(&self) -> String {
"Hello".to_string()
}
}
"#,
)
.build();
let trait_id = ctx
.registry
.iter()
.find(|(id, _path)| matches!(ctx.registry.kind(*id), Some(SymbolKind::Trait)))
.map(|(id, _)| id)
.expect("Should find trait Greet");
let mutation = InlineTraitMutation::new(trait_id, "Foo").keep_trait();
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
println!("InlineTrait (keep_trait) result: {:?}", result.result);
assert!(result.result.changes >= 2, "Expected at least 2 changes");
}
#[test]
fn test_v2_enum_to_trait_dynamic_strategy() {
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
enum Status {
Running,
Stopped,
}
fn process(status: Status) -> Status {
status
}
struct Config {
current_status: Status,
}
"#,
)
.build();
let enum_id = ctx
.registry
.iter()
.find(|(id, path)| {
path.name() == "Status" && matches!(ctx.registry.kind(*id), Some(SymbolKind::Enum))
})
.map(|(id, _)| id)
.expect("Enum 'Status' should exist");
let mutation = EnumToTraitMutation::from_symbol_id(enum_id)
.with_strategy(EnumToTraitStrategy::Dynamic);
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
println!("EnumToTrait (Dynamic) result: {:?}", result.result);
assert!(result.result.changes >= 5, "Expected at least 5 changes");
assert!(
result.result.description.contains("Box<dyn>"),
"Should mention Box<dyn> strategy"
);
}
#[test]
fn test_v2_enum_to_trait_static_strategy() {
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
enum Filter {
Active,
Inactive,
}
fn apply_filter(filter: Filter) {
let _ = filter;
}
"#,
)
.build();
let enum_id = ctx
.registry
.iter()
.find(|(id, path)| {
path.name() == "Filter" && matches!(ctx.registry.kind(*id), Some(SymbolKind::Enum))
})
.map(|(id, _)| id)
.expect("Enum 'Filter' should exist");
let mutation =
EnumToTraitMutation::from_symbol_id(enum_id).with_strategy(EnumToTraitStrategy::Static);
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
println!("EnumToTrait (Static) result: {:?}", result.result);
assert!(result.result.changes >= 5, "Expected at least 5 changes");
assert!(
result.result.description.contains("impl Trait"),
"Should mention impl Trait strategy"
);
}
#[test]
fn test_v2_enum_to_trait_marker_only_strategy() {
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
enum Mode {
Fast,
Slow,
}
fn get_mode() -> Mode {
Mode::Fast
}
"#,
)
.build();
let enum_id = ctx
.registry
.iter()
.find(|(id, path)| {
path.name() == "Mode" && matches!(ctx.registry.kind(*id), Some(SymbolKind::Enum))
})
.map(|(id, _)| id)
.expect("Enum 'Mode' should exist");
let mutation = EnumToTraitMutation::from_symbol_id(enum_id)
.with_strategy(EnumToTraitStrategy::MarkerOnly);
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
println!("EnumToTrait (MarkerOnly) result: {:?}", result.result);
assert!(result.result.changes >= 5, "Expected at least 5 changes");
assert!(
result.result.description.contains("marker only"),
"Should mention marker only strategy"
);
}
#[test]
fn test_v2_enum_to_trait_generic_strategy() {
let mut ctx = ContextBuilder::new()
.with_file(
"src/lib.rs",
r#"
enum Status {
Running,
Stopped,
}
fn process(status: Status) -> Status {
status
}
struct Config {
current_status: Status,
}
"#,
)
.build();
let enum_id = ctx
.registry
.iter()
.find(|(id, path)| {
path.name() == "Status" && matches!(ctx.registry.kind(*id), Some(SymbolKind::Enum))
})
.map(|(id, _)| id)
.expect("Enum 'Status' should exist");
let mutation = EnumToTraitMutation::from_symbol_id(enum_id)
.with_strategy(EnumToTraitStrategy::Generic);
let result = ASTMutationEngine::execute_ast_reg(&mutation, &mut ctx);
println!("EnumToTrait (Generic) result: {:?}", result.result);
assert!(result.result.changes >= 5, "Expected at least 5 changes");
assert!(
result.result.description.contains("generics"),
"Should mention generics strategy"
);
}
}