Skip to main content

cgx_engine/
cluster.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::graph::{Edge, GraphDb, Node};
4
5/// Partition nodes into communities using the Louvain modularity algorithm.
6///
7/// Returns a map of `node_id → community_id` (IDs start at 1).  Nodes with no
8/// edges are each placed in their own singleton community.
9pub fn detect_communities(nodes: &[Node], edges: &[Edge]) -> HashMap<String, i64> {
10    if nodes.is_empty() {
11        return HashMap::new();
12    }
13
14    let mut communities: HashMap<String, i64> = HashMap::new();
15    let node_ids: Vec<&str> = nodes.iter().map(|n| n.id.as_str()).collect();
16    let node_idx: HashMap<&str, usize> = node_ids
17        .iter()
18        .enumerate()
19        .map(|(i, id)| (*id, i))
20        .collect();
21    let n = node_ids.len();
22
23    let mut part: Vec<usize> = (0..n).collect();
24
25    let mut adj: Vec<HashMap<usize, f64>> = vec![HashMap::new(); n];
26    let mut k: Vec<f64> = vec![0.0; n];
27    let mut total_weight: f64 = 0.0;
28
29    for edge in edges {
30        if let (Some(&src_idx), Some(&dst_idx)) = (
31            node_idx.get(edge.src.as_str()),
32            node_idx.get(edge.dst.as_str()),
33        ) {
34            if src_idx == dst_idx {
35                continue;
36            }
37            let w = edge.weight.max(0.01) * edge.confidence.max(0.01);
38            *adj[src_idx].entry(dst_idx).or_insert(0.0) += w;
39            *adj[dst_idx].entry(src_idx).or_insert(0.0) += w;
40            k[src_idx] += w;
41            k[dst_idx] += w;
42            total_weight += w;
43        }
44    }
45
46    if total_weight == 0.0 {
47        for (i, id) in node_ids.iter().enumerate() {
48            communities.insert(id.to_string(), i as i64 + 1);
49        }
50        return communities;
51    }
52
53    let m2 = total_weight;
54
55    let mut improved = true;
56    let mut iteration = 0;
57    let max_iterations = 50;
58    while improved && iteration < max_iterations {
59        improved = false;
60        iteration += 1;
61
62        let mut order: Vec<usize> = (0..n).collect();
63        fastrand::shuffle(&mut order);
64
65        for &node in &order {
66            let current_comm = part[node];
67            let mut best_comm = current_comm;
68            let mut best_delta: f64 = 0.0;
69
70            let ki = k[node];
71            let ki_in_current = community_edge_weight(node, current_comm, &adj, &part);
72            let sigma_tot_current = total_community_weight(current_comm, &part, &k) - ki;
73
74            let mut neighbor_comms: HashSet<usize> = HashSet::new();
75            for &neighbor in adj[node].keys() {
76                neighbor_comms.insert(part[neighbor]);
77            }
78
79            for &neighbor_comm in &neighbor_comms {
80                if neighbor_comm == current_comm {
81                    continue;
82                }
83
84                let ki_in_neighbor = community_edge_weight(node, neighbor_comm, &adj, &part);
85                let sigma_tot_neighbor = total_community_weight(neighbor_comm, &part, &k);
86
87                let delta = modularity_delta(ki_in_neighbor, sigma_tot_neighbor, ki, m2)
88                    - modularity_delta(ki_in_current, sigma_tot_current, ki, m2);
89
90                if delta > best_delta {
91                    best_delta = delta;
92                    best_comm = neighbor_comm;
93                }
94            }
95
96            if best_comm != current_comm {
97                part[node] = best_comm;
98                improved = true;
99            }
100        }
101    }
102
103    let mut comm_map: HashMap<usize, i64> = HashMap::new();
104    let mut next_comm_id: i64 = 1;
105    for &c in &part {
106        comm_map.entry(c).or_insert_with(|| {
107            let id = next_comm_id;
108            next_comm_id += 1;
109            id
110        });
111    }
112
113    for (i, id) in node_ids.iter().enumerate() {
114        let comm = comm_map[&part[i]];
115        communities.insert(id.to_string(), comm);
116    }
117
118    communities
119}
120
121fn community_edge_weight(
122    node: usize,
123    community: usize,
124    adj: &[HashMap<usize, f64>],
125    part: &[usize],
126) -> f64 {
127    let mut weight = 0.0;
128    for (&neighbor, &w) in &adj[node] {
129        if part[neighbor] == community {
130            weight += w;
131        }
132    }
133    weight
134}
135
136fn total_community_weight(community: usize, part: &[usize], k: &[f64]) -> f64 {
137    let mut total = 0.0;
138    for (i, &comm) in part.iter().enumerate() {
139        if comm == community {
140            total += k[i];
141        }
142    }
143    total
144}
145
146fn modularity_delta(ki_in: f64, sigma_tot: f64, ki: f64, m2: f64) -> f64 {
147    if m2 == 0.0 {
148        return 0.0;
149    }
150    (ki_in / m2) - (sigma_tot * ki / (2.0 * m2 * m2))
151}
152
153/// Run community detection on the graph stored in `db` and persist the results.
154///
155/// Clears any existing community assignments, recomputes them with
156/// [`detect_communities`], writes them back, and returns the number of
157/// distinct communities found.
158pub fn run_clustering(db: &GraphDb) -> anyhow::Result<usize> {
159    db.clear_communities()?;
160
161    let nodes = db.get_all_nodes()?;
162    let edges = db.get_all_edges()?;
163
164    if nodes.is_empty() {
165        return Ok(0);
166    }
167
168    let community_map = detect_communities(&nodes, &edges);
169
170    let community_count = {
171        let mut unique: HashSet<i64> = HashSet::new();
172        for &c in community_map.values() {
173            unique.insert(c);
174        }
175        unique.len()
176    };
177
178    db.update_node_communities(&community_map)?;
179
180    Ok(community_count)
181}