reddb_server/storage/engine/algorithms/
community.rs1use std::collections::{HashMap, HashSet};
8
9use super::super::graph_store::GraphStore;
10
11pub struct LabelPropagation {
20 pub max_iterations: usize,
22}
23
24impl Default for LabelPropagation {
25 fn default() -> Self {
26 Self {
27 max_iterations: 100,
28 }
29 }
30}
31
32#[derive(Debug, Clone)]
34pub struct Community {
35 pub label: String,
37 pub nodes: Vec<String>,
39 pub size: usize,
41}
42
43#[derive(Debug, Clone)]
45pub struct CommunitiesResult {
46 pub communities: Vec<Community>,
48 pub iterations: usize,
50 pub converged: bool,
52}
53
54impl CommunitiesResult {
55 pub fn largest(&self) -> Option<&Community> {
57 self.communities.first()
58 }
59
60 pub fn community_of(&self, node_id: &str) -> Option<&Community> {
62 self.communities
63 .iter()
64 .find(|c| c.nodes.contains(&node_id.to_string()))
65 }
66}
67
68impl LabelPropagation {
69 pub fn new() -> Self {
71 Self::default()
72 }
73
74 pub fn max_iterations(mut self, max: usize) -> Self {
76 self.max_iterations = max;
77 self
78 }
79
80 pub fn run(&self, graph: &GraphStore) -> CommunitiesResult {
82 let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
83
84 if nodes.is_empty() {
85 return CommunitiesResult {
86 communities: Vec::new(),
87 iterations: 0,
88 converged: true,
89 };
90 }
91
92 let mut labels: HashMap<String, String> =
94 nodes.iter().map(|id| (id.clone(), id.clone())).collect();
95
96 let mut converged = false;
97 let mut iterations = 0;
98
99 for iter in 0..self.max_iterations {
100 iterations = iter + 1;
101 let mut changed = false;
102
103 for node_id in &nodes {
105 let mut label_counts: HashMap<String, usize> = HashMap::new();
107
108 for (_, neighbor, _) in graph.outgoing_edges(node_id) {
110 if let Some(label) = labels.get(&neighbor) {
111 *label_counts.entry(label.clone()).or_insert(0) += 1;
112 }
113 }
114
115 for (_, neighbor, _) in graph.incoming_edges(node_id) {
117 if let Some(label) = labels.get(&neighbor) {
118 *label_counts.entry(label.clone()).or_insert(0) += 1;
119 }
120 }
121
122 if let Some((best_label, _)) =
124 label_counts.into_iter().max_by_key(|(_, count)| *count)
125 {
126 let current = labels.get(node_id).cloned().unwrap_or_default();
127 if best_label != current {
128 labels.insert(node_id.clone(), best_label);
129 changed = true;
130 }
131 }
132 }
133
134 if !changed {
135 converged = true;
136 break;
137 }
138 }
139
140 let mut groups: HashMap<String, Vec<String>> = HashMap::new();
142 for (node_id, label) in &labels {
143 groups
144 .entry(label.clone())
145 .or_default()
146 .push(node_id.clone());
147 }
148
149 let mut communities: Vec<Community> = groups
151 .into_iter()
152 .map(|(label, nodes)| {
153 let size = nodes.len();
154 Community { label, nodes, size }
155 })
156 .collect();
157
158 communities.sort_by_key(|b| std::cmp::Reverse(b.size));
160
161 CommunitiesResult {
162 communities,
163 iterations,
164 converged,
165 }
166 }
167}
168
169pub struct Louvain {
179 pub resolution: f64,
181 pub max_iterations: usize,
183 pub min_improvement: f64,
185}
186
187impl Default for Louvain {
188 fn default() -> Self {
189 Self {
190 resolution: 1.0,
191 max_iterations: 10,
192 min_improvement: 1e-6,
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct LouvainResult {
200 pub communities: HashMap<String, usize>,
202 pub count: usize,
204 pub modularity: f64,
206 pub passes: usize,
208}
209
210impl LouvainResult {
211 pub fn get_community(&self, community_id: usize) -> Vec<String> {
213 self.communities
214 .iter()
215 .filter(|(_, &c)| c == community_id)
216 .map(|(n, _)| n.clone())
217 .collect()
218 }
219
220 pub fn community_sizes(&self) -> HashMap<usize, usize> {
222 let mut sizes: HashMap<usize, usize> = HashMap::new();
223 for &c in self.communities.values() {
224 *sizes.entry(c).or_insert(0) += 1;
225 }
226 sizes
227 }
228}
229
230impl Louvain {
231 pub fn new() -> Self {
233 Self::default()
234 }
235
236 pub fn resolution(mut self, resolution: f64) -> Self {
238 self.resolution = resolution;
239 self
240 }
241
242 pub fn max_iterations(mut self, max: usize) -> Self {
244 self.max_iterations = max;
245 self
246 }
247
248 pub fn run(&self, graph: &GraphStore) -> LouvainResult {
250 let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
251
252 if nodes.is_empty() {
253 return LouvainResult {
254 communities: HashMap::new(),
255 count: 0,
256 modularity: 0.0,
257 passes: 0,
258 };
259 }
260
261 let mut weights: HashMap<(String, String), f64> = HashMap::new();
263 let mut node_strength: HashMap<String, f64> = HashMap::new();
264 let mut total_weight = 0.0;
265
266 for node in &nodes {
267 for (_, target, _) in graph.outgoing_edges(node) {
268 if node != &target {
269 let key = if node < &target {
270 (node.clone(), target.clone())
271 } else {
272 (target.clone(), node.clone())
273 };
274
275 let w = weights.entry(key).or_insert(0.0);
276 *w += 1.0; }
278 }
279 }
280
281 for ((a, b), w) in &weights {
283 *node_strength.entry(a.clone()).or_insert(0.0) += w;
284 *node_strength.entry(b.clone()).or_insert(0.0) += w;
285 total_weight += w;
286 }
287
288 if total_weight == 0.0 {
289 let communities: HashMap<String, usize> = nodes
291 .iter()
292 .enumerate()
293 .map(|(i, n)| (n.clone(), i))
294 .collect();
295 return LouvainResult {
296 count: nodes.len(),
297 communities,
298 modularity: 0.0,
299 passes: 0,
300 };
301 }
302
303 let mut communities: HashMap<String, usize> = nodes
305 .iter()
306 .enumerate()
307 .map(|(i, n)| (n.clone(), i))
308 .collect();
309
310 let mut comm_total: HashMap<usize, f64> = nodes
312 .iter()
313 .enumerate()
314 .map(|(i, n)| (i, *node_strength.get(n).unwrap_or(&0.0)))
315 .collect();
316
317 let mut comm_internal: HashMap<usize, f64> = HashMap::new();
319
320 let mut passes = 0;
321 let mut improved = true;
322
323 while improved && passes < self.max_iterations {
324 improved = false;
325 passes += 1;
326
327 for node in &nodes {
328 let current_comm = *communities.get(node).unwrap();
329 let node_w = *node_strength.get(node).unwrap_or(&0.0);
330
331 let mut neighbor_comm_weights: HashMap<usize, f64> = HashMap::new();
333
334 for ((a, b), w) in &weights {
335 if a == node {
336 let neighbor_comm = *communities.get(b).unwrap();
337 *neighbor_comm_weights.entry(neighbor_comm).or_insert(0.0) += w;
338 } else if b == node {
339 let neighbor_comm = *communities.get(a).unwrap();
340 *neighbor_comm_weights.entry(neighbor_comm).or_insert(0.0) += w;
341 }
342 }
343
344 let mut best_comm = current_comm;
346 let mut best_delta = 0.0;
347
348 let current_internal = neighbor_comm_weights
350 .get(¤t_comm)
351 .copied()
352 .unwrap_or(0.0);
353 let current_total = *comm_total.get(¤t_comm).unwrap_or(&0.0);
354
355 for (&target_comm, &weight_to_target) in &neighbor_comm_weights {
356 if target_comm == current_comm {
357 continue;
358 }
359
360 let target_total = *comm_total.get(&target_comm).unwrap_or(&0.0);
361
362 let delta = (weight_to_target - current_internal) / total_weight
363 - self.resolution * node_w * (target_total - current_total + node_w)
364 / (2.0 * total_weight * total_weight);
365
366 if delta > best_delta + self.min_improvement {
367 best_delta = delta;
368 best_comm = target_comm;
369 }
370 }
371
372 if best_comm != current_comm {
374 improved = true;
375
376 *comm_total.entry(current_comm).or_insert(0.0) -= node_w;
378 *comm_total.entry(best_comm).or_insert(0.0) += node_w;
379
380 let current_internal = neighbor_comm_weights
382 .get(¤t_comm)
383 .copied()
384 .unwrap_or(0.0);
385 *comm_internal.entry(current_comm).or_insert(0.0) -= current_internal;
386
387 let new_internal = neighbor_comm_weights
388 .get(&best_comm)
389 .copied()
390 .unwrap_or(0.0);
391 *comm_internal.entry(best_comm).or_insert(0.0) += new_internal;
392
393 communities.insert(node.clone(), best_comm);
394 }
395 }
396 }
397
398 let unique_communities: Vec<usize> = {
400 let c: HashSet<usize> = communities.values().copied().collect();
401 let mut v: Vec<usize> = c.into_iter().collect();
402 v.sort();
403 v
404 };
405
406 let comm_map: HashMap<usize, usize> = unique_communities
407 .iter()
408 .enumerate()
409 .map(|(new, &old)| (old, new))
410 .collect();
411
412 let remapped: HashMap<String, usize> = communities
413 .into_iter()
414 .map(|(n, c)| (n, *comm_map.get(&c).unwrap_or(&0)))
415 .collect();
416
417 let modularity =
419 self.calculate_modularity(&remapped, &weights, &node_strength, total_weight);
420
421 LouvainResult {
422 count: unique_communities.len(),
423 communities: remapped,
424 modularity,
425 passes,
426 }
427 }
428
429 fn calculate_modularity(
431 &self,
432 communities: &HashMap<String, usize>,
433 weights: &HashMap<(String, String), f64>,
434 node_strength: &HashMap<String, f64>,
435 total_weight: f64,
436 ) -> f64 {
437 if total_weight == 0.0 {
438 return 0.0;
439 }
440
441 let mut q = 0.0;
442
443 for ((a, b), w) in weights {
445 let ca = communities.get(a).unwrap();
446 let cb = communities.get(b).unwrap();
447
448 if ca == cb {
449 let ka = node_strength.get(a).unwrap_or(&0.0);
450 let kb = node_strength.get(b).unwrap_or(&0.0);
451 q += w - self.resolution * ka * kb / (2.0 * total_weight);
452 }
453 }
454
455 q / total_weight
456 }
457}