use ryo_source::pure::{PureImpl, PureItem};
use ryo_symbol::{SymbolKind, SymbolPath};
use crate::engine::ASTMutationContext;
#[derive(Debug)]
pub struct ImplRegistrationResult {
pub methods_added: usize,
pub description: String,
}
pub fn register_impl_block(
ctx: &mut ASTMutationContext,
module_path: &SymbolPath,
impl_block: &PureImpl,
) -> Result<ImplRegistrationResult, String> {
let module_str = module_path.to_string();
let base_self_ty = strip_generics(&impl_block.self_ty);
let (impl_path_str, method_base_path) = if let Some(ref trait_name) = impl_block.trait_ {
let path = format!(
"{}::<impl {} for {}>",
module_str, trait_name, impl_block.self_ty
);
(path.clone(), path)
} else {
let impl_path = format!("{}::<impl {}>", module_str, impl_block.self_ty);
let method_path = format!("{}::{}", module_str, base_self_ty);
(impl_path, method_path)
};
let impl_path = SymbolPath::parse(&impl_path_str)
.map_err(|e| format!("Invalid impl path '{}': {:?}", impl_path_str, e))?;
let impl_id = if let Some(existing_id) = ctx.symbol_registry.lookup(&impl_path) {
if let Some(PureItem::Impl(existing_impl)) = ctx.ast_registry.get_mut(existing_id) {
existing_impl.items.extend(impl_block.items.clone());
}
existing_id
} else {
ctx.register_with_ast(
impl_path.clone(),
SymbolKind::Impl,
PureItem::Impl(impl_block.clone()),
)
.ok_or_else(|| format!("Failed to register impl block for '{}'", impl_block.self_ty))?
};
let mut methods_added = 0;
for impl_item in &impl_block.items {
let (item_name, item_kind, pure_item) = match impl_item {
ryo_source::pure::PureImplItem::Fn(m) => {
(m.name.clone(), SymbolKind::Method, PureItem::Fn(m.clone()))
}
ryo_source::pure::PureImplItem::Const(c) => (
c.name.clone(),
SymbolKind::Const,
PureItem::Const(c.clone()),
),
ryo_source::pure::PureImplItem::Type(t) => (
t.name.clone(),
SymbolKind::TypeAlias,
PureItem::Type(t.clone()),
),
ryo_source::pure::PureImplItem::Other(_) => continue,
};
let item_path_str = format!("{}::{}", method_base_path, item_name);
let item_path = SymbolPath::parse(&item_path_str)
.map_err(|e| format!("Invalid method path '{}': {:?}", item_path_str, e))?;
if ctx
.register_with_ast(item_path, item_kind, pure_item)
.is_some()
{
methods_added += 1;
}
}
if let Some(module_id) = ctx.symbol_registry.lookup(module_path) {
let updated_impl = match ctx.ast_registry.get(impl_id) {
Some(PureItem::Impl(i)) => i.clone(),
_ => impl_block.clone(), };
let mut items = ctx
.ast_registry
.get_module_items(module_id)
.cloned()
.unwrap_or_default();
let impl_key = (impl_block.self_ty.clone(), impl_block.trait_.clone());
if let Some(pos) = items.iter().position(|item| {
if let PureItem::Impl(i) = item {
(i.self_ty.clone(), i.trait_.clone()) == impl_key
} else {
false
}
}) {
items[pos] = PureItem::Impl(updated_impl);
} else {
items.push(PureItem::Impl(updated_impl));
}
ctx.ast_registry.set_module_items(module_id, items);
}
let description = if methods_added > 0 {
format!(
"Added impl block for '{}' with {} method(s)",
impl_block.self_ty, methods_added
)
} else {
format!("Added empty impl block for '{}'", impl_block.self_ty)
};
Ok(ImplRegistrationResult {
methods_added,
description,
})
}
fn strip_generics(type_name: &str) -> &str {
match type_name.find('<') {
Some(pos) => type_name[..pos].trim(),
None => type_name.trim(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use ryo_analysis::{ASTRegistry, SymbolRegistry};
use ryo_source::pure::PureFile;
#[test]
fn test_register_plain_impl_block() {
let mut ast_registry = ASTRegistry::new();
let mut symbol_registry = SymbolRegistry::new();
let mut ctx = ASTMutationContext::new(&mut ast_registry, &mut symbol_registry);
let struct_path = SymbolPath::parse("test_crate::Counter").unwrap();
ctx.register(struct_path.clone(), SymbolKind::Struct);
let module_path = SymbolPath::parse("test_crate").unwrap();
let code = r#"
impl Counter {
pub fn new() -> Self {
Self { count: 0 }
}
pub fn increment(&mut self) {
self.count += 1;
}
}
"#;
let file = PureFile::from_source(code).unwrap();
let impl_block = match &file.items[0] {
PureItem::Impl(i) => i.clone(),
_ => panic!("Expected impl block"),
};
let result = register_impl_block(&mut ctx, &module_path, &impl_block).unwrap();
assert_eq!(result.methods_added, 2);
let impl_path = SymbolPath::parse("test_crate::<impl Counter>").unwrap();
let impl_id = ctx.lookup(&impl_path);
assert!(impl_id.is_some(), "Impl block should be registered");
assert_eq!(ctx.kind(impl_id.unwrap()), Some(SymbolKind::Impl));
let new_path = SymbolPath::parse("test_crate::Counter::new").unwrap();
let inc_path = SymbolPath::parse("test_crate::Counter::increment").unwrap();
assert!(ctx.lookup(&new_path).is_some(), "Counter::new should exist");
assert!(
ctx.lookup(&inc_path).is_some(),
"Counter::increment should exist"
);
}
#[test]
fn test_register_trait_impl_block() {
let mut ast_registry = ASTRegistry::new();
let mut symbol_registry = SymbolRegistry::new();
let mut ctx = ASTMutationContext::new(&mut ast_registry, &mut symbol_registry);
let module_path = SymbolPath::parse("test_crate").unwrap();
let code = r#"
impl Display for Counter {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.count)
}
}
"#;
let file = PureFile::from_source(code).unwrap();
let impl_block = match &file.items[0] {
PureItem::Impl(i) => i.clone(),
_ => panic!("Expected impl block"),
};
let result = register_impl_block(&mut ctx, &module_path, &impl_block).unwrap();
assert_eq!(result.methods_added, 1);
let impl_path = SymbolPath::parse("test_crate::<impl Display for Counter>").unwrap();
assert!(
ctx.lookup(&impl_path).is_some(),
"Trait impl should be registered"
);
let fmt_path = SymbolPath::parse("test_crate::<impl Display for Counter>::fmt").unwrap();
assert!(ctx.lookup(&fmt_path).is_some(), "fmt method should exist");
}
#[test]
fn test_register_empty_impl_block() {
let mut ast_registry = ASTRegistry::new();
let mut symbol_registry = SymbolRegistry::new();
let mut ctx = ASTMutationContext::new(&mut ast_registry, &mut symbol_registry);
let module_path = SymbolPath::parse("test_crate").unwrap();
let code = "impl MyType {}";
let file = PureFile::from_source(code).unwrap();
let impl_block = match &file.items[0] {
PureItem::Impl(i) => i.clone(),
_ => panic!("Expected impl block"),
};
let result = register_impl_block(&mut ctx, &module_path, &impl_block).unwrap();
assert_eq!(result.methods_added, 0);
assert!(result.description.contains("empty"));
}
#[test]
fn test_register_multiple_impl_blocks_merge() {
let mut ast_registry = ASTRegistry::new();
let mut symbol_registry = SymbolRegistry::new();
let mut ctx = ASTMutationContext::new(&mut ast_registry, &mut symbol_registry);
ctx.register(SymbolPath::parse("test_crate").unwrap(), SymbolKind::Mod);
let module_path = SymbolPath::parse("test_crate").unwrap();
let code1 = r#"
impl TodoList {
pub fn add(&mut self, item: String) {
self.items.push(item);
}
}
"#;
let file1 = PureFile::from_source(code1).unwrap();
let impl1 = match &file1.items[0] {
PureItem::Impl(i) => i.clone(),
_ => panic!("Expected impl block"),
};
let result1 = register_impl_block(&mut ctx, &module_path, &impl1).unwrap();
assert_eq!(result1.methods_added, 1);
let code2 = r#"
impl TodoList {
pub fn new() -> Self {
Self::default()
}
}
"#;
let file2 = PureFile::from_source(code2).unwrap();
let impl2 = match &file2.items[0] {
PureItem::Impl(i) => i.clone(),
_ => panic!("Expected impl block"),
};
let result2 = register_impl_block(&mut ctx, &module_path, &impl2).unwrap();
assert_eq!(result2.methods_added, 1);
let add_path = SymbolPath::parse("test_crate::TodoList::add").unwrap();
let new_path = SymbolPath::parse("test_crate::TodoList::new").unwrap();
assert!(ctx.lookup(&add_path).is_some(), "add method should exist");
assert!(ctx.lookup(&new_path).is_some(), "new method should exist");
let impl_path = SymbolPath::parse("test_crate::<impl TodoList>").unwrap();
let impl_id = ctx.lookup(&impl_path).unwrap();
if let Some(PureItem::Impl(merged_impl)) = ctx.get_ast(impl_id) {
assert_eq!(
merged_impl.items.len(),
2,
"Impl block should have 2 methods after merge"
);
} else {
panic!("Expected merged impl block");
}
}
#[test]
fn test_register_generic_impl_block() {
let mut ast_registry = ASTRegistry::new();
let mut symbol_registry = SymbolRegistry::new();
let mut ctx = ASTMutationContext::new(&mut ast_registry, &mut symbol_registry);
ctx.register(SymbolPath::parse("test_crate").unwrap(), SymbolKind::Mod);
let struct_path = SymbolPath::parse("test_crate::SagaOrchestrator").unwrap();
ctx.register(struct_path.clone(), SymbolKind::Struct);
let module_path = SymbolPath::parse("test_crate").unwrap();
let code = r#"
impl<C, E: Debug> SagaOrchestrator<C, E> {
pub fn new(name: impl Into<String>) -> Self {
Self { name: name.into() }
}
pub fn add_step(&mut self) {
}
}
"#;
let file = PureFile::from_source(code).unwrap();
let impl_block = match &file.items[0] {
PureItem::Impl(i) => i.clone(),
_ => panic!("Expected impl block"),
};
assert!(impl_block.self_ty.starts_with("SagaOrchestrator"));
let result = register_impl_block(&mut ctx, &module_path, &impl_block).unwrap();
assert_eq!(result.methods_added, 2, "Should add 2 methods");
let new_path = SymbolPath::parse("test_crate::SagaOrchestrator::new").unwrap();
let add_step_path = SymbolPath::parse("test_crate::SagaOrchestrator::add_step").unwrap();
assert!(
ctx.lookup(&new_path).is_some(),
"SagaOrchestrator::new should exist"
);
assert!(
ctx.lookup(&add_step_path).is_some(),
"SagaOrchestrator::add_step should exist"
);
}
#[test]
fn test_strip_generics() {
assert_eq!(
strip_generics("SagaOrchestrator<C, E: Debug>"),
"SagaOrchestrator"
);
assert_eq!(strip_generics("Vec<T>"), "Vec");
assert_eq!(strip_generics("HashMap<K, V>"), "HashMap");
assert_eq!(strip_generics("Counter"), "Counter");
assert_eq!(strip_generics("Option<Box<dyn Trait>>"), "Option");
}
}