1use anyhow::{anyhow, Result};
8use scirs2_core::ndarray_ext::Array1;
9use scirs2_core::random::Random;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12use tracing::{debug, info};
13
14use crate::Triple;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum CommunityAlgorithm {
19 Louvain,
21 LabelPropagation,
23 GirvanNewman,
25 EmbeddingBased,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct CommunityConfig {
32 pub algorithm: CommunityAlgorithm,
34 pub max_iterations: usize,
36 pub resolution: f32,
38 pub min_community_size: usize,
40 pub similarity_threshold: f32,
42 pub random_seed: Option<u64>,
44}
45
46impl Default for CommunityConfig {
47 fn default() -> Self {
48 Self {
49 algorithm: CommunityAlgorithm::Louvain,
50 max_iterations: 100,
51 resolution: 1.0,
52 min_community_size: 2,
53 similarity_threshold: 0.7,
54 random_seed: None,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct CommunityResult {
62 pub assignments: HashMap<String, usize>,
64 pub num_communities: usize,
66 pub community_sizes: Vec<usize>,
68 pub modularity: f32,
70 pub coverage: f32,
72 pub communities: HashMap<usize, HashSet<String>>,
74}
75
76struct Graph {
78 edges: HashMap<String, HashSet<String>>,
80 weights: HashMap<(String, String), f32>,
82 num_edges: usize,
84}
85
86impl Graph {
87 fn new() -> Self {
88 Self {
89 edges: HashMap::new(),
90 weights: HashMap::new(),
91 num_edges: 0,
92 }
93 }
94
95 fn add_edge(&mut self, from: &str, to: &str, weight: f32) {
96 self.edges
97 .entry(from.to_string())
98 .or_default()
99 .insert(to.to_string());
100
101 self.edges
102 .entry(to.to_string())
103 .or_default()
104 .insert(from.to_string());
105
106 self.weights
107 .insert((from.to_string(), to.to_string()), weight);
108 self.weights
109 .insert((to.to_string(), from.to_string()), weight);
110
111 self.num_edges += 1;
112 }
113
114 fn get_neighbors(&self, node: &str) -> Option<&HashSet<String>> {
115 self.edges.get(node)
116 }
117
118 fn get_weight(&self, from: &str, to: &str) -> f32 {
119 self.weights
120 .get(&(from.to_string(), to.to_string()))
121 .copied()
122 .unwrap_or(1.0)
123 }
124
125 fn degree(&self, node: &str) -> usize {
126 self.edges.get(node).map(|s| s.len()).unwrap_or(0)
127 }
128
129 fn nodes(&self) -> Vec<String> {
130 self.edges.keys().cloned().collect()
131 }
132}
133
134pub struct CommunityDetector {
136 config: CommunityConfig,
137 rng: Random,
138}
139
140impl CommunityDetector {
141 pub fn new(config: CommunityConfig) -> Self {
143 let rng = Random::default();
144
145 Self { config, rng }
146 }
147
148 pub fn detect_from_triples(&mut self, triples: &[Triple]) -> Result<CommunityResult> {
150 let mut graph = Graph::new();
152
153 for triple in triples {
154 graph.add_edge(&triple.subject.to_string(), &triple.object.to_string(), 1.0);
156 }
157
158 info!(
159 "Detecting communities in graph with {} nodes and {} edges using {:?}",
160 graph.nodes().len(),
161 graph.num_edges,
162 self.config.algorithm
163 );
164
165 self.detect_from_graph(&graph)
166 }
167
168 pub fn detect_from_embeddings(
170 &mut self,
171 embeddings: &HashMap<String, Array1<f32>>,
172 ) -> Result<CommunityResult> {
173 info!("Detecting communities from {} embeddings", embeddings.len());
174
175 match self.config.algorithm {
176 CommunityAlgorithm::EmbeddingBased => self.embedding_based_detection(embeddings),
177 _ => {
178 let graph = self.build_similarity_graph(embeddings);
180 self.detect_from_graph(&graph)
181 }
182 }
183 }
184
185 fn detect_from_graph(&mut self, graph: &Graph) -> Result<CommunityResult> {
187 match self.config.algorithm {
188 CommunityAlgorithm::Louvain => self.louvain_detection(graph),
189 CommunityAlgorithm::LabelPropagation => self.label_propagation(graph),
190 CommunityAlgorithm::GirvanNewman => self.girvan_newman(graph),
191 CommunityAlgorithm::EmbeddingBased => {
192 Err(anyhow!("Embedding-based detection requires embeddings"))
193 }
194 }
195 }
196
197 fn louvain_detection(&mut self, graph: &Graph) -> Result<CommunityResult> {
199 let nodes = graph.nodes();
200 let m = graph.num_edges as f32;
201
202 let mut community: HashMap<String, usize> = nodes
204 .iter()
205 .enumerate()
206 .map(|(i, node)| (node.clone(), i))
207 .collect();
208
209 let mut improved = true;
210 let mut iteration = 0;
211
212 while improved && iteration < self.config.max_iterations {
213 improved = false;
214 iteration += 1;
215
216 for node in &nodes {
218 let current_comm = community[node];
219 let best_comm = self.find_best_community(node, current_comm, &community, graph, m);
220
221 if best_comm != current_comm {
222 community.insert(node.clone(), best_comm);
223 improved = true;
224 }
225 }
226
227 debug!("Louvain iteration {}: improved = {}", iteration, improved);
228 }
229
230 self.create_result(&community, graph)
231 }
232
233 fn find_best_community(
235 &self,
236 node: &str,
237 current_comm: usize,
238 community: &HashMap<String, usize>,
239 graph: &Graph,
240 m: f32,
241 ) -> usize {
242 let neighbors = match graph.get_neighbors(node) {
243 Some(n) => n,
244 None => return current_comm,
245 };
246
247 let mut neighbor_comms: HashSet<usize> = HashSet::new();
249 for neighbor in neighbors {
250 if let Some(&comm) = community.get(neighbor) {
251 neighbor_comms.insert(comm);
252 }
253 }
254
255 let current_modularity =
257 self.compute_modularity_contribution(node, current_comm, community, graph, m);
258
259 let mut best_comm = current_comm;
260 let mut best_modularity = current_modularity;
261
262 for &comm in &neighbor_comms {
263 if comm == current_comm {
264 continue;
265 }
266
267 let modularity = self.compute_modularity_contribution(node, comm, community, graph, m);
268
269 if modularity > best_modularity {
270 best_modularity = modularity;
271 best_comm = comm;
272 }
273 }
274
275 best_comm
276 }
277
278 fn compute_modularity_contribution(
280 &self,
281 node: &str,
282 comm: usize,
283 community: &HashMap<String, usize>,
284 graph: &Graph,
285 m: f32,
286 ) -> f32 {
287 let neighbors = match graph.get_neighbors(node) {
288 Some(n) => n,
289 None => return 0.0,
290 };
291
292 let k_i = graph.degree(node) as f32;
293
294 let mut e_ic = 0.0;
296 let mut k_c = 0.0;
297
298 for neighbor in neighbors {
299 if let Some(&neighbor_comm) = community.get(neighbor) {
300 if neighbor_comm == comm {
301 e_ic += graph.get_weight(node, neighbor);
302 k_c += graph.degree(neighbor) as f32;
303 }
304 }
305 }
306
307 (e_ic - (self.config.resolution * k_i * k_c) / (2.0 * m)) / m
308 }
309
310 fn label_propagation(&mut self, graph: &Graph) -> Result<CommunityResult> {
312 let nodes = graph.nodes();
313
314 let mut labels: HashMap<String, usize> = nodes
316 .iter()
317 .enumerate()
318 .map(|(i, node)| (node.clone(), i))
319 .collect();
320
321 for iteration in 0..self.config.max_iterations {
322 let mut changed = false;
323
324 let mut node_order = nodes.clone();
326 for i in (1..node_order.len()).rev() {
327 let j = self.rng.random_range(0..i + 1);
328 node_order.swap(i, j);
329 }
330
331 for node in &node_order {
333 let old_label = labels[node];
334 let new_label = self.majority_label(node, &labels, graph);
335
336 if new_label != old_label {
337 labels.insert(node.clone(), new_label);
338 changed = true;
339 }
340 }
341
342 debug!(
343 "Label propagation iteration {}: changed = {}",
344 iteration + 1,
345 changed
346 );
347
348 if !changed {
349 info!("Label propagation converged at iteration {}", iteration + 1);
350 break;
351 }
352 }
353
354 self.create_result(&labels, graph)
355 }
356
357 fn majority_label(&self, node: &str, labels: &HashMap<String, usize>, graph: &Graph) -> usize {
359 let neighbors = match graph.get_neighbors(node) {
360 Some(n) => n,
361 None => return labels[node],
362 };
363
364 let mut label_counts: HashMap<usize, usize> = HashMap::new();
365
366 for neighbor in neighbors {
367 if let Some(&label) = labels.get(neighbor) {
368 *label_counts.entry(label).or_insert(0) += 1;
369 }
370 }
371
372 label_counts
374 .into_iter()
375 .max_by_key(|(_, count)| *count)
376 .map(|(label, _)| label)
377 .unwrap_or_else(|| labels[node])
378 }
379
380 fn girvan_newman(&mut self, graph: &Graph) -> Result<CommunityResult> {
382 let nodes = graph.nodes();
386 let mut assignments: HashMap<String, usize> = HashMap::new();
387
388 let mut visited = HashSet::new();
390 let mut community_id = 0;
391
392 for node in &nodes {
393 if visited.contains(node) {
394 continue;
395 }
396
397 let component = self.bfs_component(node, graph, &visited);
399
400 for comp_node in &component {
401 assignments.insert(comp_node.clone(), community_id);
402 visited.insert(comp_node.clone());
403 }
404
405 community_id += 1;
406 }
407
408 self.create_result(&assignments, graph)
409 }
410
411 fn bfs_component(
413 &self,
414 start: &str,
415 graph: &Graph,
416 visited: &HashSet<String>,
417 ) -> HashSet<String> {
418 let mut component = HashSet::new();
419 let mut queue = VecDeque::new();
420 queue.push_back(start.to_string());
421 component.insert(start.to_string());
422
423 while let Some(node) = queue.pop_front() {
424 if let Some(neighbors) = graph.get_neighbors(&node) {
425 for neighbor in neighbors {
426 if !visited.contains(neighbor) && !component.contains(neighbor) {
427 component.insert(neighbor.clone());
428 queue.push_back(neighbor.clone());
429 }
430 }
431 }
432 }
433
434 component
435 }
436
437 fn embedding_based_detection(
439 &mut self,
440 embeddings: &HashMap<String, Array1<f32>>,
441 ) -> Result<CommunityResult> {
442 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
443 let mut assignments: HashMap<String, usize> = HashMap::new();
444 let mut community_id = 0;
445
446 let mut unassigned: HashSet<String> = entity_list.iter().cloned().collect();
447
448 while !unassigned.is_empty() {
449 let seed = unassigned.iter().next().unwrap().clone();
451 let mut community = HashSet::new();
452 community.insert(seed.clone());
453 unassigned.remove(&seed);
454
455 let mut changed = true;
457 while changed {
458 changed = false;
459
460 for entity in &entity_list {
461 if community.contains(entity) || !unassigned.contains(entity) {
462 continue;
463 }
464
465 let avg_similarity =
467 self.average_similarity_to_community(entity, &community, embeddings);
468
469 if avg_similarity >= self.config.similarity_threshold {
470 community.insert(entity.clone());
471 unassigned.remove(entity);
472 changed = true;
473 }
474 }
475 }
476
477 if community.len() >= self.config.min_community_size {
479 for member in community {
480 assignments.insert(member, community_id);
481 }
482 community_id += 1;
483 } else {
484 for member in community {
486 assignments.insert(member, usize::MAX);
487 }
488 }
489 }
490
491 let mut graph = Graph::new();
493 for entity in &entity_list {
494 graph.edges.insert(entity.clone(), HashSet::new());
495 }
496
497 self.create_result(&assignments, &graph)
498 }
499
500 fn average_similarity_to_community(
502 &self,
503 entity: &str,
504 community: &HashSet<String>,
505 embeddings: &HashMap<String, Array1<f32>>,
506 ) -> f32 {
507 if community.is_empty() {
508 return 0.0;
509 }
510
511 let entity_emb = &embeddings[entity];
512
513 let total_sim: f32 = community
514 .iter()
515 .map(|member| {
516 let member_emb = &embeddings[member];
517 self.cosine_similarity(entity_emb, member_emb)
518 })
519 .sum();
520
521 total_sim / community.len() as f32
522 }
523
524 fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
526 let dot = a.dot(b);
527 let norm_a = a.dot(a).sqrt();
528 let norm_b = b.dot(b).sqrt();
529
530 if norm_a == 0.0 || norm_b == 0.0 {
531 0.0
532 } else {
533 dot / (norm_a * norm_b)
534 }
535 }
536
537 fn build_similarity_graph(&self, embeddings: &HashMap<String, Array1<f32>>) -> Graph {
539 let mut graph = Graph::new();
540 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
541
542 for i in 0..entity_list.len() {
543 for j in (i + 1)..entity_list.len() {
544 let sim = self
545 .cosine_similarity(&embeddings[&entity_list[i]], &embeddings[&entity_list[j]]);
546
547 if sim >= self.config.similarity_threshold {
548 graph.add_edge(&entity_list[i], &entity_list[j], sim);
549 }
550 }
551 }
552
553 graph
554 }
555
556 fn create_result(
558 &self,
559 assignments: &HashMap<String, usize>,
560 graph: &Graph,
561 ) -> Result<CommunityResult> {
562 let mut community_sizes: HashMap<usize, usize> = HashMap::new();
564 let mut communities: HashMap<usize, HashSet<String>> = HashMap::new();
565
566 for (entity, &comm) in assignments {
567 if comm != usize::MAX {
568 *community_sizes.entry(comm).or_insert(0) += 1;
569 communities.entry(comm).or_default().insert(entity.clone());
570 }
571 }
572
573 let num_communities = community_sizes.len();
574 let sizes: Vec<usize> = (0..num_communities)
575 .map(|i| community_sizes.get(&i).copied().unwrap_or(0))
576 .collect();
577
578 let modularity = self.compute_modularity(assignments, graph);
580
581 let coverage = self.compute_coverage(assignments, graph);
583
584 Ok(CommunityResult {
585 assignments: assignments.clone(),
586 num_communities,
587 community_sizes: sizes,
588 modularity,
589 coverage,
590 communities,
591 })
592 }
593
594 fn compute_modularity(&self, assignments: &HashMap<String, usize>, graph: &Graph) -> f32 {
596 let m = graph.num_edges as f32;
597 if m == 0.0 {
598 return 0.0;
599 }
600
601 let nodes = graph.nodes();
602 let mut modularity = 0.0;
603
604 for node_i in &nodes {
605 for node_j in &nodes {
606 if let (Some(&comm_i), Some(&comm_j)) =
607 (assignments.get(node_i), assignments.get(node_j))
608 {
609 if comm_i == comm_j && comm_i != usize::MAX {
610 let a_ij = if graph
611 .get_neighbors(node_i)
612 .map(|n| n.contains(node_j))
613 .unwrap_or(false)
614 {
615 1.0
616 } else {
617 0.0
618 };
619
620 let k_i = graph.degree(node_i) as f32;
621 let k_j = graph.degree(node_j) as f32;
622
623 modularity += a_ij - (k_i * k_j) / (2.0 * m);
624 }
625 }
626 }
627 }
628
629 modularity / (2.0 * m)
630 }
631
632 fn compute_coverage(&self, assignments: &HashMap<String, usize>, graph: &Graph) -> f32 {
634 if graph.num_edges == 0 {
635 return 0.0;
636 }
637
638 let mut internal_edges = 0;
639
640 for (node, neighbors) in &graph.edges {
641 if let Some(&comm) = assignments.get(node) {
642 if comm == usize::MAX {
643 continue;
644 }
645
646 for neighbor in neighbors {
647 if let Some(&neighbor_comm) = assignments.get(neighbor) {
648 if comm == neighbor_comm {
649 internal_edges += 1;
650 }
651 }
652 }
653 }
654 }
655
656 (internal_edges / 2) as f32 / graph.num_edges as f32
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664 use crate::NamedNode;
665 use scirs2_core::ndarray_ext::array;
666
667 #[test]
668 fn test_community_detection_from_triples() {
669 let triples = vec![
670 Triple::new(
671 NamedNode::new("a").unwrap(),
672 NamedNode::new("r").unwrap(),
673 NamedNode::new("b").unwrap(),
674 ),
675 Triple::new(
676 NamedNode::new("b").unwrap(),
677 NamedNode::new("r").unwrap(),
678 NamedNode::new("c").unwrap(),
679 ),
680 Triple::new(
681 NamedNode::new("d").unwrap(),
682 NamedNode::new("r").unwrap(),
683 NamedNode::new("e").unwrap(),
684 ),
685 ];
686
687 let config = CommunityConfig::default();
688 let mut detector = CommunityDetector::new(config);
689 let result = detector.detect_from_triples(&triples).unwrap();
690
691 assert!(result.num_communities > 0);
692 assert_eq!(result.assignments.len(), 5); }
694
695 #[test]
696 fn test_embedding_based_detection() {
697 let mut embeddings = HashMap::new();
698 embeddings.insert("e1".to_string(), array![1.0, 0.0]);
699 embeddings.insert("e2".to_string(), array![0.9, 0.1]);
700 embeddings.insert("e3".to_string(), array![0.0, 1.0]);
701 embeddings.insert("e4".to_string(), array![0.1, 0.9]);
702
703 let config = CommunityConfig {
704 algorithm: CommunityAlgorithm::EmbeddingBased,
705 similarity_threshold: 0.8,
706 ..Default::default()
707 };
708
709 let mut detector = CommunityDetector::new(config);
710 let result = detector.detect_from_embeddings(&embeddings).unwrap();
711
712 assert!(result.num_communities >= 1);
713 assert_eq!(result.assignments.get("e1"), result.assignments.get("e2"));
715 }
716}