use std::collections::{HashMap, HashSet};
use crate::types::{DiamondPattern, InheritanceGraph};
pub fn detect_abc_protocol(graph: &mut InheritanceGraph) {
for (_name, node) in graph.nodes.iter_mut() {
if node.bases.iter().any(|b| b == "ABC" || b == "ABCMeta") {
node.is_abstract = Some(true);
}
if node
.bases
.iter()
.any(|b| b == "Protocol" || b.ends_with(".Protocol"))
{
node.protocol = Some(true);
}
}
}
pub fn detect_mixins(graph: &mut InheritanceGraph) {
let mut secondary_base_count: HashMap<String, usize> = HashMap::new();
for node in graph.nodes.values() {
if node.bases.len() > 1 {
for base in &node.bases[1..] {
*secondary_base_count.entry(base.clone()).or_insert(0) += 1;
}
}
}
for (name, node) in graph.nodes.iter_mut() {
if name.to_lowercase().ends_with("mixin") {
node.mixin = Some(true);
continue;
}
if node.bases.is_empty() {
if let Some(&count) = secondary_base_count.get(name) {
if count >= 2 {
node.mixin = Some(true);
}
}
}
}
}
pub fn detect_diamonds(graph: &InheritanceGraph) -> Vec<DiamondPattern> {
let mut diamonds = Vec::new();
for (class_name, parents) in graph.multi_parent_classes() {
if parents.len() < 2 {
continue;
}
let ancestor_sets: Vec<HashSet<String>> = parents
.iter()
.map(|parent| graph.ancestors_bfs(parent))
.collect();
if ancestor_sets.is_empty() {
continue;
}
let common: HashSet<String> = if ancestor_sets.len() == 1 {
ancestor_sets[0].clone()
} else {
ancestor_sets[1..]
.iter()
.fold(ancestor_sets[0].clone(), |acc, s| {
acc.intersection(s).cloned().collect()
})
};
for ancestor in common {
let paths = compute_paths_to_ancestor(graph, class_name, &ancestor, parents);
if paths.len() >= 2 {
diamonds.push(DiamondPattern {
class_name: class_name.clone(),
common_ancestor: ancestor,
paths,
});
}
}
}
diamonds
}
fn compute_paths_to_ancestor(
graph: &InheritanceGraph,
class_name: &str,
ancestor: &str,
parents: &[String],
) -> Vec<Vec<String>> {
let mut paths = Vec::new();
for parent in parents {
if let Some(path) = find_path_to_ancestor(graph, parent, ancestor) {
let mut full_path = vec![class_name.to_string()];
full_path.extend(path);
paths.push(full_path);
}
}
paths
}
fn find_path_to_ancestor(
graph: &InheritanceGraph,
start: &str,
ancestor: &str,
) -> Option<Vec<String>> {
use std::collections::VecDeque;
if start == ancestor {
return Some(vec![start.to_string()]);
}
let mut queue = VecDeque::new();
let mut visited = HashSet::new();
let mut parent_map: HashMap<String, String> = HashMap::new();
queue.push_back(start.to_string());
visited.insert(start.to_string());
while let Some(current) = queue.pop_front() {
if let Some(parents) = graph.parents.get(¤t) {
for parent in parents {
if !visited.contains(parent) {
visited.insert(parent.clone());
parent_map.insert(parent.clone(), current.clone());
queue.push_back(parent.clone());
if parent == ancestor {
let mut path = vec![ancestor.to_string()];
let mut curr = ancestor.to_string();
while let Some(child) = parent_map.get(&curr) {
path.push(child.clone());
curr = child.clone();
}
path.reverse();
return Some(path);
}
}
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{InheritanceNode, Language};
use std::path::PathBuf;
fn create_node(name: &str, bases: Vec<&str>) -> InheritanceNode {
let mut node = InheritanceNode::new(name, PathBuf::from("test.py"), 1, Language::Python);
node.bases = bases.into_iter().map(|s| s.to_string()).collect();
node
}
#[test]
fn test_abc_detection() {
let mut graph = InheritanceGraph::new();
graph.add_node(create_node("Animal", vec!["ABC"]));
graph.add_node(create_node("Dog", vec!["Animal"]));
graph.add_edge("Animal", "ABC");
graph.add_edge("Dog", "Animal");
detect_abc_protocol(&mut graph);
let animal = graph.nodes.get("Animal").unwrap();
assert_eq!(animal.is_abstract, Some(true));
}
#[test]
fn test_protocol_detection() {
let mut graph = InheritanceGraph::new();
graph.add_node(create_node("Serializable", vec!["Protocol"]));
detect_abc_protocol(&mut graph);
let serializable = graph.nodes.get("Serializable").unwrap();
assert_eq!(serializable.protocol, Some(true));
}
#[test]
fn test_mixin_detection_by_name() {
let mut graph = InheritanceGraph::new();
graph.add_node(create_node("TimestampMixin", vec![]));
graph.add_node(create_node("User", vec!["Base", "TimestampMixin"]));
detect_mixins(&mut graph);
let mixin = graph.nodes.get("TimestampMixin").unwrap();
assert_eq!(mixin.mixin, Some(true));
}
#[test]
fn test_mixin_detection_by_usage() {
let mut graph = InheritanceGraph::new();
graph.add_node(create_node("Auditable", vec![]));
graph.add_node(create_node("User", vec!["Base", "Auditable"]));
graph.add_node(create_node("Post", vec!["Base", "Auditable"]));
graph.add_node(create_node("Comment", vec!["Base", "Auditable"]));
detect_mixins(&mut graph);
let auditable = graph.nodes.get("Auditable").unwrap();
assert_eq!(auditable.mixin, Some(true));
}
#[test]
fn test_diamond_detection() {
let mut graph = InheritanceGraph::new();
graph.add_node(create_node("A", vec![]));
graph.add_node(create_node("B", vec!["A"]));
graph.add_node(create_node("C", vec!["A"]));
graph.add_node(create_node("D", vec!["B", "C"]));
graph.add_edge("B", "A");
graph.add_edge("C", "A");
graph.add_edge("D", "B");
graph.add_edge("D", "C");
let diamonds = detect_diamonds(&graph);
assert_eq!(diamonds.len(), 1);
assert_eq!(diamonds[0].class_name, "D");
assert_eq!(diamonds[0].common_ancestor, "A");
assert_eq!(diamonds[0].paths.len(), 2);
}
#[test]
fn test_no_diamond_single_inheritance() {
let mut graph = InheritanceGraph::new();
graph.add_node(create_node("A", vec![]));
graph.add_node(create_node("B", vec!["A"]));
graph.add_node(create_node("C", vec!["B"]));
graph.add_edge("B", "A");
graph.add_edge("C", "B");
let diamonds = detect_diamonds(&graph);
assert!(diamonds.is_empty());
}
#[test]
fn test_no_diamond_disjoint_parents() {
let mut graph = InheritanceGraph::new();
graph.add_node(create_node("A", vec![]));
graph.add_node(create_node("B", vec![]));
graph.add_node(create_node("D", vec!["A", "B"]));
graph.add_edge("D", "A");
graph.add_edge("D", "B");
let diamonds = detect_diamonds(&graph);
assert!(diamonds.is_empty());
}
}