use std::collections::{HashMap, HashSet, VecDeque};
use std::path::{Path, PathBuf};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use tracing::debug;
use tree_sitter::Node;
use crate::ast::AstExtractor;
use crate::callgraph::scanner::{ProjectScanner, ScanConfig};
use crate::error::{Result, BrrrError};
use crate::lang::LanguageRegistry;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CohesionLevel {
High,
Medium,
Low,
VeryLow,
}
impl CohesionLevel {
#[must_use]
pub fn from_lcom3(lcom3: u32) -> Self {
match lcom3 {
0 | 1 => Self::High,
2 => Self::Medium,
3 | 4 => Self::Low,
_ => Self::VeryLow,
}
}
#[must_use]
pub const fn description(&self) -> &'static str {
match self {
Self::High => "Cohesive class, well-designed",
Self::Medium => "Minor cohesion issue, consider reviewing",
Self::Low => "Low cohesion, consider splitting class",
Self::VeryLow => "Very low cohesion, strongly recommend splitting",
}
}
#[must_use]
pub const fn color_code(&self) -> &'static str {
match self {
Self::High => "\x1b[32m", Self::Medium => "\x1b[33m", Self::Low => "\x1b[31m", Self::VeryLow => "\x1b[35m", }
}
}
impl std::fmt::Display for CohesionLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::High => write!(f, "high"),
Self::Medium => write!(f, "medium"),
Self::Low => write!(f, "low"),
Self::VeryLow => write!(f, "very_low"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CohesionMetrics {
pub class_name: String,
pub file: PathBuf,
pub line: usize,
pub end_line: usize,
pub lcom1: u32,
pub lcom2: i32,
pub lcom3: u32,
pub lcom4: u32,
pub methods: u32,
pub attributes: u32,
pub cohesion_level: CohesionLevel,
pub is_low_cohesion: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub suggestion: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub components: Vec<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CohesionStats {
pub total_classes: usize,
pub cohesive_classes: usize,
pub low_cohesion_classes: usize,
pub average_lcom3: f64,
pub max_lcom3: u32,
pub cohesion_distribution: HashMap<String, usize>,
pub average_methods: f64,
pub average_attributes: f64,
}
impl CohesionStats {
fn from_metrics(metrics: &[CohesionMetrics]) -> Self {
if metrics.is_empty() {
return Self {
total_classes: 0,
cohesive_classes: 0,
low_cohesion_classes: 0,
average_lcom3: 0.0,
max_lcom3: 0,
cohesion_distribution: HashMap::new(),
average_methods: 0.0,
average_attributes: 0.0,
};
}
let total = metrics.len();
let cohesive = metrics.iter().filter(|m| m.lcom3 <= 1).count();
let low_cohesion = metrics.iter().filter(|m| m.lcom3 > 1).count();
let lcom3_sum: u64 = metrics.iter().map(|m| u64::from(m.lcom3)).sum();
let average_lcom3 = lcom3_sum as f64 / total as f64;
let max_lcom3 = metrics.iter().map(|m| m.lcom3).max().unwrap_or(0);
let methods_sum: u64 = metrics.iter().map(|m| u64::from(m.methods)).sum();
let attrs_sum: u64 = metrics.iter().map(|m| u64::from(m.attributes)).sum();
let mut cohesion_distribution = HashMap::new();
for m in metrics {
*cohesion_distribution.entry(m.cohesion_level.to_string()).or_insert(0) += 1;
}
Self {
total_classes: total,
cohesive_classes: cohesive,
low_cohesion_classes: low_cohesion,
average_lcom3,
max_lcom3,
cohesion_distribution,
average_methods: methods_sum as f64 / total as f64,
average_attributes: attrs_sum as f64 / total as f64,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CohesionAnalysis {
pub path: PathBuf,
pub language: Option<String>,
pub classes: Vec<CohesionMetrics>,
#[serde(skip_serializing_if = "Option::is_none")]
pub violations: Option<Vec<CohesionMetrics>>,
pub stats: CohesionStats,
#[serde(skip_serializing_if = "Option::is_none")]
pub threshold: Option<u32>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub errors: Vec<CohesionError>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CohesionError {
pub file: PathBuf,
pub message: String,
}
#[derive(Debug, Default)]
struct MethodAttributeGraph {
methods: Vec<String>,
attributes: HashSet<String>,
method_attributes: HashMap<String, HashSet<String>>,
method_calls: HashMap<String, HashSet<String>>,
}
impl MethodAttributeGraph {
fn new() -> Self {
Self::default()
}
fn add_method(&mut self, name: &str) {
if !self.methods.contains(&name.to_string()) {
self.methods.push(name.to_string());
}
}
fn add_attribute_access(&mut self, method: &str, attribute: &str) {
self.attributes.insert(attribute.to_string());
self.method_attributes
.entry(method.to_string())
.or_default()
.insert(attribute.to_string());
}
fn add_method_call(&mut self, caller: &str, callee: &str) {
self.method_calls
.entry(caller.to_string())
.or_default()
.insert(callee.to_string());
}
fn calculate_lcom1(&self) -> u32 {
let (p, q) = self.count_sharing_pairs();
if p > q { (p - q) as u32 } else { 0 }
}
fn calculate_lcom2(&self) -> i32 {
let (p, q) = self.count_sharing_pairs();
(p as i32) - (q as i32)
}
fn count_sharing_pairs(&self) -> (usize, usize) {
let n = self.methods.len();
if n < 2 {
return (0, 0);
}
let mut p = 0; let mut q = 0;
for i in 0..n {
for j in (i + 1)..n {
let m1 = &self.methods[i];
let m2 = &self.methods[j];
let attrs1 = self.method_attributes.get(m1);
let attrs2 = self.method_attributes.get(m2);
let shares = match (attrs1, attrs2) {
(Some(a1), Some(a2)) => !a1.is_disjoint(a2),
_ => false,
};
if shares {
q += 1;
} else {
p += 1;
}
}
}
(p, q)
}
fn calculate_lcom3(&self) -> (u32, Vec<Vec<String>>) {
self.find_connected_components(false)
}
fn calculate_lcom4(&self) -> (u32, Vec<Vec<String>>) {
self.find_connected_components(true)
}
fn find_connected_components(&self, include_method_calls: bool) -> (u32, Vec<Vec<String>>) {
if self.methods.is_empty() {
return (0, Vec::new());
}
let mut adjacency: HashMap<&str, HashSet<&str>> = HashMap::new();
for method in &self.methods {
adjacency.insert(method.as_str(), HashSet::new());
}
for i in 0..self.methods.len() {
for j in (i + 1)..self.methods.len() {
let m1 = &self.methods[i];
let m2 = &self.methods[j];
let attrs1 = self.method_attributes.get(m1);
let attrs2 = self.method_attributes.get(m2);
let shares_attribute = match (attrs1, attrs2) {
(Some(a1), Some(a2)) => !a1.is_disjoint(a2),
_ => false,
};
if shares_attribute {
adjacency.get_mut(m1.as_str()).unwrap().insert(m2.as_str());
adjacency.get_mut(m2.as_str()).unwrap().insert(m1.as_str());
}
}
}
if include_method_calls {
for (caller, callees) in &self.method_calls {
if !self.methods.contains(caller) {
continue;
}
for callee in callees {
if self.methods.contains(callee) {
adjacency.get_mut(caller.as_str()).unwrap().insert(callee.as_str());
adjacency.get_mut(callee.as_str()).unwrap().insert(caller.as_str());
}
}
}
}
let mut visited: HashSet<&str> = HashSet::new();
let mut components: Vec<Vec<String>> = Vec::new();
for method in &self.methods {
if visited.contains(method.as_str()) {
continue;
}
let mut component = Vec::new();
let mut queue = VecDeque::new();
queue.push_back(method.as_str());
visited.insert(method.as_str());
while let Some(current) = queue.pop_front() {
component.push(current.to_string());
if let Some(neighbors) = adjacency.get(current) {
for &neighbor in neighbors {
if !visited.contains(neighbor) {
visited.insert(neighbor);
queue.push_back(neighbor);
}
}
}
}
components.push(component);
}
(components.len() as u32, components)
}
}
fn extract_attribute_accesses(
node: Node,
source: &[u8],
language: &str,
) -> HashSet<String> {
let mut attributes = HashSet::new();
extract_attributes_recursive(node, source, language, &mut attributes);
attributes
}
fn extract_attributes_recursive(
node: Node,
source: &[u8],
language: &str,
attributes: &mut HashSet<String>,
) {
let node_kind = node.kind();
match language {
"python" => {
if node_kind == "attribute" {
if let Some(object) = node.child_by_field_name("object") {
let obj_text = node_text(object, source);
if obj_text == "self" {
if let Some(attr) = node.child_by_field_name("attribute") {
let attr_name = node_text(attr, source);
if !attr_name.starts_with('_') || attr_name.starts_with("__") && attr_name.ends_with("__") {
if !(attr_name.starts_with("__") && attr_name.ends_with("__")) {
attributes.insert(attr_name.to_string());
}
} else {
attributes.insert(attr_name.to_string());
}
}
}
}
}
}
"typescript" | "javascript" | "tsx" | "jsx" => {
if node_kind == "member_expression" {
if let Some(object) = node.child_by_field_name("object") {
let obj_text = node_text(object, source);
if obj_text == "this" {
if let Some(prop) = node.child_by_field_name("property") {
let attr_name = node_text(prop, source);
attributes.insert(attr_name.to_string());
}
}
}
}
}
"rust" => {
if node_kind == "field_expression" {
if let Some(value) = node.child_by_field_name("value") {
let val_text = node_text(value, source);
if val_text == "self" {
if let Some(field) = node.child_by_field_name("field") {
let field_name = node_text(field, source);
attributes.insert(field_name.to_string());
}
}
}
}
}
"go" => {
if node_kind == "selector_expression" {
if let Some(operand) = node.child_by_field_name("operand") {
let operand_kind = operand.kind();
if operand_kind == "identifier" {
let var_name = node_text(operand, source);
if var_name.len() <= 3 || var_name.starts_with("this") || var_name.starts_with("self") {
if let Some(field) = node.child_by_field_name("field") {
let field_name = node_text(field, source);
attributes.insert(field_name.to_string());
}
}
}
}
}
}
"java" | "kotlin" | "csharp" => {
if node_kind == "field_access" || node_kind == "member_access_expression" {
if let Some(object) = node.child_by_field_name("object") {
let obj_text = node_text(object, source);
if obj_text == "this" {
if let Some(field) = node.child_by_field_name("field").or_else(|| node.child_by_field_name("name")) {
let field_name = node_text(field, source);
attributes.insert(field_name.to_string());
}
}
}
}
}
"cpp" | "c" => {
if node_kind == "field_expression" {
if let Some(argument) = node.child_by_field_name("argument") {
let arg_text = node_text(argument, source);
if arg_text == "this" {
if let Some(field) = node.child_by_field_name("field") {
let field_name = node_text(field, source);
attributes.insert(field_name.to_string());
}
}
}
}
}
_ => {}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
extract_attributes_recursive(child, source, language, attributes);
}
}
fn extract_method_calls(
node: Node,
source: &[u8],
language: &str,
class_methods: &HashSet<String>,
) -> HashSet<String> {
let mut calls = HashSet::new();
extract_calls_recursive(node, source, language, class_methods, &mut calls);
calls
}
fn extract_calls_recursive(
node: Node,
source: &[u8],
language: &str,
class_methods: &HashSet<String>,
calls: &mut HashSet<String>,
) {
let node_kind = node.kind();
match language {
"python" => {
if node_kind == "call" {
if let Some(func) = node.child_by_field_name("function") {
if func.kind() == "attribute" {
if let Some(obj) = func.child_by_field_name("object") {
if node_text(obj, source) == "self" {
if let Some(attr) = func.child_by_field_name("attribute") {
let method_name = node_text(attr, source);
if class_methods.contains(method_name) {
calls.insert(method_name.to_string());
}
}
}
}
}
}
}
}
"typescript" | "javascript" | "tsx" | "jsx" => {
if node_kind == "call_expression" {
if let Some(func) = node.child_by_field_name("function") {
if func.kind() == "member_expression" {
if let Some(obj) = func.child_by_field_name("object") {
if node_text(obj, source) == "this" {
if let Some(prop) = func.child_by_field_name("property") {
let method_name = node_text(prop, source);
if class_methods.contains(method_name) {
calls.insert(method_name.to_string());
}
}
}
}
}
}
}
}
"rust" => {
if node_kind == "call_expression" {
if let Some(func) = node.child_by_field_name("function") {
if func.kind() == "field_expression" {
if let Some(val) = func.child_by_field_name("value") {
if node_text(val, source) == "self" {
if let Some(field) = func.child_by_field_name("field") {
let method_name = node_text(field, source);
if class_methods.contains(method_name) {
calls.insert(method_name.to_string());
}
}
}
}
}
}
}
}
_ => {}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
extract_calls_recursive(child, source, language, class_methods, calls);
}
}
fn node_text<'a>(node: Node<'a>, source: &'a [u8]) -> &'a str {
std::str::from_utf8(&source[node.start_byte()..node.end_byte()]).unwrap_or("")
}
pub fn analyze_cohesion(
path: impl AsRef<Path>,
language: Option<&str>,
threshold: Option<u32>,
) -> Result<CohesionAnalysis> {
let path = path.as_ref();
if !path.exists() {
return Err(BrrrError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Path not found: {}", path.display()),
)));
}
if path.is_file() {
return analyze_file_cohesion(path, threshold);
}
let path_str = path.to_str().ok_or_else(|| {
BrrrError::InvalidArgument("Invalid path encoding".to_string())
})?;
let scanner = ProjectScanner::new(path_str)?;
let config = if let Some(lang) = language {
ScanConfig::for_language(lang)
} else {
ScanConfig::default()
};
let scan_result = scanner.scan_with_config(&config)?;
if scan_result.files.is_empty() {
return Err(BrrrError::InvalidArgument(format!(
"No source files found in {} (filter: {:?})",
path.display(),
language
)));
}
debug!("Analyzing {} files for cohesion", scan_result.files.len());
let results: Vec<(Vec<CohesionMetrics>, Vec<CohesionError>)> = scan_result
.files
.par_iter()
.map(|file| analyze_file_classes(file, threshold))
.collect();
let mut all_classes = Vec::new();
let mut all_errors = Vec::new();
for (classes, errors) in results {
all_classes.extend(classes);
all_errors.extend(errors);
}
let stats = CohesionStats::from_metrics(&all_classes);
let violations = threshold.map(|t| {
all_classes
.iter()
.filter(|c| c.lcom3 > t)
.cloned()
.collect::<Vec<_>>()
});
Ok(CohesionAnalysis {
path: path.to_path_buf(),
language: language.map(String::from),
classes: all_classes,
violations,
stats,
threshold,
errors: all_errors,
})
}
pub fn analyze_file_cohesion(
file: impl AsRef<Path>,
threshold: Option<u32>,
) -> Result<CohesionAnalysis> {
let file = file.as_ref();
if !file.exists() {
return Err(BrrrError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("File not found: {}", file.display()),
)));
}
if !file.is_file() {
return Err(BrrrError::InvalidArgument(format!(
"Expected a file, got directory: {}",
file.display()
)));
}
let (classes, errors) = analyze_file_classes(file, threshold);
let stats = CohesionStats::from_metrics(&classes);
let violations = threshold.map(|t| {
classes
.iter()
.filter(|c| c.lcom3 > t)
.cloned()
.collect::<Vec<_>>()
});
let registry = LanguageRegistry::global();
let language = registry
.detect_language(file)
.map(|l| l.name().to_string());
Ok(CohesionAnalysis {
path: file.to_path_buf(),
language,
classes,
violations,
stats,
threshold,
errors,
})
}
fn analyze_file_classes(
file: &Path,
_threshold: Option<u32>,
) -> (Vec<CohesionMetrics>, Vec<CohesionError>) {
let mut results = Vec::new();
let mut errors = Vec::new();
let module = match AstExtractor::extract_file(file) {
Ok(m) => m,
Err(e) => {
errors.push(CohesionError {
file: file.to_path_buf(),
message: format!("Failed to parse file: {}", e),
});
return (results, errors);
}
};
let source = match std::fs::read(file) {
Ok(s) => s,
Err(e) => {
errors.push(CohesionError {
file: file.to_path_buf(),
message: format!("Failed to read file: {}", e),
});
return (results, errors);
}
};
let language = &module.language;
let registry = LanguageRegistry::global();
let lang_impl = match registry.detect_language(file) {
Some(l) => l,
None => {
errors.push(CohesionError {
file: file.to_path_buf(),
message: "Unsupported language".to_string(),
});
return (results, errors);
}
};
let mut parser = match lang_impl.parser() {
Ok(p) => p,
Err(e) => {
errors.push(CohesionError {
file: file.to_path_buf(),
message: format!("Failed to create parser: {}", e),
});
return (results, errors);
}
};
let tree = match parser.parse(&source, None) {
Some(t) => t,
None => {
errors.push(CohesionError {
file: file.to_path_buf(),
message: "Failed to parse file".to_string(),
});
return (results, errors);
}
};
for class in &module.classes {
if let Some(metrics) = analyze_class_cohesion(
file,
class,
&tree,
&source,
language,
) {
results.push(metrics);
}
for inner in &class.inner_classes {
if let Some(metrics) = analyze_class_cohesion(
file,
inner,
&tree,
&source,
language,
) {
results.push(metrics);
}
}
}
(results, errors)
}
fn analyze_class_cohesion(
file: &Path,
class: &crate::ast::types::ClassInfo,
tree: &tree_sitter::Tree,
source: &[u8],
language: &str,
) -> Option<CohesionMetrics> {
if class.methods.is_empty() {
return None;
}
let mut graph = MethodAttributeGraph::new();
let mut class_method_names: HashSet<String> = HashSet::new();
for method in &class.methods {
let is_static = method.decorators.iter().any(|d| {
d == "staticmethod" || d == "static" || d.contains("@staticmethod")
});
let is_class_method = method.decorators.iter().any(|d| {
d == "classmethod" || d.contains("@classmethod")
});
if is_static || is_class_method {
continue;
}
if language == "python" && method.name.starts_with("__") && method.name.ends_with("__") {
if method.name != "__init__" {
continue;
}
}
class_method_names.insert(method.name.clone());
graph.add_method(&method.name);
}
if graph.methods.len() <= 1 {
return None;
}
let root = tree.root_node();
if let Some(class_node) = find_class_node(root, &class.name, class.line_number) {
for method in &class.methods {
if !class_method_names.contains(&method.name) {
continue;
}
if let Some(method_node) = find_method_node(class_node, &method.name, method.line_number, language) {
let attrs = extract_attribute_accesses(method_node, source, language);
for attr in &attrs {
graph.add_attribute_access(&method.name, attr);
}
let calls = extract_method_calls(method_node, source, language, &class_method_names);
for callee in &calls {
graph.add_method_call(&method.name, callee);
}
}
}
}
let lcom1 = graph.calculate_lcom1();
let lcom2 = graph.calculate_lcom2();
let (lcom3, _) = graph.calculate_lcom3();
let (lcom4, components) = graph.calculate_lcom4();
let cohesion_level = CohesionLevel::from_lcom3(lcom3);
let is_low_cohesion = lcom3 > 1;
let suggestion = if lcom4 > 1 {
Some(format!(
"Consider splitting into {} classes based on connected components",
lcom4
))
} else {
None
};
Some(CohesionMetrics {
class_name: class.name.clone(),
file: file.to_path_buf(),
line: class.line_number,
end_line: class.end_line_number.unwrap_or(class.line_number),
lcom1,
lcom2,
lcom3,
lcom4,
methods: graph.methods.len() as u32,
attributes: graph.attributes.len() as u32,
cohesion_level,
is_low_cohesion,
suggestion,
components,
})
}
fn find_class_node<'a>(node: Node<'a>, class_name: &str, line: usize) -> Option<Node<'a>> {
let node_kind = node.kind();
let is_class = matches!(
node_kind,
"class_definition" | "class_declaration" | "class" |
"impl_item" | "struct_item" | "type_declaration"
);
if is_class {
let name = node.child_by_field_name("name")
.or_else(|| {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" || child.kind() == "type_identifier" {
return Some(child);
}
}
None
});
if name.is_some() {
let node_line = node.start_position().row + 1;
if node_line == line {
return Some(node);
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(found) = find_class_node(child, class_name, line) {
return Some(found);
}
}
None
}
fn find_method_node<'a>(
class_node: Node<'a>,
_method_name: &str,
line: usize,
language: &str,
) -> Option<Node<'a>> {
let method_kinds = match language {
"python" => vec!["function_definition"],
"typescript" | "javascript" | "tsx" | "jsx" => vec!["method_definition", "function_declaration", "function"],
"rust" => vec!["function_item"],
"go" => vec!["function_declaration", "method_declaration"],
"java" | "kotlin" | "csharp" => vec!["method_declaration", "function_declaration"],
"cpp" | "c" => vec!["function_definition", "function_declarator"],
_ => vec!["function_definition", "method_definition"],
};
find_method_recursive(class_node, &method_kinds, line)
}
fn find_method_recursive<'a>(
node: Node<'a>,
method_kinds: &[&str],
line: usize,
) -> Option<Node<'a>> {
if method_kinds.contains(&node.kind()) {
let node_line = node.start_position().row + 1;
if node_line == line {
return Some(node);
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(found) = find_method_recursive(child, method_kinds, line) {
return Some(found);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_temp_file(content: &str, extension: &str) -> NamedTempFile {
let mut file = tempfile::Builder::new()
.suffix(extension)
.tempfile()
.expect("Failed to create temp file");
file.write_all(content.as_bytes())
.expect("Failed to write to temp file");
file
}
#[test]
fn test_cohesion_level_classification() {
assert_eq!(CohesionLevel::from_lcom3(0), CohesionLevel::High);
assert_eq!(CohesionLevel::from_lcom3(1), CohesionLevel::High);
assert_eq!(CohesionLevel::from_lcom3(2), CohesionLevel::Medium);
assert_eq!(CohesionLevel::from_lcom3(3), CohesionLevel::Low);
assert_eq!(CohesionLevel::from_lcom3(4), CohesionLevel::Low);
assert_eq!(CohesionLevel::from_lcom3(5), CohesionLevel::VeryLow);
assert_eq!(CohesionLevel::from_lcom3(10), CohesionLevel::VeryLow);
}
#[test]
fn test_cohesion_level_display() {
assert_eq!(CohesionLevel::High.to_string(), "high");
assert_eq!(CohesionLevel::Medium.to_string(), "medium");
assert_eq!(CohesionLevel::Low.to_string(), "low");
assert_eq!(CohesionLevel::VeryLow.to_string(), "very_low");
}
#[test]
fn test_method_attribute_graph_lcom1() {
let mut graph = MethodAttributeGraph::new();
graph.add_method("method_a");
graph.add_method("method_b");
graph.add_attribute_access("method_a", "attr1");
graph.add_attribute_access("method_b", "attr1");
assert_eq!(graph.calculate_lcom1(), 0);
graph.add_method("method_c");
graph.add_attribute_access("method_c", "attr2");
assert_eq!(graph.calculate_lcom1(), 1);
}
#[test]
fn test_method_attribute_graph_lcom2() {
let mut graph = MethodAttributeGraph::new();
graph.add_method("m1");
graph.add_method("m2");
graph.add_method("m3");
graph.add_attribute_access("m1", "a");
graph.add_attribute_access("m2", "a");
graph.add_attribute_access("m3", "a");
assert_eq!(graph.calculate_lcom2(), -3);
}
#[test]
fn test_method_attribute_graph_connected_components() {
let mut graph = MethodAttributeGraph::new();
graph.add_method("m1");
graph.add_method("m2");
graph.add_attribute_access("m1", "attr1");
graph.add_attribute_access("m2", "attr1");
graph.add_method("m3");
graph.add_method("m4");
graph.add_attribute_access("m3", "attr2");
graph.add_attribute_access("m4", "attr2");
let (lcom3, components) = graph.calculate_lcom3();
assert_eq!(lcom3, 2);
assert_eq!(components.len(), 2);
}
#[test]
fn test_method_attribute_graph_lcom4_with_calls() {
let mut graph = MethodAttributeGraph::new();
graph.add_method("m1");
graph.add_method("m2");
graph.add_method("m3");
graph.add_attribute_access("m1", "attr1");
graph.add_attribute_access("m2", "attr1");
graph.add_attribute_access("m3", "attr2");
let (lcom3, _) = graph.calculate_lcom3();
assert_eq!(lcom3, 2);
graph.add_method_call("m2", "m3");
let (lcom4, _) = graph.calculate_lcom4();
assert_eq!(lcom4, 1); }
#[test]
fn test_cohesive_python_class() {
let source = r#"
class Calculator:
def __init__(self):
self.value = 0
self.history = []
def add(self, x):
self.value += x
self.history.append(('add', x))
def subtract(self, x):
self.value -= x
self.history.append(('sub', x))
def get_value(self):
return self.value
def get_history(self):
return self.history
"#;
let file = create_temp_file(source, ".py");
let result = analyze_file_cohesion(file.path(), None);
assert!(result.is_ok());
let analysis = result.unwrap();
assert_eq!(analysis.classes.len(), 1);
let metrics = &analysis.classes[0];
assert_eq!(metrics.class_name, "Calculator");
assert!(metrics.lcom3 <= 2, "Expected cohesive class, got LCOM3 = {}", metrics.lcom3);
}
#[test]
fn test_low_cohesion_python_class() {
let source = r#"
class GodObject:
def process_users(self):
self.users = []
return self.users
def handle_payments(self):
self.payments = []
return self.payments
def send_emails(self):
self.emails = []
return self.emails
def generate_reports(self):
self.reports = []
return self.reports
"#;
let file = create_temp_file(source, ".py");
let result = analyze_file_cohesion(file.path(), None);
assert!(result.is_ok());
let analysis = result.unwrap();
assert_eq!(analysis.classes.len(), 1);
let metrics = &analysis.classes[0];
assert_eq!(metrics.class_name, "GodObject");
assert!(metrics.lcom3 >= 3, "Expected low cohesion, got LCOM3 = {}", metrics.lcom3);
assert!(metrics.is_low_cohesion);
}
#[test]
fn test_statistics_calculation() {
let metrics = vec![
CohesionMetrics {
class_name: "A".to_string(),
file: PathBuf::from("a.py"),
line: 1,
end_line: 10,
lcom1: 0,
lcom2: -2,
lcom3: 1,
lcom4: 1,
methods: 3,
attributes: 2,
cohesion_level: CohesionLevel::High,
is_low_cohesion: false,
suggestion: None,
components: vec![],
},
CohesionMetrics {
class_name: "B".to_string(),
file: PathBuf::from("b.py"),
line: 1,
end_line: 20,
lcom1: 5,
lcom2: 3,
lcom3: 3,
lcom4: 2,
methods: 4,
attributes: 4,
cohesion_level: CohesionLevel::Low,
is_low_cohesion: true,
suggestion: Some("Consider splitting".to_string()),
components: vec![],
},
];
let stats = CohesionStats::from_metrics(&metrics);
assert_eq!(stats.total_classes, 2);
assert_eq!(stats.cohesive_classes, 1);
assert_eq!(stats.low_cohesion_classes, 1);
assert_eq!(stats.max_lcom3, 3);
assert!((stats.average_lcom3 - 2.0).abs() < 0.01);
}
#[test]
fn test_empty_class_skipped() {
let source = r#"
class Empty:
pass
"#;
let file = create_temp_file(source, ".py");
let result = analyze_file_cohesion(file.path(), None);
assert!(result.is_ok());
let analysis = result.unwrap();
assert_eq!(analysis.classes.len(), 0);
}
#[test]
fn test_single_method_class_skipped() {
let source = r#"
class SingleMethod:
def only_method(self):
self.attr = 1
"#;
let file = create_temp_file(source, ".py");
let result = analyze_file_cohesion(file.path(), None);
assert!(result.is_ok());
let analysis = result.unwrap();
assert_eq!(analysis.classes.len(), 0);
}
#[test]
fn test_static_methods_excluded() {
let source = r#"
class WithStatic:
@staticmethod
def static_helper():
return 42
def instance_method(self):
self.value = 1
def another_instance(self):
self.value = 2
"#;
let file = create_temp_file(source, ".py");
let result = analyze_file_cohesion(file.path(), None);
assert!(result.is_ok());
let analysis = result.unwrap();
assert_eq!(analysis.classes.len(), 1);
let metrics = &analysis.classes[0];
assert_eq!(metrics.methods, 2);
}
#[test]
fn test_nonexistent_file() {
let result = analyze_file_cohesion("/nonexistent/path/file.py", None);
assert!(result.is_err());
}
#[test]
fn test_threshold_filtering() {
let source = r#"
class Cohesive:
def method_a(self):
self.shared = 1
def method_b(self):
self.shared = 2
class NotCohesive:
def isolated_a(self):
self.a = 1
def isolated_b(self):
self.b = 1
def isolated_c(self):
self.c = 1
"#;
let file = create_temp_file(source, ".py");
let result = analyze_file_cohesion(file.path(), Some(1));
assert!(result.is_ok());
let analysis = result.unwrap();
assert!(analysis.violations.is_some());
let violations = analysis.violations.unwrap();
assert!(violations.iter().any(|v| v.class_name == "NotCohesive"));
}
}