1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
9use scirs2_core::numeric::{Float, FromPrimitive};
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::fmt::Debug;
12
13use serde::{Deserialize, Serialize};
14
15use crate::error::{ClusteringError, Result};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Graph<F: Float> {
20 pub n_nodes: usize,
22 pub adjacency: Vec<Vec<(usize, F)>>,
24 pub node_features: Option<Array2<F>>,
26}
27
28impl<
29 F: Float
30 + FromPrimitive
31 + Debug
32 + ScalarOperand
33 + std::iter::Sum
34 + std::cmp::Eq
35 + std::hash::Hash
36 + 'static,
37 > Graph<F>
38{
39 pub fn new(_nnodes: usize) -> Self {
41 Self {
42 n_nodes: _nnodes,
43 adjacency: vec![Vec::new(); _nnodes],
44 node_features: None,
45 }
46 }
47
48 pub fn from_adjacencymatrix(_adjacencymatrix: ArrayView2<F>) -> Result<Self> {
50 let n_nodes = _adjacencymatrix.shape()[0];
51 if _adjacencymatrix.shape()[1] != n_nodes {
52 return Err(ClusteringError::InvalidInput(
53 "Adjacency _matrix must be square".to_string(),
54 ));
55 }
56
57 let mut graph = Self::new(n_nodes);
58
59 for i in 0..n_nodes {
60 for j in 0..n_nodes {
61 let weight = _adjacencymatrix[[i, j]];
62 if weight > F::zero() && i != j {
63 graph.add_edge(i, j, weight)?;
64 }
65 }
66 }
67
68 Ok(graph)
69 }
70
71 pub fn from_knngraph(data: ArrayView2<F>, k: usize) -> Result<Self> {
73 let n_samples = data.shape()[0];
74 let mut graph = Self::new(n_samples);
75 graph.node_features = Some(data.to_owned());
76
77 for i in 0..n_samples {
79 let mut distances: Vec<(usize, F)> = Vec::new();
80
81 for j in 0..n_samples {
82 if i != j {
83 let dist = euclidean_distance(data.row(i), data.row(j));
84 distances.push((j, dist));
85 }
86 }
87
88 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
90
91 for &(neighbor_idx, distance) in distances.iter().take(k) {
92 let similarity = F::one() / (F::one() + distance);
94 graph.add_edge(i, neighbor_idx, similarity)?;
95 }
96 }
97
98 Ok(graph)
99 }
100
101 pub fn add_edge(&mut self, node1: usize, node2: usize, weight: F) -> Result<()> {
103 if node1 >= self.n_nodes || node2 >= self.n_nodes {
104 return Err(ClusteringError::InvalidInput(
105 "Node index out of bounds".to_string(),
106 ));
107 }
108
109 if node1 != node2 {
110 self.adjacency[node1].push((node2, weight));
111 self.adjacency[node2].push((node1, weight)); }
113
114 Ok(())
115 }
116
117 pub fn degree(&self, node: usize) -> usize {
119 if node < self.n_nodes {
120 self.adjacency[node].len()
121 } else {
122 0
123 }
124 }
125
126 pub fn weighted_degree(&self, node: usize) -> F {
128 if node < self.n_nodes {
129 self.adjacency[node].iter().map(|(_, weight)| *weight).sum()
130 } else {
131 F::zero()
132 }
133 }
134
135 pub fn neighbor_s(&self, node: usize) -> &[(usize, F)] {
137 if node < self.n_nodes {
138 &self.adjacency[node]
139 } else {
140 &[]
141 }
142 }
143
144 pub fn modularity(&self, communities: &[usize]) -> F {
146 let total_weight = self.total_edge_weight();
147 if total_weight == F::zero() {
148 return F::zero();
149 }
150
151 let mut modularity = F::zero();
152
153 for i in 0..self.n_nodes {
154 for j in 0..self.n_nodes {
155 if communities[i] == communities[j] {
156 let edge_weight = self.get_edge_weight(i, j);
157 let degree_i = self.weighted_degree(i);
158 let degree_j = self.weighted_degree(j);
159
160 let expected = degree_i * degree_j
161 / (F::from(2.0).expect("Failed to convert constant to float")
162 * total_weight);
163 modularity = modularity + edge_weight - expected;
164 }
165 }
166 }
167
168 modularity / (F::from(2.0).expect("Failed to convert constant to float") * total_weight)
169 }
170
171 fn get_edge_weight(&self, node1: usize, node2: usize) -> F {
173 if node1 < self.n_nodes {
174 for &(neighbor_, weight) in &self.adjacency[node1] {
175 if neighbor_ == node2 {
176 return weight;
177 }
178 }
179 }
180 F::zero()
181 }
182
183 fn total_edge_weight(&self) -> F {
185 let mut total = F::zero();
186 for node in 0..self.n_nodes {
187 for &(_, weight) in &self.adjacency[node] {
188 total = total + weight;
189 }
190 }
191 total / F::from(2.0).expect("Failed to convert constant to float") }
193}
194
195#[allow(dead_code)]
232pub fn louvain<F>(graph: &Graph<F>, resolution: f64, max_iterations: usize) -> Result<Array1<usize>>
233where
234 F: Float
235 + FromPrimitive
236 + Debug
237 + ScalarOperand
238 + std::iter::Sum
239 + std::cmp::Eq
240 + std::hash::Hash
241 + 'static,
242 f64: From<F>,
243{
244 let n_nodes = graph.n_nodes;
245 let mut communities: Array1<usize> = Array1::from_iter(0..n_nodes);
246 let mut improved = true;
247 let mut iteration = 0;
248
249 while improved && iteration < max_iterations {
250 improved = false;
251 iteration += 1;
252
253 for node in 0..n_nodes {
255 let current_community = communities[node];
256 let mut best_community = current_community;
257 let mut best_gain = F::zero();
258
259 let mut candidate_communities = HashSet::new();
261 candidate_communities.insert(current_community);
262
263 for &(neighbor_id, _weight) in graph.neighbor_s(node) {
264 candidate_communities.insert(communities[neighbor_id]);
265 }
266
267 for &candidate_community in &candidate_communities {
268 if candidate_community != current_community {
269 let gain = modularity_gain(
271 graph,
272 &communities,
273 node,
274 current_community,
275 candidate_community,
276 resolution,
277 );
278
279 if gain > best_gain {
280 best_gain = gain;
281 best_community = candidate_community;
282 }
283 }
284 }
285
286 if best_community != current_community && best_gain > F::zero() {
288 communities[node] = best_community;
289 improved = true;
290 }
291 }
292 }
293
294 Ok(communities)
295}
296
297#[allow(dead_code)]
299fn modularity_gain<F>(
300 graph: &Graph<F>,
301 communities: &Array1<usize>,
302 node: usize,
303 from_community: usize,
304 to_community: usize,
305 resolution: f64,
306) -> F
307where
308 F: Float
309 + FromPrimitive
310 + Debug
311 + ScalarOperand
312 + std::iter::Sum
313 + std::cmp::Eq
314 + std::hash::Hash
315 + 'static,
316 f64: From<F>,
317{
318 let total_weight = graph.total_edge_weight();
319 if total_weight == F::zero() {
320 return F::zero();
321 }
322
323 let node_degree = graph.weighted_degree(node);
324 let resolution_f = F::from(resolution).expect("Failed to convert to float");
325
326 let mut edges_to_target = F::zero();
328 let mut edges_from_source = F::zero();
329
330 for &(neighbor_, weight) in graph.neighbor_s(node) {
331 if communities[neighbor_] == to_community {
332 edges_to_target = edges_to_target + weight;
333 }
334 if communities[neighbor_] == from_community && neighbor_ != node {
335 edges_from_source = edges_from_source + weight;
336 }
337 }
338
339 let target_community_weight = calculate_community_weight(graph, communities, to_community);
341 let source_community_weight = calculate_community_weight(graph, communities, from_community);
342
343 let gain_to = edges_to_target
345 - resolution_f * node_degree * target_community_weight
346 / (F::from(2.0).expect("Failed to convert constant to float") * total_weight);
347 let loss_from = edges_from_source
348 - resolution_f * node_degree * (source_community_weight - node_degree)
349 / (F::from(2.0).expect("Failed to convert constant to float") * total_weight);
350
351 gain_to - loss_from
352}
353
354#[allow(dead_code)]
356fn calculate_community_weight<F>(
357 graph: &Graph<F>,
358 communities: &Array1<usize>,
359 community: usize,
360) -> F
361where
362 F: Float
363 + FromPrimitive
364 + Debug
365 + ScalarOperand
366 + std::iter::Sum
367 + std::cmp::Eq
368 + std::hash::Hash
369 + 'static,
370{
371 let mut weight = F::zero();
372 for node in 0..graph.n_nodes {
373 if communities[node] == community {
374 weight = weight + graph.weighted_degree(node);
375 }
376 }
377 weight
378}
379
380#[allow(dead_code)]
395pub fn label_propagation<F>(
396 graph: &Graph<F>,
397 max_iterations: usize,
398 tolerance: f64,
399) -> Result<Array1<usize>>
400where
401 F: Float
402 + FromPrimitive
403 + Debug
404 + ScalarOperand
405 + std::iter::Sum
406 + std::cmp::Eq
407 + std::hash::Hash
408 + 'static,
409 f64: From<F>,
410{
411 let n_nodes = graph.n_nodes;
412 let mut labels: Array1<usize> = Array1::from_iter(0..n_nodes);
413 let tolerance_f = F::from(tolerance).expect("Failed to convert to float");
414
415 for _iteration in 0..max_iterations {
416 let mut new_labels = labels.clone();
417 let mut changed_nodes = 0;
418
419 let mut node_order: Vec<usize> = (0..n_nodes).collect();
421 node_order.sort_by_key(|&i| i * 17 % n_nodes);
423
424 for &node in &node_order {
425 let mut label_weights: HashMap<usize, F> = HashMap::new();
427
428 for &(neighbor_, weight) in graph.neighbor_s(node) {
429 let label = labels[neighbor_];
430 let entry = label_weights.entry(label).or_insert(F::zero());
431 *entry = *entry + weight;
432 }
433
434 if let Some((&best_label_, _)) = label_weights
436 .iter()
437 .max_by(|a, b| a.1.partial_cmp(b.1).expect("Operation failed"))
438 {
439 if best_label_ != labels[node] {
440 new_labels[node] = best_label_;
441 changed_nodes += 1;
442 }
443 }
444 }
445
446 labels = new_labels;
447
448 let change_ratio = changed_nodes as f64 / n_nodes as f64;
450 if change_ratio < tolerance {
451 break;
452 }
453 }
454
455 let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
457 let label_mapping: HashMap<usize, usize> = unique_labels
458 .into_iter()
459 .enumerate()
460 .map(|(new_label, old_label)| (old_label, new_label))
461 .collect();
462
463 for label in labels.iter_mut() {
464 *label = label_mapping[label];
465 }
466
467 Ok(labels)
468}
469
470#[allow(dead_code)]
485pub fn girvan_newman<F>(graph: &Graph<F>, ncommunities: usize) -> Result<Array1<usize>>
486where
487 F: Float
488 + FromPrimitive
489 + Debug
490 + ScalarOperand
491 + std::iter::Sum
492 + std::cmp::Eq
493 + std::hash::Hash
494 + 'static,
495{
496 if ncommunities > graph.n_nodes {
497 return Err(ClusteringError::InvalidInput(
498 "Number of _communities cannot exceed number of nodes".to_string(),
499 ));
500 }
501
502 let mut workinggraph = graph.clone();
503 let mut _communities = find_connected_components(&workinggraph);
504
505 while count_communities(&_communities) < ncommunities && has_edges(&workinggraph) {
506 let edge_betweenness = calculate_edge_betweenness(&workinggraph)?;
508
509 if let Some((max_edge_, _)) = edge_betweenness
511 .iter()
512 .max_by(|a, b| a.1.partial_cmp(b.1).expect("Operation failed"))
513 {
514 remove_edge(&mut workinggraph, max_edge_.0, max_edge_.1);
516
517 _communities = find_connected_components(&workinggraph);
519 } else {
520 break; }
522 }
523
524 Ok(Array1::from_vec(_communities))
525}
526
527#[allow(dead_code)]
529fn calculate_edge_betweenness<F>(graph: &Graph<F>) -> Result<HashMap<(usize, usize), f64>>
530where
531 F: Float
532 + FromPrimitive
533 + Debug
534 + ScalarOperand
535 + std::iter::Sum
536 + std::cmp::Eq
537 + std::hash::Hash
538 + 'static,
539{
540 let mut edge_betweenness = HashMap::new();
541
542 for node in 0..graph.n_nodes {
544 for &(neighbor_, _) in graph.neighbor_s(node) {
545 if node < neighbor_ {
546 edge_betweenness.insert((node, neighbor_), 0.0);
548 }
549 }
550 }
551
552 for source in 0..graph.n_nodes {
554 for target in (source + 1)..graph.n_nodes {
555 let paths = find_all_shortest_paths(graph, source, target);
556
557 if !paths.is_empty() {
558 let contribution = 1.0 / paths.len() as f64;
559
560 for path in paths {
561 for i in 0..(path.len() - 1) {
562 let (u, v) = if path[i] < path[i + 1] {
563 (path[i], path[i + 1])
564 } else {
565 (path[i + 1], path[i])
566 };
567
568 *edge_betweenness.entry((u, v)).or_insert(0.0) += contribution;
569 }
570 }
571 }
572 }
573 }
574
575 Ok(edge_betweenness)
576}
577
578#[allow(dead_code)]
580fn find_all_shortest_paths<F>(graph: &Graph<F>, source: usize, target: usize) -> Vec<Vec<usize>>
581where
582 F: Float
583 + FromPrimitive
584 + Debug
585 + ScalarOperand
586 + std::iter::Sum
587 + std::cmp::Eq
588 + std::hash::Hash
589 + 'static,
590{
591 let mut distances = vec![None; graph.n_nodes];
592 let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); graph.n_nodes];
593 let mut queue = VecDeque::new();
594
595 distances[source] = Some(0);
596 queue.push_back(source);
597
598 while let Some(current) = queue.pop_front() {
599 let current_dist = distances[current].expect("Operation failed");
600
601 for &(neighbor_, _) in graph.neighbor_s(current) {
602 if distances[neighbor_].is_none() {
603 distances[neighbor_] = Some(current_dist + 1);
605 predecessors[neighbor_].push(current);
606 queue.push_back(neighbor_);
607 } else if distances[neighbor_] == Some(current_dist + 1) {
608 predecessors[neighbor_].push(current);
610 }
611 }
612 }
613
614 if distances[target].is_none() {
616 return Vec::new(); }
618
619 let mut paths = Vec::new();
620 let mut current_paths = vec![vec![target]];
621
622 while !current_paths.is_empty() {
623 let mut next_paths = Vec::new();
624
625 for path in current_paths {
626 let last_node = path[path.len() - 1];
627
628 if last_node == source {
629 let mut complete_path = path.clone();
630 complete_path.reverse();
631 paths.push(complete_path);
632 } else {
633 for &pred in &predecessors[last_node] {
634 let mut new_path = path.clone();
635 new_path.push(pred);
636 next_paths.push(new_path);
637 }
638 }
639 }
640
641 current_paths = next_paths;
642 }
643
644 paths
645}
646
647#[allow(dead_code)]
649fn remove_edge<F>(graph: &mut Graph<F>, node1: usize, node2: usize)
650where
651 F: Float
652 + FromPrimitive
653 + Debug
654 + ScalarOperand
655 + std::iter::Sum
656 + std::cmp::Eq
657 + std::hash::Hash
658 + 'static,
659{
660 graph.adjacency[node1].retain(|(neighbor_, _)| *neighbor_ != node2);
661 graph.adjacency[node2].retain(|(neighbor_, _)| *neighbor_ != node1);
662}
663
664#[allow(dead_code)]
666fn has_edges<F>(graph: &Graph<F>) -> bool
667where
668 F: Float
669 + FromPrimitive
670 + Debug
671 + ScalarOperand
672 + std::iter::Sum
673 + std::cmp::Eq
674 + std::hash::Hash
675 + 'static,
676{
677 graph
678 .adjacency
679 .iter()
680 .any(|neighbor_s| !neighbor_s.is_empty())
681}
682
683#[allow(dead_code)]
685fn find_connected_components<F>(graph: &Graph<F>) -> Vec<usize>
686where
687 F: Float
688 + FromPrimitive
689 + Debug
690 + ScalarOperand
691 + std::iter::Sum
692 + std::cmp::Eq
693 + std::hash::Hash
694 + 'static,
695{
696 let mut visited = vec![false; graph.n_nodes];
697 let mut components = vec![0; graph.n_nodes];
698 let mut component_id = 0;
699
700 for node in 0..graph.n_nodes {
701 if !visited[node] {
702 dfs_component(graph, node, component_id, &mut visited, &mut components);
703 component_id += 1;
704 }
705 }
706
707 components
708}
709
710#[allow(dead_code)]
712fn dfs_component<F>(
713 graph: &Graph<F>,
714 node: usize,
715 component_id: usize,
716 visited: &mut [bool],
717 components: &mut [usize],
718) where
719 F: Float
720 + FromPrimitive
721 + Debug
722 + ScalarOperand
723 + std::iter::Sum
724 + std::cmp::Eq
725 + std::hash::Hash
726 + 'static,
727{
728 visited[node] = true;
729 components[node] = component_id;
730
731 for &(neighbor_, _) in graph.neighbor_s(node) {
732 if !visited[neighbor_] {
733 dfs_component(graph, neighbor_, component_id, visited, components);
734 }
735 }
736}
737
738#[allow(dead_code)]
740fn count_communities(communities: &[usize]) -> usize {
741 let mut unique: HashSet<usize> = HashSet::new();
742 for &community in communities {
743 unique.insert(community);
744 }
745 unique.len()
746}
747
748#[allow(dead_code)]
750fn euclidean_distance<F>(a: ArrayView1<F>, b: ArrayView1<F>) -> F
751where
752 F: Float + std::iter::Sum + 'static,
753{
754 let diff = &a.to_owned() - &b.to_owned();
755 diff.dot(&diff).sqrt()
756}
757
758#[derive(Debug, Clone, Serialize, Deserialize)]
760pub struct GraphClusteringConfig {
761 pub algorithm: GraphClusteringAlgorithm,
763 pub max_iterations: usize,
765 pub tolerance: f64,
767 pub resolution: f64,
769 pub ncommunities: Option<usize>,
771}
772
773#[derive(Debug, Clone, Serialize, Deserialize)]
775pub enum GraphClusteringAlgorithm {
776 Louvain,
778 LabelPropagation,
780 GirvanNewman,
782}
783
784impl Default for GraphClusteringConfig {
785 fn default() -> Self {
786 Self {
787 algorithm: GraphClusteringAlgorithm::Louvain,
788 max_iterations: 100,
789 tolerance: 1e-6,
790 resolution: 1.0,
791 ncommunities: None,
792 }
793 }
794}
795
796#[allow(dead_code)]
807pub fn graph_clustering<F>(
808 graph: &Graph<F>,
809 config: &GraphClusteringConfig,
810) -> Result<Array1<usize>>
811where
812 F: Float
813 + FromPrimitive
814 + Debug
815 + ScalarOperand
816 + std::iter::Sum
817 + std::cmp::Eq
818 + std::hash::Hash
819 + 'static,
820 f64: From<F>,
821{
822 match config.algorithm {
823 GraphClusteringAlgorithm::Louvain => {
824 louvain(graph, config.resolution, config.max_iterations)
825 }
826 GraphClusteringAlgorithm::LabelPropagation => {
827 label_propagation(graph, config.max_iterations, config.tolerance)
828 }
829 GraphClusteringAlgorithm::GirvanNewman => {
830 let ncommunities = config.ncommunities.unwrap_or(2);
831 girvan_newman(graph, ncommunities)
832 }
833 }
834}
835
836#[cfg(test)]
837mod tests {
838 use super::*;
839 use scirs2_core::ndarray::Array2;
840
841 }