1use anyhow::{anyhow, Result};
8use scirs2_core::ndarray_ext::Array1;
9use scirs2_core::random::Random;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12use tracing::{debug, info};
13
14use crate::Triple;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum CommunityAlgorithm {
19 Louvain,
21 LabelPropagation,
23 GirvanNewman,
25 EmbeddingBased,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct CommunityConfig {
32 pub algorithm: CommunityAlgorithm,
34 pub max_iterations: usize,
36 pub resolution: f32,
38 pub min_community_size: usize,
40 pub similarity_threshold: f32,
42 pub random_seed: Option<u64>,
44}
45
46impl Default for CommunityConfig {
47 fn default() -> Self {
48 Self {
49 algorithm: CommunityAlgorithm::Louvain,
50 max_iterations: 100,
51 resolution: 1.0,
52 min_community_size: 2,
53 similarity_threshold: 0.7,
54 random_seed: None,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct CommunityResult {
62 pub assignments: HashMap<String, usize>,
64 pub num_communities: usize,
66 pub community_sizes: Vec<usize>,
68 pub modularity: f32,
70 pub coverage: f32,
72 pub communities: HashMap<usize, HashSet<String>>,
74}
75
76struct Graph {
78 edges: HashMap<String, HashSet<String>>,
80 weights: HashMap<(String, String), f32>,
82 num_edges: usize,
84}
85
86impl Graph {
87 fn new() -> Self {
88 Self {
89 edges: HashMap::new(),
90 weights: HashMap::new(),
91 num_edges: 0,
92 }
93 }
94
95 fn add_edge(&mut self, from: &str, to: &str, weight: f32) {
96 self.edges
97 .entry(from.to_string())
98 .or_default()
99 .insert(to.to_string());
100
101 self.edges
102 .entry(to.to_string())
103 .or_default()
104 .insert(from.to_string());
105
106 self.weights
107 .insert((from.to_string(), to.to_string()), weight);
108 self.weights
109 .insert((to.to_string(), from.to_string()), weight);
110
111 self.num_edges += 1;
112 }
113
114 fn get_neighbors(&self, node: &str) -> Option<&HashSet<String>> {
115 self.edges.get(node)
116 }
117
118 fn get_weight(&self, from: &str, to: &str) -> f32 {
119 self.weights
120 .get(&(from.to_string(), to.to_string()))
121 .copied()
122 .unwrap_or(1.0)
123 }
124
125 fn degree(&self, node: &str) -> usize {
126 self.edges.get(node).map(|s| s.len()).unwrap_or(0)
127 }
128
129 fn nodes(&self) -> Vec<String> {
130 self.edges.keys().cloned().collect()
131 }
132}
133
134pub struct CommunityDetector {
136 config: CommunityConfig,
137 rng: Random,
138}
139
140impl CommunityDetector {
141 pub fn new(config: CommunityConfig) -> Self {
143 let rng = Random::default();
144
145 Self { config, rng }
146 }
147
148 pub fn detect_from_triples(&mut self, triples: &[Triple]) -> Result<CommunityResult> {
150 let mut graph = Graph::new();
152
153 for triple in triples {
154 graph.add_edge(&triple.subject.to_string(), &triple.object.to_string(), 1.0);
156 }
157
158 info!(
159 "Detecting communities in graph with {} nodes and {} edges using {:?}",
160 graph.nodes().len(),
161 graph.num_edges,
162 self.config.algorithm
163 );
164
165 self.detect_from_graph(&graph)
166 }
167
168 pub fn detect_from_embeddings(
170 &mut self,
171 embeddings: &HashMap<String, Array1<f32>>,
172 ) -> Result<CommunityResult> {
173 info!("Detecting communities from {} embeddings", embeddings.len());
174
175 match self.config.algorithm {
176 CommunityAlgorithm::EmbeddingBased => self.embedding_based_detection(embeddings),
177 _ => {
178 let graph = self.build_similarity_graph(embeddings);
180 self.detect_from_graph(&graph)
181 }
182 }
183 }
184
185 fn detect_from_graph(&mut self, graph: &Graph) -> Result<CommunityResult> {
187 match self.config.algorithm {
188 CommunityAlgorithm::Louvain => self.louvain_detection(graph),
189 CommunityAlgorithm::LabelPropagation => self.label_propagation(graph),
190 CommunityAlgorithm::GirvanNewman => self.girvan_newman(graph),
191 CommunityAlgorithm::EmbeddingBased => {
192 Err(anyhow!("Embedding-based detection requires embeddings"))
193 }
194 }
195 }
196
197 fn louvain_detection(&mut self, graph: &Graph) -> Result<CommunityResult> {
199 let nodes = graph.nodes();
200 let m = graph.num_edges as f32;
201
202 let mut community: HashMap<String, usize> = nodes
204 .iter()
205 .enumerate()
206 .map(|(i, node)| (node.clone(), i))
207 .collect();
208
209 let mut improved = true;
210 let mut iteration = 0;
211
212 while improved && iteration < self.config.max_iterations {
213 improved = false;
214 iteration += 1;
215
216 for node in &nodes {
218 let current_comm = community[node];
219 let best_comm = self.find_best_community(node, current_comm, &community, graph, m);
220
221 if best_comm != current_comm {
222 community.insert(node.clone(), best_comm);
223 improved = true;
224 }
225 }
226
227 debug!("Louvain iteration {}: improved = {}", iteration, improved);
228 }
229
230 self.create_result(&community, graph)
231 }
232
233 fn find_best_community(
235 &self,
236 node: &str,
237 current_comm: usize,
238 community: &HashMap<String, usize>,
239 graph: &Graph,
240 m: f32,
241 ) -> usize {
242 let neighbors = match graph.get_neighbors(node) {
243 Some(n) => n,
244 None => return current_comm,
245 };
246
247 let mut neighbor_comms: HashSet<usize> = HashSet::new();
249 for neighbor in neighbors {
250 if let Some(&comm) = community.get(neighbor) {
251 neighbor_comms.insert(comm);
252 }
253 }
254
255 let current_modularity =
257 self.compute_modularity_contribution(node, current_comm, community, graph, m);
258
259 let mut best_comm = current_comm;
260 let mut best_modularity = current_modularity;
261
262 for &comm in &neighbor_comms {
263 if comm == current_comm {
264 continue;
265 }
266
267 let modularity = self.compute_modularity_contribution(node, comm, community, graph, m);
268
269 if modularity > best_modularity {
270 best_modularity = modularity;
271 best_comm = comm;
272 }
273 }
274
275 best_comm
276 }
277
278 fn compute_modularity_contribution(
280 &self,
281 node: &str,
282 comm: usize,
283 community: &HashMap<String, usize>,
284 graph: &Graph,
285 m: f32,
286 ) -> f32 {
287 let neighbors = match graph.get_neighbors(node) {
288 Some(n) => n,
289 None => return 0.0,
290 };
291
292 let k_i = graph.degree(node) as f32;
293
294 let mut e_ic = 0.0;
296 let mut k_c = 0.0;
297
298 for neighbor in neighbors {
299 if let Some(&neighbor_comm) = community.get(neighbor) {
300 if neighbor_comm == comm {
301 e_ic += graph.get_weight(node, neighbor);
302 k_c += graph.degree(neighbor) as f32;
303 }
304 }
305 }
306
307 (e_ic - (self.config.resolution * k_i * k_c) / (2.0 * m)) / m
308 }
309
310 fn label_propagation(&mut self, graph: &Graph) -> Result<CommunityResult> {
312 let nodes = graph.nodes();
313
314 let mut labels: HashMap<String, usize> = nodes
316 .iter()
317 .enumerate()
318 .map(|(i, node)| (node.clone(), i))
319 .collect();
320
321 for iteration in 0..self.config.max_iterations {
322 let mut changed = false;
323
324 let mut node_order = nodes.clone();
326 for i in (1..node_order.len()).rev() {
327 let j = self.rng.random_range(0..i + 1);
328 node_order.swap(i, j);
329 }
330
331 for node in &node_order {
333 let old_label = labels[node];
334 let new_label = self.majority_label(node, &labels, graph);
335
336 if new_label != old_label {
337 labels.insert(node.clone(), new_label);
338 changed = true;
339 }
340 }
341
342 debug!(
343 "Label propagation iteration {}: changed = {}",
344 iteration + 1,
345 changed
346 );
347
348 if !changed {
349 info!("Label propagation converged at iteration {}", iteration + 1);
350 break;
351 }
352 }
353
354 self.create_result(&labels, graph)
355 }
356
357 fn majority_label(&self, node: &str, labels: &HashMap<String, usize>, graph: &Graph) -> usize {
359 let neighbors = match graph.get_neighbors(node) {
360 Some(n) => n,
361 None => return labels[node],
362 };
363
364 let mut label_counts: HashMap<usize, usize> = HashMap::new();
365
366 for neighbor in neighbors {
367 if let Some(&label) = labels.get(neighbor) {
368 *label_counts.entry(label).or_insert(0) += 1;
369 }
370 }
371
372 label_counts
374 .into_iter()
375 .max_by_key(|(_, count)| *count)
376 .map(|(label, _)| label)
377 .unwrap_or_else(|| labels[node])
378 }
379
380 fn girvan_newman(&mut self, graph: &Graph) -> Result<CommunityResult> {
382 let nodes = graph.nodes();
386 let mut assignments: HashMap<String, usize> = HashMap::new();
387
388 let mut visited = HashSet::new();
390 let mut community_id = 0;
391
392 for node in &nodes {
393 if visited.contains(node) {
394 continue;
395 }
396
397 let component = self.bfs_component(node, graph, &visited);
399
400 for comp_node in &component {
401 assignments.insert(comp_node.clone(), community_id);
402 visited.insert(comp_node.clone());
403 }
404
405 community_id += 1;
406 }
407
408 self.create_result(&assignments, graph)
409 }
410
411 fn bfs_component(
413 &self,
414 start: &str,
415 graph: &Graph,
416 visited: &HashSet<String>,
417 ) -> HashSet<String> {
418 let mut component = HashSet::new();
419 let mut queue = VecDeque::new();
420 queue.push_back(start.to_string());
421 component.insert(start.to_string());
422
423 while let Some(node) = queue.pop_front() {
424 if let Some(neighbors) = graph.get_neighbors(&node) {
425 for neighbor in neighbors {
426 if !visited.contains(neighbor) && !component.contains(neighbor) {
427 component.insert(neighbor.clone());
428 queue.push_back(neighbor.clone());
429 }
430 }
431 }
432 }
433
434 component
435 }
436
437 fn embedding_based_detection(
439 &mut self,
440 embeddings: &HashMap<String, Array1<f32>>,
441 ) -> Result<CommunityResult> {
442 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
443 let mut assignments: HashMap<String, usize> = HashMap::new();
444 let mut community_id = 0;
445
446 let mut unassigned: HashSet<String> = entity_list.iter().cloned().collect();
447
448 while !unassigned.is_empty() {
449 let seed = unassigned
451 .iter()
452 .next()
453 .expect("unassigned should not be empty")
454 .clone();
455 let mut community = HashSet::new();
456 community.insert(seed.clone());
457 unassigned.remove(&seed);
458
459 let mut changed = true;
461 while changed {
462 changed = false;
463
464 for entity in &entity_list {
465 if community.contains(entity) || !unassigned.contains(entity) {
466 continue;
467 }
468
469 let avg_similarity =
471 self.average_similarity_to_community(entity, &community, embeddings);
472
473 if avg_similarity >= self.config.similarity_threshold {
474 community.insert(entity.clone());
475 unassigned.remove(entity);
476 changed = true;
477 }
478 }
479 }
480
481 if community.len() >= self.config.min_community_size {
483 for member in community {
484 assignments.insert(member, community_id);
485 }
486 community_id += 1;
487 } else {
488 for member in community {
490 assignments.insert(member, usize::MAX);
491 }
492 }
493 }
494
495 let mut graph = Graph::new();
497 for entity in &entity_list {
498 graph.edges.insert(entity.clone(), HashSet::new());
499 }
500
501 self.create_result(&assignments, &graph)
502 }
503
504 fn average_similarity_to_community(
506 &self,
507 entity: &str,
508 community: &HashSet<String>,
509 embeddings: &HashMap<String, Array1<f32>>,
510 ) -> f32 {
511 if community.is_empty() {
512 return 0.0;
513 }
514
515 let entity_emb = &embeddings[entity];
516
517 let total_sim: f32 = community
518 .iter()
519 .map(|member| {
520 let member_emb = &embeddings[member];
521 self.cosine_similarity(entity_emb, member_emb)
522 })
523 .sum();
524
525 total_sim / community.len() as f32
526 }
527
528 fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
530 let dot = a.dot(b);
531 let norm_a = a.dot(a).sqrt();
532 let norm_b = b.dot(b).sqrt();
533
534 if norm_a == 0.0 || norm_b == 0.0 {
535 0.0
536 } else {
537 dot / (norm_a * norm_b)
538 }
539 }
540
541 fn build_similarity_graph(&self, embeddings: &HashMap<String, Array1<f32>>) -> Graph {
543 let mut graph = Graph::new();
544 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
545
546 for i in 0..entity_list.len() {
547 for j in (i + 1)..entity_list.len() {
548 let sim = self
549 .cosine_similarity(&embeddings[&entity_list[i]], &embeddings[&entity_list[j]]);
550
551 if sim >= self.config.similarity_threshold {
552 graph.add_edge(&entity_list[i], &entity_list[j], sim);
553 }
554 }
555 }
556
557 graph
558 }
559
560 fn create_result(
562 &self,
563 assignments: &HashMap<String, usize>,
564 graph: &Graph,
565 ) -> Result<CommunityResult> {
566 let mut community_sizes: HashMap<usize, usize> = HashMap::new();
568 let mut communities: HashMap<usize, HashSet<String>> = HashMap::new();
569
570 for (entity, &comm) in assignments {
571 if comm != usize::MAX {
572 *community_sizes.entry(comm).or_insert(0) += 1;
573 communities.entry(comm).or_default().insert(entity.clone());
574 }
575 }
576
577 let num_communities = community_sizes.len();
578 let sizes: Vec<usize> = (0..num_communities)
579 .map(|i| community_sizes.get(&i).copied().unwrap_or(0))
580 .collect();
581
582 let modularity = self.compute_modularity(assignments, graph);
584
585 let coverage = self.compute_coverage(assignments, graph);
587
588 Ok(CommunityResult {
589 assignments: assignments.clone(),
590 num_communities,
591 community_sizes: sizes,
592 modularity,
593 coverage,
594 communities,
595 })
596 }
597
598 fn compute_modularity(&self, assignments: &HashMap<String, usize>, graph: &Graph) -> f32 {
600 let m = graph.num_edges as f32;
601 if m == 0.0 {
602 return 0.0;
603 }
604
605 let nodes = graph.nodes();
606 let mut modularity = 0.0;
607
608 for node_i in &nodes {
609 for node_j in &nodes {
610 if let (Some(&comm_i), Some(&comm_j)) =
611 (assignments.get(node_i), assignments.get(node_j))
612 {
613 if comm_i == comm_j && comm_i != usize::MAX {
614 let a_ij = if graph
615 .get_neighbors(node_i)
616 .map(|n| n.contains(node_j))
617 .unwrap_or(false)
618 {
619 1.0
620 } else {
621 0.0
622 };
623
624 let k_i = graph.degree(node_i) as f32;
625 let k_j = graph.degree(node_j) as f32;
626
627 modularity += a_ij - (k_i * k_j) / (2.0 * m);
628 }
629 }
630 }
631 }
632
633 modularity / (2.0 * m)
634 }
635
636 fn compute_coverage(&self, assignments: &HashMap<String, usize>, graph: &Graph) -> f32 {
638 if graph.num_edges == 0 {
639 return 0.0;
640 }
641
642 let mut internal_edges = 0;
643
644 for (node, neighbors) in &graph.edges {
645 if let Some(&comm) = assignments.get(node) {
646 if comm == usize::MAX {
647 continue;
648 }
649
650 for neighbor in neighbors {
651 if let Some(&neighbor_comm) = assignments.get(neighbor) {
652 if comm == neighbor_comm {
653 internal_edges += 1;
654 }
655 }
656 }
657 }
658 }
659
660 (internal_edges / 2) as f32 / graph.num_edges as f32
662 }
663}
664
665#[cfg(test)]
666mod tests {
667 use super::*;
668 use crate::NamedNode;
669 use scirs2_core::ndarray_ext::array;
670
671 #[test]
672 fn test_community_detection_from_triples() {
673 let triples = vec![
674 Triple::new(
675 NamedNode::new("a").unwrap(),
676 NamedNode::new("r").unwrap(),
677 NamedNode::new("b").unwrap(),
678 ),
679 Triple::new(
680 NamedNode::new("b").unwrap(),
681 NamedNode::new("r").unwrap(),
682 NamedNode::new("c").unwrap(),
683 ),
684 Triple::new(
685 NamedNode::new("d").unwrap(),
686 NamedNode::new("r").unwrap(),
687 NamedNode::new("e").unwrap(),
688 ),
689 ];
690
691 let config = CommunityConfig::default();
692 let mut detector = CommunityDetector::new(config);
693 let result = detector.detect_from_triples(&triples).unwrap();
694
695 assert!(result.num_communities > 0);
696 assert_eq!(result.assignments.len(), 5); }
698
699 #[test]
700 fn test_embedding_based_detection() {
701 let mut embeddings = HashMap::new();
702 embeddings.insert("e1".to_string(), array![1.0, 0.0]);
703 embeddings.insert("e2".to_string(), array![0.9, 0.1]);
704 embeddings.insert("e3".to_string(), array![0.0, 1.0]);
705 embeddings.insert("e4".to_string(), array![0.1, 0.9]);
706
707 let config = CommunityConfig {
708 algorithm: CommunityAlgorithm::EmbeddingBased,
709 similarity_threshold: 0.8,
710 ..Default::default()
711 };
712
713 let mut detector = CommunityDetector::new(config);
714 let result = detector.detect_from_embeddings(&embeddings).unwrap();
715
716 assert!(result.num_communities >= 1);
717 assert_eq!(result.assignments.get("e1"), result.assignments.get("e2"));
719 }
720}