use crate::activation::HybridEngine;
use crate::counterfactual::{CascadeResult, CounterfactualEngine, CounterfactualResult};
use crate::error::{M1ndError, M1ndResult};
use crate::graph::Graph;
use crate::topology::{Bridge, BridgeDetector, CommunityDetector, CommunityResult};
use crate::types::{CommunityId, FiniteF32, NodeId};
use serde::Serialize;
use std::collections::{HashMap, HashSet};
#[derive(Clone, Debug)]
pub struct RefactorConfig {
pub max_communities: usize,
pub min_community_size: usize,
pub max_acceptable_impact: f32,
pub scope: Option<String>,
}
impl Default for RefactorConfig {
fn default() -> Self {
Self {
max_communities: 10,
min_community_size: 3,
max_acceptable_impact: 0.30,
scope: None,
}
}
}
#[derive(Clone, Debug, Serialize)]
pub struct InterfaceEdge {
pub source_id: String,
pub target_id: String,
pub relation: String,
pub weight: f32,
pub direction: String,
}
#[derive(Clone, Debug, Serialize)]
pub struct ExtractionRisk {
pub level: String,
pub activation_loss: f32,
pub orphaned_count: usize,
pub weakened_count: usize,
pub cascade_depth: u8,
pub cascade_affected: u32,
}
#[derive(Clone, Debug, Serialize)]
pub struct ExtractionPlan {
pub community_id: u32,
pub extracted_nodes: Vec<String>,
pub extracted_labels: Vec<String>,
pub interface_edges: Vec<InterfaceEdge>,
pub risk: ExtractionRisk,
pub community_modularity: f32,
pub cohesion: f32,
pub coupling: f32,
}
#[derive(Clone, Debug, Serialize)]
pub struct RefactorPlan {
pub candidates: Vec<ExtractionPlan>,
pub graph_modularity: f32,
pub num_communities: u32,
pub nodes_analyzed: usize,
pub elapsed_ms: f64,
}
pub fn plan_refactoring(graph: &Graph, config: &RefactorConfig) -> M1ndResult<RefactorPlan> {
let start = std::time::Instant::now();
let n = graph.num_nodes() as usize;
if n == 0 || !graph.finalized {
return Err(M1ndError::EmptyGraph);
}
let mut node_to_ext: Vec<String> = vec![String::new(); n];
for (interned, node_id) in &graph.id_to_node {
let idx = node_id.as_usize();
if idx < n {
node_to_ext[idx] = graph.strings.resolve(*interned).to_string();
}
}
let detector = CommunityDetector::with_defaults();
let communities = detector.detect(graph)?;
let bridges = BridgeDetector::detect(graph, &communities)?;
let mut community_nodes: HashMap<u32, Vec<usize>> = HashMap::new();
#[allow(clippy::needless_range_loop)]
for i in 0..n {
if let Some(ref scope) = config.scope {
if !node_to_ext[i].contains(scope.as_str()) {
continue;
}
}
let cid = communities.assignments[i].0;
community_nodes.entry(cid).or_default().push(i);
}
let mut internal_edges: HashMap<u32, u32> = HashMap::new();
let mut external_edges: HashMap<u32, u32> = HashMap::new();
for i in 0..n {
let ci = communities.assignments[i].0;
let range = graph.csr.out_range(NodeId::new(i as u32));
for j in range {
let tgt = graph.csr.targets[j].as_usize();
if tgt < n {
let cj = communities.assignments[tgt].0;
if ci == cj {
*internal_edges.entry(ci).or_insert(0) += 1;
} else {
*external_edges.entry(ci).or_insert(0) += 1;
}
}
}
}
let cf_engine = CounterfactualEngine::with_defaults();
let hybrid_engine = HybridEngine::new();
let prop_config = crate::types::PropagationConfig::default();
let mut candidates: Vec<ExtractionPlan> = Vec::new();
for (&cid, nodes) in &community_nodes {
if nodes.len() < config.min_community_size {
continue;
}
if candidates.len() >= config.max_communities {
break;
}
let node_ids: Vec<NodeId> = nodes.iter().map(|&i| NodeId::new(i as u32)).collect();
let node_set: HashSet<usize> = nodes.iter().copied().collect();
let cf_result =
cf_engine.simulate_removal(graph, &hybrid_engine, &prop_config, &node_ids)?;
let cascade =
cf_engine.cascade_analysis(graph, &hybrid_engine, &prop_config, node_ids[0])?;
let interface: Vec<InterfaceEdge> = bridges
.iter()
.filter(|b| {
b.source_community == CommunityId(cid) || b.target_community == CommunityId(cid)
})
.map(|b| {
let direction = if b.source_community == CommunityId(cid) {
"outbound"
} else {
"inbound"
};
InterfaceEdge {
source_id: node_to_ext[b.source.as_usize()].clone(),
target_id: node_to_ext[b.target.as_usize()].clone(),
relation: graph
.strings
.resolve(graph.csr.relations[b.edge_idx.as_usize()])
.to_string(),
weight: b.importance.get(),
direction: direction.to_string(),
}
})
.collect();
let int_e = *internal_edges.get(&cid).unwrap_or(&0) as f32;
let ext_e = *external_edges.get(&cid).unwrap_or(&0) as f32;
let total_e = int_e + ext_e;
let cohesion = if total_e > 0.0 { int_e / total_e } else { 1.0 };
let coupling = if total_e > 0.0 { ext_e / total_e } else { 0.0 };
let impact = cf_result.pct_activation_lost.get();
let risk_level = if impact < 0.05 {
"low"
} else if impact < 0.15 {
"medium"
} else if impact < config.max_acceptable_impact {
"high"
} else {
"critical"
};
candidates.push(ExtractionPlan {
community_id: cid,
extracted_nodes: nodes.iter().map(|&i| node_to_ext[i].clone()).collect(),
extracted_labels: nodes
.iter()
.map(|&i| graph.strings.resolve(graph.nodes.label[i]).to_string())
.collect(),
interface_edges: interface,
risk: ExtractionRisk {
level: risk_level.to_string(),
activation_loss: impact,
orphaned_count: cf_result.orphaned_nodes.len(),
weakened_count: cf_result.weakened_nodes.len(),
cascade_depth: cascade.cascade_depth,
cascade_affected: cascade.total_affected,
},
community_modularity: communities.modularity.get(),
cohesion,
coupling,
});
}
candidates.sort_by(|a, b| {
let score_a = a.risk.activation_loss - a.cohesion * 0.5;
let score_b = b.risk.activation_loss - b.cohesion * 0.5;
score_a
.partial_cmp(&score_b)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(RefactorPlan {
candidates,
graph_modularity: communities.modularity.get(),
num_communities: communities.num_communities,
nodes_analyzed: n,
elapsed_ms: start.elapsed().as_secs_f64() * 1000.0,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::*;
use crate::types::{EdgeDirection, FiniteF32, NodeId, NodeType};
fn build_two_cluster_graph() -> Graph {
let mut g = Graph::new();
g.add_node(
"a1",
"handler_a",
NodeType::Function,
&["cluster_a"],
0.0,
0.5,
)
.unwrap();
g.add_node(
"a2",
"process_a",
NodeType::Function,
&["cluster_a"],
0.0,
0.4,
)
.unwrap();
g.add_node(
"a3",
"output_a",
NodeType::Function,
&["cluster_a"],
0.0,
0.3,
)
.unwrap();
g.add_node(
"b1",
"handler_b",
NodeType::Function,
&["cluster_b"],
0.0,
0.5,
)
.unwrap();
g.add_node(
"b2",
"process_b",
NodeType::Function,
&["cluster_b"],
0.0,
0.4,
)
.unwrap();
g.add_node(
"b3",
"output_b",
NodeType::Function,
&["cluster_b"],
0.0,
0.3,
)
.unwrap();
g.add_edge(
NodeId::new(0),
NodeId::new(1),
"calls",
FiniteF32::new(0.9),
EdgeDirection::Forward,
false,
FiniteF32::new(0.5),
)
.unwrap();
g.add_edge(
NodeId::new(1),
NodeId::new(2),
"calls",
FiniteF32::new(0.8),
EdgeDirection::Forward,
false,
FiniteF32::new(0.5),
)
.unwrap();
g.add_edge(
NodeId::new(3),
NodeId::new(4),
"calls",
FiniteF32::new(0.9),
EdgeDirection::Forward,
false,
FiniteF32::new(0.5),
)
.unwrap();
g.add_edge(
NodeId::new(4),
NodeId::new(5),
"calls",
FiniteF32::new(0.8),
EdgeDirection::Forward,
false,
FiniteF32::new(0.5),
)
.unwrap();
g.add_edge(
NodeId::new(2),
NodeId::new(3),
"calls",
FiniteF32::new(0.2),
EdgeDirection::Forward,
false,
FiniteF32::new(0.3),
)
.unwrap();
g.finalize().unwrap();
g
}
#[test]
fn plan_empty_graph_error() {
let g = Graph::new();
let config = RefactorConfig::default();
assert!(plan_refactoring(&g, &config).is_err());
}
#[test]
fn plan_two_clusters_produces_candidates() {
let g = build_two_cluster_graph();
let config = RefactorConfig {
min_community_size: 2,
..RefactorConfig::default()
};
let result = plan_refactoring(&g, &config).unwrap();
assert!(result.nodes_analyzed == 6);
assert!(result.num_communities >= 1);
}
#[test]
fn plan_high_cohesion_low_coupling() {
let g = build_two_cluster_graph();
let config = RefactorConfig {
min_community_size: 2,
..RefactorConfig::default()
};
let result = plan_refactoring(&g, &config).unwrap();
if !result.candidates.is_empty() {
let best = &result.candidates[0];
assert!(best.cohesion >= 0.0, "Cohesion should be >= 0");
}
}
#[test]
fn plan_risk_levels_assigned() {
let g = build_two_cluster_graph();
let config = RefactorConfig {
min_community_size: 2,
..RefactorConfig::default()
};
let result = plan_refactoring(&g, &config).unwrap();
for candidate in &result.candidates {
assert!(
["low", "medium", "high", "critical"].contains(&candidate.risk.level.as_str()),
"Invalid risk level: {}",
candidate.risk.level
);
}
}
#[test]
fn plan_scope_filter_limits_candidates() {
let g = build_two_cluster_graph();
let config = RefactorConfig {
min_community_size: 1,
scope: Some("nonexistent".to_string()),
..RefactorConfig::default()
};
let result = plan_refactoring(&g, &config).unwrap();
assert!(
result.candidates.is_empty(),
"Nonexistent scope should yield no candidates"
);
}
#[test]
fn plan_interface_edges_on_bridge() {
let g = build_two_cluster_graph();
let config = RefactorConfig {
min_community_size: 2,
..RefactorConfig::default()
};
let result = plan_refactoring(&g, &config).unwrap();
if result.num_communities >= 2 {
let has_interface = result
.candidates
.iter()
.any(|c| !c.interface_edges.is_empty());
assert!(
has_interface,
"Split communities should have interface edges"
);
}
}
}