use super::response::{CascadeResult, DiscoverError};
use crate::project::Project;
use ryo_analysis::cascade::CascadeSpec;
use ryo_analysis::{
AnalysisContext, DiscoveryEngine, DiscoveryQuery, DiscoveryResult, SymbolKind, SymbolPath,
SymbolRegistry,
};
use ryo_source::pure::{
PureBlock, PureExpr, PureImpl, PureImplItem, PureItem, PurePattern, PureStmt,
};
use std::path::Path;
pub struct DiscoverService {
ctx: AnalysisContext,
}
impl DiscoverService {
pub fn new(ctx: AnalysisContext) -> Self {
Self { ctx }
}
pub fn from_project(project: &Project) -> Result<Self, DiscoverError> {
let ctx = AnalysisContext::from_workspace_root(project.workspace_root())
.map_err(|e| DiscoverError::Project(e.to_string()))?;
Ok(Self { ctx })
}
pub fn from_path(path: &Path) -> Result<Self, DiscoverError> {
let ctx = AnalysisContext::from_workspace_root(path)
.map_err(|e| DiscoverError::Project(e.to_string()))?;
Ok(Self { ctx })
}
pub fn context(&self) -> &AnalysisContext {
&self.ctx
}
pub fn registry(&self) -> &SymbolRegistry {
&self.ctx.registry
}
pub fn discover(&self, query: &DiscoveryQuery) -> DiscoveryResult {
let engine = DiscoveryEngine::new(&self.ctx.code_graph, &self.ctx.registry, None);
engine.execute(query)
}
pub fn find_symbols(&self, pattern: &str) -> DiscoveryResult {
let query = DiscoveryQuery::symbol(pattern);
self.discover(&query)
}
pub fn find_symbols_by_kind(&self, pattern: &str, kind: SymbolKind) -> DiscoveryResult {
let query = DiscoveryQuery::symbol(pattern).kind(kind);
self.discover(&query)
}
pub fn find_with_relations(&self, pattern: &str, depth: usize) -> DiscoveryResult {
let query = DiscoveryQuery::symbol(pattern)
.with_relations()
.relation_depth(depth);
self.discover(&query)
}
pub fn find_cascade_effects(
&self,
enum_pattern: &str,
variant_name: Option<&str>,
variant_type: Option<&str>,
) -> CascadeResult {
find_cascade_effects(&self.ctx.files, enum_pattern, variant_name, variant_type)
}
pub fn find_remove_cascade_effects(
&self,
enum_pattern: &str,
variant_name: &str,
) -> CascadeResult {
find_remove_cascade_effects(&self.ctx.files, enum_pattern, variant_name)
}
}
pub fn find_cascade_effects(
files: &ryo_analysis::ImHashMap<
ryo_symbol::WorkspaceFilePath,
std::sync::Arc<ryo_source::PureFile>,
>,
enum_pattern: &str,
variant_name: Option<&str>,
variant_type: Option<&str>,
) -> CascadeResult {
let variant = variant_name.unwrap_or("NewVariant");
let vtype = variant_type.unwrap_or("unit");
let enum_name = enum_pattern.trim_start_matches('*').trim_end_matches('*');
let mut specs = Vec::new();
for (wfp, file) in files.iter() {
let module_path = wfp_to_symbol_path(wfp);
for func in file.functions() {
let function_name = func.name.clone();
find_match_in_block(
&func.body,
enum_name,
variant,
vtype,
&module_path,
&function_name,
&mut specs,
);
}
for item in &file.items {
if let PureItem::Impl(impl_block) = item {
let impl_path = impl_symbol_path(&module_path, impl_block);
for impl_item in &impl_block.items {
if let PureImplItem::Fn(func) = impl_item {
let function_name = func.name.clone();
find_match_in_block(
&func.body,
enum_name,
variant,
vtype,
&impl_path,
&function_name,
&mut specs,
);
}
}
}
}
}
{
let mut seen = Vec::new();
specs.retain(|spec| {
if let CascadeSpec::AddMatchArm {
target,
function_name,
..
} = spec
{
let key = format!("{}::{}", target, function_name);
if seen.contains(&key) {
false
} else {
seen.push(key);
true
}
} else {
true
}
});
}
CascadeResult {
symbol: enum_pattern.to_string(),
specs,
}
}
pub fn find_remove_cascade_effects(
files: &ryo_analysis::ImHashMap<
ryo_symbol::WorkspaceFilePath,
std::sync::Arc<ryo_source::PureFile>,
>,
enum_pattern: &str,
variant_name: &str,
) -> CascadeResult {
let enum_name = enum_pattern.trim_start_matches('*').trim_end_matches('*');
let mut specs = Vec::new();
for (wfp, file) in files.iter() {
let module_path = wfp_to_symbol_path(wfp);
for func in file.functions() {
let function_name = func.name.clone();
find_removable_arms_in_block(
&func.body,
enum_name,
variant_name,
&module_path,
&function_name,
&mut specs,
);
}
for item in &file.items {
if let PureItem::Impl(impl_block) = item {
let impl_path = impl_symbol_path(&module_path, impl_block);
for impl_item in &impl_block.items {
if let PureImplItem::Fn(func) = impl_item {
let function_name = func.name.clone();
find_removable_arms_in_block(
&func.body,
enum_name,
variant_name,
&impl_path,
&function_name,
&mut specs,
);
}
}
}
}
}
CascadeResult {
symbol: enum_pattern.to_string(),
specs,
}
}
fn wfp_to_symbol_path(wfp: &ryo_symbol::WorkspaceFilePath) -> SymbolPath {
let path_str = wfp.as_relative().to_string_lossy();
let crate_name = wfp.crate_name();
let relative = if let Some(idx) = path_str.find("/src/") {
&path_str[idx + 5..]
} else if let Some(idx) = path_str.find("src/") {
&path_str[idx + 4..]
} else {
return SymbolPath::parse(crate_name.as_str()).unwrap_or_else(|_| {
SymbolPath::builder(crate_name.as_str())
.build()
.expect("valid crate name")
});
};
let module_path = relative.trim_end_matches(".rs").replace('/', "::");
let module_path = if module_path == "lib" || module_path.ends_with("::mod") {
let trimmed = module_path.trim_end_matches("::mod");
if trimmed.is_empty() || trimmed == "lib" {
crate_name.as_str().to_string()
} else {
format!("{}::{}", crate_name.as_str(), trimmed)
}
} else {
format!("{}::{}", crate_name.as_str(), module_path)
};
SymbolPath::parse(&module_path).unwrap_or_else(|_| {
SymbolPath::builder(crate_name.as_str())
.build()
.expect("valid crate name")
})
}
fn impl_symbol_path(module_path: &SymbolPath, impl_block: &PureImpl) -> SymbolPath {
let segment = if let Some(trait_name) = &impl_block.trait_ {
format!("<impl {} for {}>", trait_name, impl_block.self_ty)
} else {
impl_block.self_ty.clone()
};
module_path
.child(&segment)
.unwrap_or_else(|_| module_path.clone())
}
fn generate_match_pattern(enum_name: &str, variant_name: &str, variant_type: &str) -> String {
if variant_type == "unit" || variant_type.is_empty() {
format!("{}::{}", enum_name, variant_name)
} else if let Some(types_part) = variant_type.strip_prefix("tuple:") {
let count = count_tuple_elements(types_part);
if count == 0 {
format!("{}::{}", enum_name, variant_name)
} else {
let wildcards = std::iter::repeat_n("_", count)
.collect::<Vec<_>>()
.join(", ");
format!("{}::{}({})", enum_name, variant_name, wildcards)
}
} else if let Some(fields_part) = variant_type.strip_prefix("struct:") {
let field_patterns: Vec<&str> = fields_part
.split(',')
.filter_map(|f| {
let name = f.split(':').next()?;
if name.is_empty() {
None
} else {
Some(name)
}
})
.collect();
if field_patterns.is_empty() {
format!("{}::{} {{ .. }}", enum_name, variant_name)
} else {
let pattern = field_patterns
.iter()
.map(|f| format!("{}: _", f))
.collect::<Vec<_>>()
.join(", ");
format!("{}::{} {{ {} }}", enum_name, variant_name, pattern)
}
} else {
format!("{}::{}", enum_name, variant_name)
}
}
fn count_tuple_elements(types_str: &str) -> usize {
if types_str.is_empty() {
return 0;
}
let mut count = 1usize; let mut angle_depth = 0usize;
for c in types_str.chars() {
match c {
'<' => angle_depth += 1,
'>' => angle_depth = angle_depth.saturating_sub(1),
',' if angle_depth == 0 => count += 1,
_ => {}
}
}
count
}
fn find_match_in_block(
block: &PureBlock,
enum_name: &str,
variant_name: &str,
variant_type: &str,
module_path: &SymbolPath,
function_name: &str,
specs: &mut Vec<CascadeSpec>,
) {
for stmt in &block.stmts {
match stmt {
PureStmt::Expr(expr) | PureStmt::Semi(expr) => {
find_match_in_expr(
expr,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureStmt::Local { init, .. } => {
if let Some(init_expr) = init {
find_match_in_expr(
init_expr,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
}
PureStmt::Item(_) => {}
}
}
}
fn find_match_in_expr(
expr: &PureExpr,
enum_name: &str,
variant_name: &str,
variant_type: &str,
module_path: &SymbolPath,
function_name: &str,
specs: &mut Vec<CascadeSpec>,
) {
match expr {
PureExpr::Match {
expr: scrutinee,
arms,
} => {
let matches_enum = arms
.iter()
.any(|arm| pattern_contains_enum(&arm.pattern, enum_name));
if matches_enum {
let pattern = generate_match_pattern(enum_name, variant_name, variant_type);
specs.push(CascadeSpec::add_match_arm(
module_path.clone(),
function_name,
enum_name,
pattern,
"todo!()",
));
}
find_match_in_expr(
scrutinee,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
for arm in arms {
find_match_in_expr(
&arm.body,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
}
PureExpr::Block { block, .. } => {
find_match_in_block(
block,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::If {
cond,
then_branch,
else_branch,
} => {
find_match_in_expr(
cond,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
find_match_in_block(
then_branch,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
if let Some(else_expr) = else_branch {
find_match_in_expr(
else_expr,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
}
PureExpr::Loop { body: block, .. } | PureExpr::Unsafe(block) => {
find_match_in_block(
block,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::Async { body, .. } => {
find_match_in_block(
body,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::While { cond, body, .. } => {
find_match_in_expr(
cond,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
find_match_in_block(
body,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::For {
expr: iter, body, ..
} => {
find_match_in_expr(
iter,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
find_match_in_block(
body,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::Closure { body, .. } => {
find_match_in_expr(
body,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::Call { func, args } => {
find_match_in_expr(
func,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
for arg in args {
find_match_in_expr(
arg,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
}
PureExpr::MethodCall { receiver, args, .. } => {
find_match_in_expr(
receiver,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
for arg in args {
find_match_in_expr(
arg,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
}
PureExpr::Binary { left, right, .. } => {
find_match_in_expr(
left,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
find_match_in_expr(
right,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::Unary { expr, .. } => {
find_match_in_expr(
expr,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::Field { expr: base, .. } | PureExpr::Index { expr: base, .. } => {
find_match_in_expr(
base,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::Ref { expr, .. }
| PureExpr::Try(expr)
| PureExpr::Return(Some(expr))
| PureExpr::Await(expr) => {
find_match_in_expr(
expr,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::Tuple(exprs) | PureExpr::Array(exprs) => {
for e in exprs {
find_match_in_expr(
e,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
}
PureExpr::Struct { fields, .. } => {
for (_, field_expr) in fields {
find_match_in_expr(
field_expr,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
}
PureExpr::Range { start, end, .. } => {
if let Some(s) = start {
find_match_in_expr(
s,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
if let Some(e) = end {
find_match_in_expr(
e,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
}
PureExpr::Let { expr, .. } => {
find_match_in_expr(
expr,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::Cast { expr, .. } => {
find_match_in_expr(
expr,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::Repeat { expr, len } => {
find_match_in_expr(
expr,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
find_match_in_expr(
len,
enum_name,
variant_name,
variant_type,
module_path,
function_name,
specs,
);
}
PureExpr::Lit(_)
| PureExpr::Path(_)
| PureExpr::Return(None)
| PureExpr::Break { .. }
| PureExpr::Continue { .. }
| PureExpr::Macro { .. }
| PureExpr::Other(_) => {}
}
}
pub(crate) fn pattern_contains_enum(pattern: &PurePattern, enum_name: &str) -> bool {
match pattern {
PurePattern::Path(path) => path_has_enum_segment(path, enum_name),
PurePattern::Struct { path, .. } => path_has_enum_segment(path, enum_name),
PurePattern::Or(patterns) => patterns.iter().any(|p| pattern_contains_enum(p, enum_name)),
PurePattern::Other(s) => {
let path_part = s.split(&['(', '{', ' '][..]).next().unwrap_or(s);
path_has_enum_segment(path_part, enum_name)
}
_ => false,
}
}
fn path_has_enum_segment(path: &str, enum_name: &str) -> bool {
path.split("::").any(|segment| segment == enum_name)
}
fn find_removable_arms_in_block(
block: &PureBlock,
enum_name: &str,
variant_name: &str,
module_path: &SymbolPath,
function_name: &str,
specs: &mut Vec<CascadeSpec>,
) {
for stmt in &block.stmts {
match stmt {
PureStmt::Expr(expr) | PureStmt::Semi(expr) => {
find_removable_arms_in_expr(
expr,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
PureStmt::Local { init, .. } => {
if let Some(init_expr) = init {
find_removable_arms_in_expr(
init_expr,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
}
PureStmt::Item(_) => {}
}
}
}
fn find_removable_arms_in_expr(
expr: &PureExpr,
enum_name: &str,
variant_name: &str,
module_path: &SymbolPath,
function_name: &str,
specs: &mut Vec<CascadeSpec>,
) {
match expr {
PureExpr::Match {
expr: scrutinee,
arms,
} => {
let matches_enum = arms
.iter()
.any(|arm| pattern_contains_enum(&arm.pattern, enum_name));
if matches_enum {
for arm in arms {
if pattern_contains_variant(&arm.pattern, enum_name, variant_name) {
let pattern_str = format_pattern(&arm.pattern);
specs.push(CascadeSpec::remove_match_arm(
module_path.clone(),
function_name,
enum_name,
pattern_str,
));
}
}
}
find_removable_arms_in_expr(
scrutinee,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
for arm in arms {
find_removable_arms_in_expr(
&arm.body,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
}
PureExpr::Block { block, .. } => {
find_removable_arms_in_block(
block,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
PureExpr::If {
cond,
then_branch,
else_branch,
} => {
find_removable_arms_in_expr(
cond,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
find_removable_arms_in_block(
then_branch,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
if let Some(else_expr) = else_branch {
find_removable_arms_in_expr(
else_expr,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
}
PureExpr::Loop { body: block, .. } | PureExpr::Unsafe(block) => {
find_removable_arms_in_block(
block,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
PureExpr::Async { body, .. } => {
find_removable_arms_in_block(
body,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
PureExpr::While { cond, body, .. } => {
find_removable_arms_in_expr(
cond,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
find_removable_arms_in_block(
body,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
PureExpr::For {
expr: iter, body, ..
} => {
find_removable_arms_in_expr(
iter,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
find_removable_arms_in_block(
body,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
PureExpr::Closure { body, .. } => {
find_removable_arms_in_expr(
body,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
PureExpr::Call { func, args } => {
find_removable_arms_in_expr(
func,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
for arg in args {
find_removable_arms_in_expr(
arg,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
}
PureExpr::MethodCall { receiver, args, .. } => {
find_removable_arms_in_expr(
receiver,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
for arg in args {
find_removable_arms_in_expr(
arg,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
}
PureExpr::Binary { left, right, .. } => {
find_removable_arms_in_expr(
left,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
find_removable_arms_in_expr(
right,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
PureExpr::Unary { expr, .. }
| PureExpr::Field { expr, .. }
| PureExpr::Index { expr, .. }
| PureExpr::Ref { expr, .. }
| PureExpr::Try(expr)
| PureExpr::Await(expr)
| PureExpr::Let { expr, .. }
| PureExpr::Cast { expr, .. } => {
find_removable_arms_in_expr(
expr,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
PureExpr::Return(Some(expr)) => {
find_removable_arms_in_expr(
expr,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
PureExpr::Tuple(exprs) | PureExpr::Array(exprs) => {
for e in exprs {
find_removable_arms_in_expr(
e,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
}
PureExpr::Struct { fields, .. } => {
for (_, field_expr) in fields {
find_removable_arms_in_expr(
field_expr,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
}
PureExpr::Range { start, end, .. } => {
if let Some(s) = start {
find_removable_arms_in_expr(
s,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
if let Some(e) = end {
find_removable_arms_in_expr(
e,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
}
PureExpr::Repeat { expr, len } => {
find_removable_arms_in_expr(
expr,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
find_removable_arms_in_expr(
len,
enum_name,
variant_name,
module_path,
function_name,
specs,
);
}
PureExpr::Lit(_)
| PureExpr::Path(_)
| PureExpr::Return(None)
| PureExpr::Break { .. }
| PureExpr::Continue { .. }
| PureExpr::Macro { .. }
| PureExpr::Other(_) => {}
}
}
fn pattern_contains_variant(pattern: &PurePattern, enum_name: &str, variant_name: &str) -> bool {
let target = format!("{}::{}", enum_name, variant_name);
match pattern {
PurePattern::Path(path) => path.contains(&target),
PurePattern::Struct { path, .. } => path.contains(&target),
PurePattern::Or(patterns) => patterns
.iter()
.any(|p| pattern_contains_variant(p, enum_name, variant_name)),
_ => false,
}
}
fn format_pattern(pattern: &PurePattern) -> String {
match pattern {
PurePattern::Path(path) => path.clone(),
PurePattern::Wild => "_".to_string(),
PurePattern::Ident { name, .. } => name.clone(),
PurePattern::Struct {
path, fields, rest, ..
} => {
let field_strs: Vec<String> = fields
.iter()
.map(|(name, pat)| {
let pat_str = format_pattern(pat);
if pat_str == *name {
name.clone()
} else {
format!("{}: {}", name, pat_str)
}
})
.collect();
if *rest {
format!("{} {{ {}, .. }}", path, field_strs.join(", "))
} else {
format!("{} {{ {} }}", path, field_strs.join(", "))
}
}
PurePattern::Tuple(patterns) => {
let inner: Vec<String> = patterns.iter().map(format_pattern).collect();
format!("({})", inner.join(", "))
}
PurePattern::Or(patterns) => {
let inner: Vec<String> = patterns.iter().map(format_pattern).collect();
inner.join(" | ")
}
PurePattern::Lit(lit) => lit.clone(),
PurePattern::Rest => "..".to_string(),
PurePattern::Ref {
is_mut, pattern, ..
} => {
let inner = format_pattern(pattern);
if *is_mut {
format!("&mut {}", inner)
} else {
format!("&{}", inner)
}
}
PurePattern::Range {
start,
end,
inclusive,
} => {
let s = start.as_deref().unwrap_or("");
let e = end.as_deref().unwrap_or("");
if *inclusive {
format!("{}..={}", s, e)
} else {
format!("{}..{}", s, e)
}
}
PurePattern::Slice(patterns) => {
let inner: Vec<String> = patterns.iter().map(format_pattern).collect();
format!("[{}]", inner.join(", "))
}
PurePattern::Other(s) => s.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
fn create_test_project_with_enum() -> tempfile::TempDir {
let dir = tempdir().unwrap();
let src = dir.path().join("src");
fs::create_dir(&src).unwrap();
fs::write(
dir.path().join("Cargo.toml"),
r#"[package]
name = "test_enum_project"
version = "0.1.0"
edition = "2021"
"#,
)
.unwrap();
fs::write(
src.join("lib.rs"),
r#"
pub enum Status {
Active,
Inactive,
Pending,
}
pub fn handle_status(status: Status) -> &'static str {
match status {
Status::Active => "active",
Status::Inactive => "inactive",
Status::Pending => "pending",
}
}
pub struct User {
pub name: String,
pub status: Status,
}
impl User {
pub fn status_label(&self) -> &'static str {
match self.status {
Status::Active => "A",
Status::Inactive => "I",
Status::Pending => "P",
}
}
}
"#,
)
.unwrap();
dir
}
fn create_test_project_for_discovery() -> tempfile::TempDir {
let dir = tempdir().unwrap();
let src = dir.path().join("src");
fs::create_dir(&src).unwrap();
fs::write(
dir.path().join("Cargo.toml"),
r#"[package]
name = "test_project"
version = "0.1.0"
edition = "2021"
"#,
)
.unwrap();
fs::write(
src.join("lib.rs"),
r#"
pub fn foo() {}
pub fn bar() {}
pub fn foobar() {}
pub struct Foo;
pub struct Bar;
pub enum MyEnum {
Variant1,
Variant2,
}
pub trait MyTrait {
fn do_something(&self);
}
"#,
)
.unwrap();
dir
}
#[test]
fn test_from_path() {
let dir = create_test_project_for_discovery();
let result = DiscoverService::from_path(dir.path());
assert!(result.is_ok());
}
#[test]
fn test_from_path_nonexistent() {
let result = DiscoverService::from_path(Path::new("/nonexistent/path"));
assert!(result.is_err());
}
#[test]
fn test_from_project() {
let dir = create_test_project_for_discovery();
let project = Project::load(dir.path()).unwrap();
let service = DiscoverService::from_project(&project).unwrap();
assert!(!service.registry().is_empty());
}
#[test]
fn test_find_symbols_exact_match() {
let dir = create_test_project_for_discovery();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_symbols("foo");
assert!(!result.symbols.is_empty());
assert!(result.symbols.iter().any(|s| s.path.name() == "foo"));
}
#[test]
fn test_find_symbols_wildcard() {
let dir = create_test_project_for_discovery();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_symbols("*foo*");
assert!(result.symbols.len() >= 2);
}
#[test]
fn test_find_symbols_by_kind_function() {
let dir = create_test_project_for_discovery();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_symbols_by_kind("*", SymbolKind::Function);
assert!(result
.symbols
.iter()
.all(|s| s.kind == SymbolKind::Function));
}
#[test]
fn test_find_symbols_by_kind_struct() {
let dir = create_test_project_for_discovery();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_symbols_by_kind("*", SymbolKind::Struct);
assert!(result.symbols.iter().all(|s| s.kind == SymbolKind::Struct));
}
#[test]
fn test_find_symbols_by_kind_enum() {
let dir = create_test_project_for_discovery();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_symbols_by_kind("*", SymbolKind::Enum);
assert!(!result.symbols.is_empty());
assert!(result.symbols.iter().any(|s| s.path.name() == "MyEnum"));
}
#[test]
fn test_find_symbols_no_match() {
let dir = create_test_project_for_discovery();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_symbols("nonexistent_symbol");
assert!(result.symbols.is_empty());
}
#[test]
fn test_find_cascade_effects_basic() {
let dir = create_test_project_with_enum();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_cascade_effects("Status", Some("Cancelled"), None);
assert_eq!(result.symbol, "Status");
assert_eq!(result.len(), 2);
}
#[test]
fn test_find_cascade_effects_default_variant() {
let dir = create_test_project_with_enum();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_cascade_effects("Status", None, None);
assert!(!result.is_empty());
}
#[test]
fn test_find_cascade_effects_no_match() {
let dir = create_test_project_with_enum();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_cascade_effects("NonExistentEnum", Some("Variant"), None);
assert!(result.is_empty());
}
#[test]
fn test_find_cascade_effects_wildcard_stripped() {
let dir = create_test_project_with_enum();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_cascade_effects("*Status*", Some("Cancelled"), None);
assert!(!result.is_empty());
}
#[test]
fn test_cascade_spec_target_paths_include_function() {
let dir = create_test_project_with_enum();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_cascade_effects("Status", Some("Cancelled"), None);
assert_eq!(result.len(), 2, "Expected 2 cascade specs");
let mut entries: Vec<(String, String)> = result
.specs
.iter()
.map(|spec| match spec {
CascadeSpec::AddMatchArm {
target,
function_name,
..
} => (target.to_string(), function_name.clone()),
other => panic!("Expected AddMatchArm, got {:?}", other),
})
.collect();
entries.sort();
assert!(
entries.iter().any(|(_t, f)| f == "handle_status"),
"Should have handle_status cascade: {:?}",
entries
);
let method_entry = entries
.iter()
.find(|(_, f)| f == "status_label")
.expect("Should have status_label cascade");
assert!(
method_entry.0.contains("User"),
"Method target should include type name 'User', got: {}",
method_entry.0
);
}
#[test]
fn test_cascade_to_intent_constructs_full_path() {
use crate::intent::Intent;
let dir = create_test_project_with_enum();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_cascade_effects("Status", Some("Cancelled"), None);
assert_eq!(result.len(), 2);
let intents: Vec<Intent> = result.specs.into_iter().map(Intent::from).collect();
for intent in &intents {
match intent {
Intent::AddMatchArm {
symbol_path,
target_fn,
..
} => {
let path = symbol_path.as_ref().expect("symbol_path should be set");
assert!(
path.contains("handle_status") || path.contains("status_label"),
"symbol_path should contain function name, got: {}",
path
);
assert!(
target_fn.is_none(),
"target_fn should be None when symbol_path has full path"
);
}
other => panic!("Expected AddMatchArm intent, got {:?}", other),
}
}
let method_intent = intents.iter().find(|i| {
matches!(i, Intent::AddMatchArm { symbol_path: Some(p), .. } if p.contains("status_label"))
}).expect("Should have status_label intent");
if let Intent::AddMatchArm {
symbol_path: Some(path),
..
} = method_intent
{
assert!(
path.contains("User"),
"Method symbol_path should include type name, got: {}",
path
);
}
}
#[test]
fn test_generate_match_pattern_unit() {
assert_eq!(
generate_match_pattern("Status", "Active", "unit"),
"Status::Active"
);
assert_eq!(
generate_match_pattern("Status", "Active", ""),
"Status::Active"
);
}
#[test]
fn test_generate_match_pattern_tuple() {
assert_eq!(
generate_match_pattern("Result", "Ok", "tuple:T"),
"Result::Ok(_)"
);
assert_eq!(
generate_match_pattern("Option", "Some", "tuple:i32,String"),
"Option::Some(_, _)"
);
assert_eq!(
generate_match_pattern("PathSegment", "Slice", "tuple:Option<usize>,Option<usize>"),
"PathSegment::Slice(_, _)"
);
}
#[test]
fn test_generate_match_pattern_struct() {
assert_eq!(
generate_match_pattern(
"PathSegment",
"Slice",
"struct:start:Option<usize>,end:Option<usize>"
),
"PathSegment::Slice { start: _, end: _ }"
);
assert_eq!(
generate_match_pattern("Error", "Custom", "struct:code:u32"),
"Error::Custom { code: _ }"
);
}
#[test]
fn test_generate_match_pattern_struct_empty() {
assert_eq!(
generate_match_pattern("Foo", "Bar", "struct:"),
"Foo::Bar { .. }"
);
}
#[test]
fn test_pattern_contains_enum_path() {
let pattern = PurePattern::Path("Status::Active".to_string());
assert!(pattern_contains_enum(&pattern, "Status"));
assert!(!pattern_contains_enum(&pattern, "Other"));
}
#[test]
fn test_pattern_contains_enum_struct() {
let pattern = PurePattern::Struct {
path: "MyStruct".to_string(),
fields: vec![],
rest: false,
};
assert!(pattern_contains_enum(&pattern, "MyStruct"));
assert!(!pattern_contains_enum(&pattern, "Other"));
}
#[test]
fn test_pattern_contains_enum_or() {
let pattern = PurePattern::Or(vec![
PurePattern::Path("Status::Active".to_string()),
PurePattern::Path("Status::Inactive".to_string()),
]);
assert!(pattern_contains_enum(&pattern, "Status"));
assert!(!pattern_contains_enum(&pattern, "Other"));
}
#[test]
fn test_pattern_contains_enum_wild() {
let pattern = PurePattern::Wild;
assert!(!pattern_contains_enum(&pattern, "Status"));
}
#[test]
fn test_cascade_result_new() {
let result = CascadeResult::new("TestEnum".to_string());
assert_eq!(result.symbol, "TestEnum");
assert!(result.is_empty());
assert_eq!(result.len(), 0);
}
#[test]
fn test_count_tuple_elements_empty() {
assert_eq!(count_tuple_elements(""), 0);
}
#[test]
fn test_count_tuple_elements_single() {
assert_eq!(count_tuple_elements("i32"), 1);
assert_eq!(count_tuple_elements("String"), 1);
}
#[test]
fn test_count_tuple_elements_multiple() {
assert_eq!(count_tuple_elements("i32,String"), 2);
assert_eq!(count_tuple_elements("i32,i32,i32"), 3);
}
#[test]
fn test_count_tuple_elements_with_generics() {
assert_eq!(count_tuple_elements("Option<usize>"), 1);
assert_eq!(count_tuple_elements("Option<usize>,Option<usize>"), 2);
assert_eq!(count_tuple_elements("HashMap<K, V>,String"), 2);
assert_eq!(count_tuple_elements("Result<Vec<i32>, Error>"), 1);
}
#[test]
fn test_count_tuple_elements_nested_generics() {
assert_eq!(count_tuple_elements("Option<Result<i32, E>>,String"), 2);
assert_eq!(count_tuple_elements("Vec<HashMap<K, V>>,Option<T>"), 2);
}
#[test]
fn test_pattern_contains_variant_path_match() {
let pat = PurePattern::Path("Status::Active".to_string());
assert!(pattern_contains_variant(&pat, "Status", "Active"));
}
#[test]
fn test_pattern_contains_variant_path_no_match() {
let pat = PurePattern::Path("Status::Active".to_string());
assert!(!pattern_contains_variant(&pat, "Status", "Inactive"));
}
#[test]
fn test_pattern_contains_variant_struct_pattern() {
let pat = PurePattern::Struct {
path: "Status::Active".to_string(),
fields: vec![],
rest: false,
};
assert!(pattern_contains_variant(&pat, "Status", "Active"));
}
#[test]
fn test_pattern_contains_variant_or_pattern() {
let pat = PurePattern::Or(vec![
PurePattern::Path("Status::Active".to_string()),
PurePattern::Path("Status::Pending".to_string()),
]);
assert!(pattern_contains_variant(&pat, "Status", "Active"));
assert!(pattern_contains_variant(&pat, "Status", "Pending"));
assert!(!pattern_contains_variant(&pat, "Status", "Inactive"));
}
#[test]
fn test_pattern_contains_variant_wildcard() {
let pat = PurePattern::Wild;
assert!(!pattern_contains_variant(&pat, "Status", "Active"));
}
#[test]
fn test_format_pattern_path() {
let pat = PurePattern::Path("Status::Active".to_string());
assert_eq!(format_pattern(&pat), "Status::Active");
}
#[test]
fn test_format_pattern_wild() {
assert_eq!(format_pattern(&PurePattern::Wild), "_");
}
#[test]
fn test_format_pattern_ident() {
let pat = PurePattern::Ident {
name: "x".to_string(),
is_mut: false,
};
assert_eq!(format_pattern(&pat), "x");
}
#[test]
fn test_format_pattern_struct_with_rest() {
let pat = PurePattern::Struct {
path: "Point".to_string(),
fields: vec![(
"x".to_string(),
PurePattern::Ident {
name: "x".to_string(),
is_mut: false,
},
)],
rest: true,
};
assert_eq!(format_pattern(&pat), "Point { x, .. }");
}
#[test]
fn test_format_pattern_tuple() {
let pat = PurePattern::Tuple(vec![
PurePattern::Path("A".to_string()),
PurePattern::Path("B".to_string()),
]);
assert_eq!(format_pattern(&pat), "(A, B)");
}
#[test]
fn test_format_pattern_or() {
let pat = PurePattern::Or(vec![
PurePattern::Path("A".to_string()),
PurePattern::Path("B".to_string()),
]);
assert_eq!(format_pattern(&pat), "A | B");
}
#[test]
fn test_format_pattern_ref() {
let pat = PurePattern::Ref {
is_mut: false,
pattern: Box::new(PurePattern::Ident {
name: "x".to_string(),
is_mut: false,
}),
};
assert_eq!(format_pattern(&pat), "&x");
}
#[test]
fn test_format_pattern_ref_mut() {
let pat = PurePattern::Ref {
is_mut: true,
pattern: Box::new(PurePattern::Ident {
name: "y".to_string(),
is_mut: false,
}),
};
assert_eq!(format_pattern(&pat), "&mut y");
}
#[test]
fn test_format_pattern_range_inclusive() {
let pat = PurePattern::Range {
start: Some("1".to_string()),
end: Some("5".to_string()),
inclusive: true,
};
assert_eq!(format_pattern(&pat), "1..=5");
}
#[test]
fn test_format_pattern_slice() {
let pat = PurePattern::Slice(vec![
PurePattern::Ident {
name: "first".to_string(),
is_mut: false,
},
PurePattern::Rest,
]);
assert_eq!(format_pattern(&pat), "[first, ..]");
}
#[test]
fn test_find_remove_cascade_effects_finds_matching_arms() {
let dir = create_test_project_with_enum();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_remove_cascade_effects("Status", "Pending");
assert!(
result.specs.len() >= 2,
"Expected at least 2 RemoveMatchArm specs, got {}: {:?}",
result.specs.len(),
result.specs
);
for spec in &result.specs {
match spec {
CascadeSpec::RemoveMatchArm {
enum_name, pattern, ..
} => {
assert_eq!(enum_name, "Status");
assert!(
pattern.contains("Pending"),
"Pattern should contain 'Pending': {}",
pattern
);
}
other => panic!("Expected RemoveMatchArm, got {:?}", other),
}
}
}
#[test]
fn test_remove_cascade_target_paths_include_type_for_methods() {
let dir = create_test_project_with_enum();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_remove_cascade_effects("Status", "Pending");
assert!(result.specs.len() >= 2);
let method_spec = result.specs.iter().find(|spec| {
matches!(spec, CascadeSpec::RemoveMatchArm { function_name, .. } if function_name == "status_label")
}).expect("Should find status_label RemoveMatchArm");
if let CascadeSpec::RemoveMatchArm { target, .. } = method_spec {
assert!(
target.to_string().contains("User"),
"Method target should include 'User', got: {}",
target
);
}
use crate::intent::Intent;
let intents: Vec<Intent> = result.specs.into_iter().map(Intent::from).collect();
let method_intent = intents.iter().find(|i| {
matches!(i, Intent::RemoveMatchArm { symbol_path: Some(p), .. } if p.contains("status_label"))
}).expect("Should have status_label RemoveMatchArm intent");
if let Intent::RemoveMatchArm {
symbol_path: Some(path),
target_fn,
..
} = method_intent
{
assert!(
path.contains("User") && path.contains("status_label"),
"Full method path should contain both User and status_label, got: {}",
path
);
assert!(target_fn.is_none(), "target_fn should be None");
}
}
#[test]
fn test_find_remove_cascade_effects_no_match_for_nonexistent_variant() {
let dir = create_test_project_with_enum();
let service = DiscoverService::from_path(dir.path()).unwrap();
let result = service.find_remove_cascade_effects("Status", "NonExistent");
assert!(
result.specs.is_empty(),
"Expected no specs for nonexistent variant, got {}: {:?}",
result.specs.len(),
result.specs
);
}
#[test]
fn test_path_segment_exact() {
assert!(path_has_enum_segment("Filter::Recurse", "Filter"));
assert!(path_has_enum_segment("Filter", "Filter"));
}
#[test]
fn test_path_segment_no_substring() {
assert!(!path_has_enum_segment("FilterKind::Inclusive", "Filter"));
}
#[test]
fn test_pattern_contains_enum_segment_level() {
let pat = PurePattern::Path("FilterKind::Inclusive".to_string());
assert!(!pattern_contains_enum(&pat, "Filter"));
assert!(pattern_contains_enum(&pat, "FilterKind"));
}
}