manifoldb_graph/analytics/
community.rs1use std::collections::HashMap;
45
46use manifoldb_core::EntityId;
47use manifoldb_storage::Transaction;
48
49use crate::index::AdjacencyIndex;
50use crate::store::{EdgeStore, GraphError, GraphResult, NodeStore};
51use crate::traversal::Direction;
52
53use super::pagerank::DEFAULT_MAX_GRAPH_NODES;
54
55#[derive(Debug, Clone)]
57pub struct CommunityDetectionConfig {
58 pub max_iterations: usize,
61
62 pub direction: Direction,
65
66 pub seed: Option<u64>,
69
70 pub min_improvement: f64,
74
75 pub max_graph_nodes: Option<usize>,
79}
80
81impl Default for CommunityDetectionConfig {
82 fn default() -> Self {
83 Self {
84 max_iterations: 100,
85 direction: Direction::Both,
86 seed: None,
87 min_improvement: 0.0,
88 max_graph_nodes: Some(DEFAULT_MAX_GRAPH_NODES),
89 }
90 }
91}
92
93impl CommunityDetectionConfig {
94 pub fn new() -> Self {
96 Self::default()
97 }
98
99 pub const fn with_max_iterations(mut self, max_iterations: usize) -> Self {
101 self.max_iterations = max_iterations;
102 self
103 }
104
105 pub const fn with_direction(mut self, direction: Direction) -> Self {
107 self.direction = direction;
108 self
109 }
110
111 pub const fn with_seed(mut self, seed: u64) -> Self {
113 self.seed = Some(seed);
114 self
115 }
116
117 pub const fn with_min_improvement(mut self, min_improvement: f64) -> Self {
119 self.min_improvement = min_improvement;
120 self
121 }
122
123 pub const fn with_max_graph_nodes(mut self, limit: Option<usize>) -> Self {
133 self.max_graph_nodes = limit;
134 self
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct CommunityResult {
141 pub assignments: HashMap<EntityId, u64>,
143
144 pub iterations: usize,
146
147 pub converged: bool,
149
150 pub num_communities: usize,
152}
153
154impl CommunityResult {
155 pub fn community(&self, node: EntityId) -> Option<u64> {
157 self.assignments.get(&node).copied()
158 }
159
160 pub fn members(&self, community_id: u64) -> Vec<EntityId> {
162 self.assignments.iter().filter(|(_, &c)| c == community_id).map(|(&node, _)| node).collect()
163 }
164
165 pub fn community_sizes(&self) -> HashMap<u64, usize> {
167 let mut sizes: HashMap<u64, usize> = HashMap::new();
168 for &community in self.assignments.values() {
169 *sizes.entry(community).or_insert(0) += 1;
170 }
171 sizes
172 }
173
174 pub fn communities_by_size(&self) -> Vec<(u64, usize)> {
176 let mut sizes: Vec<_> = self.community_sizes().into_iter().collect();
177 sizes.sort_by(|a, b| b.1.cmp(&a.1));
178 sizes
179 }
180
181 pub fn largest_community(&self) -> Option<(u64, usize)> {
183 self.communities_by_size().into_iter().next()
184 }
185
186 pub fn smallest_community(&self) -> Option<(u64, usize)> {
188 self.communities_by_size().into_iter().last()
189 }
190
191 pub fn same_community(&self, node1: EntityId, node2: EntityId) -> bool {
193 match (self.community(node1), self.community(node2)) {
194 (Some(c1), Some(c2)) => c1 == c2,
195 _ => false,
196 }
197 }
198}
199
200pub struct CommunityDetection;
202
203impl CommunityDetection {
204 pub fn label_propagation<T: Transaction>(
218 tx: &T,
219 config: &CommunityDetectionConfig,
220 ) -> GraphResult<CommunityResult> {
221 if let Some(limit) = config.max_graph_nodes {
223 let node_count = NodeStore::count(tx)?;
224 if node_count > limit {
225 return Err(GraphError::GraphTooLarge { node_count, limit });
226 }
227 }
228
229 let mut nodes: Vec<EntityId> = Vec::new();
231 NodeStore::for_each(tx, |entity| {
232 nodes.push(entity.id);
233 true
234 })?;
235
236 let n = nodes.len();
237 if n == 0 {
238 return Ok(CommunityResult {
239 assignments: HashMap::new(),
240 iterations: 0,
241 converged: true,
242 num_communities: 0,
243 });
244 }
245
246 let node_index: HashMap<EntityId, usize> =
248 nodes.iter().enumerate().map(|(i, &id)| (id, i)).collect();
249
250 let mut neighbors: Vec<Vec<usize>> = vec![Vec::new(); n];
252 for (i, &node) in nodes.iter().enumerate() {
253 let neighbor_ids = Self::get_neighbors(tx, node, config.direction)?;
254 for neighbor_id in neighbor_ids {
255 if let Some(&j) = node_index.get(&neighbor_id) {
256 neighbors[i].push(j);
257 }
258 }
259 }
260
261 let mut labels: Vec<u64> = (0..n as u64).collect();
263
264 let mut rng_state = config.seed.unwrap_or(12345);
266 let simple_random = |state: &mut u64| -> u64 {
267 *state = state
269 .wrapping_mul(6_364_136_223_846_793_005)
270 .wrapping_add(1_442_695_040_888_963_407);
271 *state
272 };
273
274 let mut iterations = 0;
275 let mut converged = false;
276
277 let mut order: Vec<usize> = (0..n).collect();
279
280 while iterations < config.max_iterations {
281 iterations += 1;
282
283 for i in (1..n).rev() {
285 let j = (simple_random(&mut rng_state) as usize) % (i + 1);
286 order.swap(i, j);
287 }
288
289 let mut changed = false;
290
291 for &i in &order {
292 if neighbors[i].is_empty() {
293 continue;
294 }
295
296 let mut label_counts: HashMap<u64, usize> = HashMap::new();
298 for &j in &neighbors[i] {
299 *label_counts.entry(labels[j]).or_insert(0) += 1;
300 }
301
302 let max_count = *label_counts.values().max().unwrap_or(&0);
304
305 let max_labels: Vec<u64> = label_counts
307 .iter()
308 .filter(|(_, &count)| count == max_count)
309 .map(|(&label, _)| label)
310 .collect();
311
312 let new_label = if max_labels.contains(&labels[i]) {
314 labels[i]
315 } else if max_labels.len() == 1 {
316 max_labels[0]
317 } else {
318 let idx = (simple_random(&mut rng_state) as usize) % max_labels.len();
320 max_labels[idx]
321 };
322
323 if new_label != labels[i] {
324 labels[i] = new_label;
325 changed = true;
326 }
327 }
328
329 if !changed {
330 converged = true;
331 break;
332 }
333 }
334
335 let mut community_map: HashMap<u64, u64> = HashMap::new();
337 let mut next_id = 0u64;
338 for label in &mut labels {
339 let new_id = *community_map.entry(*label).or_insert_with(|| {
340 let id = next_id;
341 next_id += 1;
342 id
343 });
344 *label = new_id;
345 }
346
347 let assignments: HashMap<EntityId, u64> = nodes.into_iter().zip(labels).collect();
349
350 let num_communities = community_map.len();
351
352 Ok(CommunityResult { assignments, iterations, converged, num_communities })
353 }
354
355 pub fn label_propagation_for_nodes<T: Transaction>(
365 tx: &T,
366 nodes: &[EntityId],
367 config: &CommunityDetectionConfig,
368 ) -> GraphResult<CommunityResult> {
369 let n = nodes.len();
370 if n == 0 {
371 return Ok(CommunityResult {
372 assignments: HashMap::new(),
373 iterations: 0,
374 converged: true,
375 num_communities: 0,
376 });
377 }
378
379 let node_set: std::collections::HashSet<EntityId> = nodes.iter().copied().collect();
381 let node_index: HashMap<EntityId, usize> =
382 nodes.iter().enumerate().map(|(i, &id)| (id, i)).collect();
383
384 let mut neighbors: Vec<Vec<usize>> = vec![Vec::new(); n];
386 for (i, &node) in nodes.iter().enumerate() {
387 let neighbor_ids = Self::get_neighbors(tx, node, config.direction)?;
388 for neighbor_id in neighbor_ids {
389 if node_set.contains(&neighbor_id) {
390 if let Some(&j) = node_index.get(&neighbor_id) {
391 neighbors[i].push(j);
392 }
393 }
394 }
395 }
396
397 let mut labels: Vec<u64> = (0..n as u64).collect();
399
400 let mut rng_state = config.seed.unwrap_or(12345);
402 let simple_random = |state: &mut u64| -> u64 {
403 *state = state
404 .wrapping_mul(6_364_136_223_846_793_005)
405 .wrapping_add(1_442_695_040_888_963_407);
406 *state
407 };
408
409 let mut iterations = 0;
410 let mut converged = false;
411 let mut order: Vec<usize> = (0..n).collect();
412
413 while iterations < config.max_iterations {
414 iterations += 1;
415
416 for i in (1..n).rev() {
418 let j = (simple_random(&mut rng_state) as usize) % (i + 1);
419 order.swap(i, j);
420 }
421
422 let mut changed = false;
423
424 for &i in &order {
425 if neighbors[i].is_empty() {
426 continue;
427 }
428
429 let mut label_counts: HashMap<u64, usize> = HashMap::new();
430 for &j in &neighbors[i] {
431 *label_counts.entry(labels[j]).or_insert(0) += 1;
432 }
433
434 let max_count = *label_counts.values().max().unwrap_or(&0);
435 let max_labels: Vec<u64> = label_counts
436 .iter()
437 .filter(|(_, &count)| count == max_count)
438 .map(|(&label, _)| label)
439 .collect();
440
441 let new_label = if max_labels.contains(&labels[i]) {
442 labels[i]
443 } else if max_labels.len() == 1 {
444 max_labels[0]
445 } else {
446 let idx = (simple_random(&mut rng_state) as usize) % max_labels.len();
447 max_labels[idx]
448 };
449
450 if new_label != labels[i] {
451 labels[i] = new_label;
452 changed = true;
453 }
454 }
455
456 if !changed {
457 converged = true;
458 break;
459 }
460 }
461
462 let mut community_map: HashMap<u64, u64> = HashMap::new();
464 let mut next_id = 0u64;
465 for label in &mut labels {
466 let new_id = *community_map.entry(*label).or_insert_with(|| {
467 let id = next_id;
468 next_id += 1;
469 id
470 });
471 *label = new_id;
472 }
473
474 let assignments: HashMap<EntityId, u64> = nodes.iter().copied().zip(labels).collect();
475
476 let num_communities = community_map.len();
477
478 Ok(CommunityResult { assignments, iterations, converged, num_communities })
479 }
480
481 fn get_neighbors<T: Transaction>(
483 tx: &T,
484 node: EntityId,
485 direction: Direction,
486 ) -> GraphResult<Vec<EntityId>> {
487 let mut neighbors = Vec::new();
488
489 if direction.includes_outgoing() {
490 let edges = AdjacencyIndex::get_outgoing_edge_ids(tx, node)?;
491 for edge_id in edges {
492 if let Some(edge) = EdgeStore::get(tx, edge_id)? {
493 neighbors.push(edge.target);
494 }
495 }
496 }
497
498 if direction.includes_incoming() {
499 let edges = AdjacencyIndex::get_incoming_edge_ids(tx, node)?;
500 for edge_id in edges {
501 if let Some(edge) = EdgeStore::get(tx, edge_id)? {
502 neighbors.push(edge.source);
503 }
504 }
505 }
506
507 Ok(neighbors)
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn config_defaults() {
517 let config = CommunityDetectionConfig::default();
518 assert_eq!(config.max_iterations, 100);
519 assert_eq!(config.direction, Direction::Both);
520 assert!(config.seed.is_none());
521 assert!((config.min_improvement - 0.0).abs() < f64::EPSILON);
522 }
523
524 #[test]
525 fn config_builder() {
526 let config = CommunityDetectionConfig::new()
527 .with_max_iterations(50)
528 .with_direction(Direction::Outgoing)
529 .with_seed(42)
530 .with_min_improvement(0.001);
531
532 assert_eq!(config.max_iterations, 50);
533 assert_eq!(config.direction, Direction::Outgoing);
534 assert_eq!(config.seed, Some(42));
535 assert!((config.min_improvement - 0.001).abs() < f64::EPSILON);
536 }
537
538 #[test]
539 fn result_empty() {
540 let result = CommunityResult {
541 assignments: HashMap::new(),
542 iterations: 0,
543 converged: true,
544 num_communities: 0,
545 };
546
547 assert!(result.community(EntityId::new(1)).is_none());
548 assert!(result.members(0).is_empty());
549 assert!(result.community_sizes().is_empty());
550 assert!(result.largest_community().is_none());
551 assert!(result.smallest_community().is_none());
552 }
553
554 #[test]
555 fn result_community_operations() {
556 let mut assignments = HashMap::new();
557 assignments.insert(EntityId::new(1), 0);
558 assignments.insert(EntityId::new(2), 0);
559 assignments.insert(EntityId::new(3), 1);
560 assignments.insert(EntityId::new(4), 1);
561 assignments.insert(EntityId::new(5), 1);
562
563 let result =
564 CommunityResult { assignments, iterations: 10, converged: true, num_communities: 2 };
565
566 assert_eq!(result.community(EntityId::new(1)), Some(0));
568 assert_eq!(result.community(EntityId::new(3)), Some(1));
569 assert_eq!(result.community(EntityId::new(99)), None);
570
571 let mut members_0 = result.members(0);
573 members_0.sort_by_key(|e| e.as_u64());
574 assert_eq!(members_0.len(), 2);
575
576 let mut members_1 = result.members(1);
577 members_1.sort_by_key(|e| e.as_u64());
578 assert_eq!(members_1.len(), 3);
579
580 let sizes = result.community_sizes();
582 assert_eq!(sizes.get(&0), Some(&2));
583 assert_eq!(sizes.get(&1), Some(&3));
584
585 assert_eq!(result.largest_community(), Some((1, 3)));
587 assert_eq!(result.smallest_community(), Some((0, 2)));
588
589 assert!(result.same_community(EntityId::new(1), EntityId::new(2)));
591 assert!(result.same_community(EntityId::new(3), EntityId::new(4)));
592 assert!(!result.same_community(EntityId::new(1), EntityId::new(3)));
593 }
594}