Skip to main content

cyanea_omics/
network.rs

1//! Network biology — weighted graphs, centrality metrics, and community detection.
2//!
3//! Build graphs from correlation matrices, compute degree/betweenness/closeness
4//! centrality, and detect communities with the Louvain algorithm.
5
6use std::collections::VecDeque;
7
8use cyanea_core::{CyaneaError, Result};
9
10use crate::sparse::SparseMatrix;
11
12/// A weighted graph (directed or undirected).
13#[derive(Debug, Clone)]
14pub struct Graph {
15    n_nodes: usize,
16    edges: Vec<(usize, usize, f64)>,
17    adjacency: Vec<Vec<(usize, f64)>>,
18    directed: bool,
19}
20
21/// Centrality scores for all nodes in a graph.
22#[derive(Debug, Clone)]
23pub struct CentralityScores {
24    /// Degree centrality for each node.
25    pub degree: Vec<f64>,
26    /// Betweenness centrality for each node.
27    pub betweenness: Vec<f64>,
28    /// Closeness centrality for each node.
29    pub closeness: Vec<f64>,
30}
31
32/// Community detection result.
33#[derive(Debug, Clone)]
34pub struct Community {
35    /// Community assignment for each node.
36    pub assignments: Vec<usize>,
37    /// Modularity Q of the partition.
38    pub modularity: f64,
39}
40
41impl Graph {
42    /// Create an empty graph with `n_nodes` nodes.
43    pub fn new(n_nodes: usize, directed: bool) -> Self {
44        Self {
45            n_nodes,
46            edges: Vec::new(),
47            adjacency: vec![Vec::new(); n_nodes],
48            directed,
49        }
50    }
51
52    /// Add a weighted edge.
53    ///
54    /// # Errors
55    ///
56    /// Returns an error if either node index is out of range.
57    pub fn add_edge(&mut self, from: usize, to: usize, weight: f64) -> Result<()> {
58        if from >= self.n_nodes || to >= self.n_nodes {
59            return Err(CyaneaError::InvalidInput(format!(
60                "node index out of range: from={}, to={}, n_nodes={}",
61                from, to, self.n_nodes
62            )));
63        }
64        self.edges.push((from, to, weight));
65        self.adjacency[from].push((to, weight));
66        if !self.directed {
67            self.adjacency[to].push((from, weight));
68        }
69        Ok(())
70    }
71
72    /// Build an undirected graph from a correlation matrix with a threshold.
73    ///
74    /// Adds an edge between nodes i and j if |correlation[i][j]| ≥ threshold.
75    /// Edge weight is the absolute correlation.
76    pub fn from_correlation_matrix(matrix: &[Vec<f64>], threshold: f64) -> Self {
77        let n = matrix.len();
78        let mut graph = Self::new(n, false);
79        for i in 0..n {
80            for j in (i + 1)..n {
81                let corr = matrix[i][j].abs();
82                if corr >= threshold {
83                    let _ = graph.add_edge(i, j, corr);
84                }
85            }
86        }
87        graph
88    }
89
90    /// Build an undirected graph from a sparse matrix.
91    ///
92    /// For a symmetric matrix (e.g. adjacency/connectivity), each stored entry
93    /// `(i, j, w)` with `i < j` becomes an edge. Entries with `i >= j` are
94    /// skipped to avoid double-counting in undirected graphs.
95    pub fn from_sparse_matrix(matrix: &SparseMatrix) -> Self {
96        let (n, _) = matrix.shape();
97        let mut graph = Self::new(n, false);
98        for (i, j, w) in matrix.iter() {
99            if i < j && w != 0.0 {
100                let _ = graph.add_edge(i, j, w);
101            }
102        }
103        graph
104    }
105
106    /// Louvain community detection with a resolution parameter.
107    ///
108    /// The resolution parameter γ controls community granularity:
109    /// - γ < 1.0 → fewer, larger communities
110    /// - γ = 1.0 → standard modularity
111    /// - γ > 1.0 → more, smaller communities
112    ///
113    /// Implements both Phase 1 (local moves) and Phase 2 (hierarchical aggregation).
114    pub fn louvain_with_resolution(&self, resolution: f64) -> Community {
115        if self.n_nodes == 0 {
116            return Community {
117                assignments: Vec::new(),
118                modularity: 0.0,
119            };
120        }
121
122        let n = self.n_nodes;
123        let mut assignments: Vec<usize> = (0..n).collect();
124
125        // Total edge weight (2m for undirected)
126        let m2: f64 = if self.directed {
127            self.edges.iter().map(|(_, _, w)| w).sum()
128        } else {
129            self.edges.iter().map(|(_, _, w)| 2.0 * w).sum()
130        };
131
132        if m2 == 0.0 {
133            return Community {
134                assignments,
135                modularity: 0.0,
136            };
137        }
138
139        // Phase 1 + Phase 2 loop
140        let mut global_improved = true;
141        while global_improved {
142            global_improved = false;
143
144            // Phase 1: local moves with resolution
145            let mut improved = true;
146            while improved {
147                improved = false;
148                for i in 0..n {
149                    let current_community = assignments[i];
150                    let k_i: f64 = self.adjacency[i].iter().map(|(_, w)| w).sum();
151
152                    let mut community_weights: Vec<(usize, f64)> = Vec::new();
153                    for &(j, w) in &self.adjacency[i] {
154                        let cj = assignments[j];
155                        if let Some(entry) = community_weights.iter_mut().find(|(c, _)| *c == cj) {
156                            entry.1 += w;
157                        } else {
158                            community_weights.push((cj, w));
159                        }
160                    }
161
162                    let sigma_tot_current =
163                        self.community_total_weight(&assignments, current_community);
164                    let k_i_in_current = community_weights
165                        .iter()
166                        .find(|(c, _)| *c == current_community)
167                        .map_or(0.0, |(_, w)| *w);
168
169                    let mut best_community = current_community;
170                    let mut best_delta_q = 0.0;
171
172                    for &(cj, k_i_in) in &community_weights {
173                        if cj == current_community {
174                            continue;
175                        }
176                        let sigma_tot = self.community_total_weight(&assignments, cj);
177
178                        // ΔQ with resolution parameter γ
179                        let delta_q = (k_i_in - k_i_in_current) / m2
180                            - resolution * k_i * (sigma_tot - sigma_tot_current + k_i)
181                                / (m2 * m2)
182                                * 2.0;
183
184                        if delta_q > best_delta_q {
185                            best_delta_q = delta_q;
186                            best_community = cj;
187                        }
188                    }
189
190                    if best_community != current_community {
191                        assignments[i] = best_community;
192                        improved = true;
193                        global_improved = true;
194                    }
195                }
196            }
197
198            // Phase 2: check if further refinement would help
199            // Renumber communities and check if aggregation reduced count
200            let mut community_map: Vec<usize> = Vec::new();
201            for a in assignments.iter_mut() {
202                let pos = community_map.iter().position(|&c| c == *a);
203                *a = match pos {
204                    Some(idx) => idx,
205                    None => {
206                        community_map.push(*a);
207                        community_map.len() - 1
208                    }
209                };
210            }
211
212            // If we only have one community or no improvement, stop
213            if community_map.len() >= n {
214                break;
215            }
216        }
217
218        let modularity = self.modularity_with_resolution(&assignments, resolution);
219        Community {
220            assignments,
221            modularity,
222        }
223    }
224
225    /// Compute modularity Q with a resolution parameter.
226    ///
227    /// Q = Σ_c [L_c / m - γ * (d_c / 2m)²]
228    pub fn modularity_with_resolution(&self, assignments: &[usize], resolution: f64) -> f64 {
229        let m: f64 = self.edges.iter().map(|(_, _, w)| w).sum();
230        if m == 0.0 {
231            return 0.0;
232        }
233
234        let mut communities: Vec<usize> = assignments.to_vec();
235        communities.sort_unstable();
236        communities.dedup();
237
238        let mut q = 0.0;
239        for &c in &communities {
240            let l_c: f64 = self
241                .edges
242                .iter()
243                .filter(|&&(i, j, _)| {
244                    assignments.get(i) == Some(&c) && assignments.get(j) == Some(&c)
245                })
246                .map(|(_, _, w)| w)
247                .sum();
248
249            let d_c: f64 = (0..self.n_nodes)
250                .filter(|&v| assignments.get(v) == Some(&c))
251                .map(|v| -> f64 { self.adjacency[v].iter().map(|(_, w)| w).sum() })
252                .sum();
253
254            q += l_c / m - resolution * (d_c / (2.0 * m)).powi(2);
255        }
256        q
257    }
258
259    /// Number of nodes.
260    pub fn n_nodes(&self) -> usize {
261        self.n_nodes
262    }
263
264    /// Number of edges.
265    pub fn n_edges(&self) -> usize {
266        self.edges.len()
267    }
268
269    /// Neighbors of a node with edge weights.
270    pub fn neighbors(&self, node: usize) -> &[(usize, f64)] {
271        if node < self.n_nodes {
272            &self.adjacency[node]
273        } else {
274            &[]
275        }
276    }
277
278    /// Degree centrality: degree(v) / (n - 1).
279    pub fn degree_centrality(&self) -> Vec<f64> {
280        let n = self.n_nodes;
281        if n <= 1 {
282            return vec![0.0; n];
283        }
284        let denom = (n - 1) as f64;
285        (0..n)
286            .map(|v| self.adjacency[v].len() as f64 / denom)
287            .collect()
288    }
289
290    /// Betweenness centrality using Brandes' algorithm.
291    pub fn betweenness_centrality(&self) -> Vec<f64> {
292        let n = self.n_nodes;
293        let mut cb = vec![0.0f64; n];
294
295        for s in 0..n {
296            // BFS from s (unweighted shortest paths).
297            let mut stack = Vec::new();
298            let mut pred: Vec<Vec<usize>> = vec![Vec::new(); n];
299            let mut sigma = vec![0.0f64; n];
300            sigma[s] = 1.0;
301            let mut dist = vec![-1i64; n];
302            dist[s] = 0;
303
304            let mut queue = VecDeque::new();
305            queue.push_back(s);
306
307            while let Some(v) = queue.pop_front() {
308                stack.push(v);
309                for &(w, _) in &self.adjacency[v] {
310                    // First visit?
311                    if dist[w] < 0 {
312                        dist[w] = dist[v] + 1;
313                        queue.push_back(w);
314                    }
315                    // Shortest path through v?
316                    if dist[w] == dist[v] + 1 {
317                        sigma[w] += sigma[v];
318                        pred[w].push(v);
319                    }
320                }
321            }
322
323            // Accumulate dependencies.
324            let mut delta = vec![0.0f64; n];
325            while let Some(w) = stack.pop() {
326                for &v in &pred[w] {
327                    delta[v] += (sigma[v] / sigma[w]) * (1.0 + delta[w]);
328                }
329                if w != s {
330                    cb[w] += delta[w];
331                }
332            }
333        }
334
335        // For undirected graphs, each pair is counted twice.
336        if !self.directed {
337            for v in &mut cb {
338                *v /= 2.0;
339            }
340        }
341
342        // Normalize by (n-1)(n-2) for comparability.
343        let norm = if n > 2 {
344            ((n - 1) * (n - 2)) as f64
345        } else {
346            1.0
347        };
348        for v in &mut cb {
349            *v /= norm;
350        }
351
352        cb
353    }
354
355    /// Closeness centrality: (n-1) / Σ d(v, u).
356    pub fn closeness_centrality(&self) -> Vec<f64> {
357        let n = self.n_nodes;
358        if n <= 1 {
359            return vec![0.0; n];
360        }
361
362        (0..n)
363            .map(|v| {
364                let distances = self.bfs_distances(v);
365                let sum_dist: usize = distances.iter().filter(|&&d| d > 0).sum();
366                let reachable = distances.iter().filter(|&&d| d > 0).count();
367                if sum_dist > 0 && reachable > 0 {
368                    reachable as f64 / sum_dist as f64
369                } else {
370                    0.0
371                }
372            })
373            .collect()
374    }
375
376    /// Compute all three centrality metrics.
377    pub fn centrality(&self) -> CentralityScores {
378        CentralityScores {
379            degree: self.degree_centrality(),
380            betweenness: self.betweenness_centrality(),
381            closeness: self.closeness_centrality(),
382        }
383    }
384
385    /// Louvain community detection.
386    ///
387    /// Iteratively moves nodes between communities to maximize modularity.
388    pub fn louvain(&self) -> Community {
389        if self.n_nodes == 0 {
390            return Community {
391                assignments: Vec::new(),
392                modularity: 0.0,
393            };
394        }
395
396        let n = self.n_nodes;
397        let mut assignments: Vec<usize> = (0..n).collect();
398
399        // Total edge weight (2m for undirected).
400        let m2: f64 = if self.directed {
401            self.edges.iter().map(|(_, _, w)| w).sum()
402        } else {
403            self.edges.iter().map(|(_, _, w)| 2.0 * w).sum()
404        };
405
406        if m2 == 0.0 {
407            return Community {
408                assignments,
409                modularity: 0.0,
410            };
411        }
412
413        // Phase 1: local moves.
414        let mut improved = true;
415        while improved {
416            improved = false;
417
418            for i in 0..n {
419                let current_community = assignments[i];
420
421                // Compute k_i (weighted degree of node i).
422                let k_i: f64 = self.adjacency[i].iter().map(|(_, w)| w).sum();
423
424                // Sum of weights to each neighbor community.
425                let mut community_weights: Vec<(usize, f64)> = Vec::new();
426                for &(j, w) in &self.adjacency[i] {
427                    let cj = assignments[j];
428                    if let Some(entry) = community_weights.iter_mut().find(|(c, _)| *c == cj) {
429                        entry.1 += w;
430                    } else {
431                        community_weights.push((cj, w));
432                    }
433                }
434
435                // Σ_tot for current community.
436                let sigma_tot_current = self.community_total_weight(&assignments, current_community);
437                let k_i_in_current = community_weights
438                    .iter()
439                    .find(|(c, _)| *c == current_community)
440                    .map_or(0.0, |(_, w)| *w);
441
442                let mut best_community = current_community;
443                let mut best_delta_q = 0.0;
444
445                for &(cj, k_i_in) in &community_weights {
446                    if cj == current_community {
447                        continue;
448                    }
449                    let sigma_tot = self.community_total_weight(&assignments, cj);
450
451                    // ΔQ for moving i from current to cj.
452                    let delta_q = (k_i_in - k_i_in_current) / m2
453                        - k_i * (sigma_tot - sigma_tot_current + k_i) / (m2 * m2) * 2.0;
454
455                    if delta_q > best_delta_q {
456                        best_delta_q = delta_q;
457                        best_community = cj;
458                    }
459                }
460
461                if best_community != current_community {
462                    assignments[i] = best_community;
463                    improved = true;
464                }
465            }
466        }
467
468        // Renumber communities to be contiguous.
469        let mut community_map: Vec<usize> = Vec::new();
470        for a in &mut assignments {
471            let pos = community_map.iter().position(|&c| c == *a);
472            *a = match pos {
473                Some(idx) => idx,
474                None => {
475                    community_map.push(*a);
476                    community_map.len() - 1
477                }
478            };
479        }
480
481        let modularity = self.modularity(&assignments);
482        Community {
483            assignments,
484            modularity,
485        }
486    }
487
488    /// Compute modularity Q for a given community assignment.
489    ///
490    /// Uses the community-level formula:
491    /// Q = Σ_c [L_c / m - (d_c / 2m)²]
492    /// where L_c = sum of edge weights within community c, m = total edge weight,
493    /// and d_c = sum of node degrees in community c.
494    pub fn modularity(&self, assignments: &[usize]) -> f64 {
495        let m: f64 = self.edges.iter().map(|(_, _, w)| w).sum();
496        if m == 0.0 {
497            return 0.0;
498        }
499
500        // Find unique communities.
501        let mut communities: Vec<usize> = assignments.to_vec();
502        communities.sort_unstable();
503        communities.dedup();
504
505        let mut q = 0.0;
506        for &c in &communities {
507            // L_c: sum of edge weights where both endpoints are in community c.
508            let l_c: f64 = self
509                .edges
510                .iter()
511                .filter(|&&(i, j, _)| {
512                    assignments.get(i) == Some(&c) && assignments.get(j) == Some(&c)
513                })
514                .map(|(_, _, w)| w)
515                .sum();
516
517            // d_c: sum of weighted degrees of nodes in community c.
518            let d_c: f64 = (0..self.n_nodes)
519                .filter(|&v| assignments.get(v) == Some(&c))
520                .map(|v| -> f64 { self.adjacency[v].iter().map(|(_, w)| w).sum() })
521                .sum();
522
523            q += l_c / m - (d_c / (2.0 * m)).powi(2);
524        }
525        q
526    }
527
528    fn bfs_distances(&self, start: usize) -> Vec<usize> {
529        let n = self.n_nodes;
530        let mut dist = vec![usize::MAX; n];
531        dist[start] = 0;
532        let mut queue = VecDeque::new();
533        queue.push_back(start);
534        while let Some(v) = queue.pop_front() {
535            for &(w, _) in &self.adjacency[v] {
536                if dist[w] == usize::MAX {
537                    dist[w] = dist[v] + 1;
538                    queue.push_back(w);
539                }
540            }
541        }
542        dist
543    }
544
545    fn community_total_weight(&self, assignments: &[usize], community: usize) -> f64 {
546        let mut total = 0.0;
547        for (v, neighbors) in self.adjacency.iter().enumerate() {
548            if assignments[v] == community {
549                for &(_, w) in neighbors {
550                    total += w;
551                }
552            }
553        }
554        total
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    #[test]
563    fn graph_creation() {
564        let mut g = Graph::new(5, false);
565        g.add_edge(0, 1, 1.0).unwrap();
566        g.add_edge(1, 2, 1.0).unwrap();
567        assert_eq!(g.n_nodes(), 5);
568        assert_eq!(g.n_edges(), 2);
569    }
570
571    #[test]
572    fn degree_centrality_star() {
573        // Star graph: center (0) connected to 1,2,3,4.
574        let mut g = Graph::new(5, false);
575        for i in 1..5 {
576            g.add_edge(0, i, 1.0).unwrap();
577        }
578        let dc = g.degree_centrality();
579        // Center has degree 4, centrality = 4/4 = 1.0.
580        assert!((dc[0] - 1.0).abs() < 1e-10);
581        // Leaves have degree 1, centrality = 1/4 = 0.25.
582        for i in 1..5 {
583            assert!((dc[i] - 0.25).abs() < 1e-10);
584        }
585    }
586
587    #[test]
588    fn betweenness_centrality_line() {
589        // Line: 0-1-2-3-4. Node 2 should have highest betweenness.
590        let mut g = Graph::new(5, false);
591        for i in 0..4 {
592            g.add_edge(i, i + 1, 1.0).unwrap();
593        }
594        let bc = g.betweenness_centrality();
595        // Node 2 (center of line) has highest betweenness.
596        let max_node = bc
597            .iter()
598            .enumerate()
599            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
600            .unwrap()
601            .0;
602        assert_eq!(max_node, 2);
603    }
604
605    #[test]
606    fn closeness_centrality_complete() {
607        // Complete graph: all nodes equidistant → equal closeness.
608        let mut g = Graph::new(4, false);
609        for i in 0..4 {
610            for j in (i + 1)..4 {
611                g.add_edge(i, j, 1.0).unwrap();
612            }
613        }
614        let cc = g.closeness_centrality();
615        for i in 0..4 {
616            assert!((cc[i] - cc[0]).abs() < 1e-10);
617        }
618        // Closeness should be 1.0: (n-1) / sum(distances) = 3/3 = 1.0.
619        assert!((cc[0] - 1.0).abs() < 1e-10);
620    }
621
622    #[test]
623    fn louvain_two_cliques() {
624        // Two cliques of 3 nodes each, connected by a single weak edge.
625        let mut g = Graph::new(6, false);
626        // Clique 1: 0-1-2.
627        g.add_edge(0, 1, 1.0).unwrap();
628        g.add_edge(1, 2, 1.0).unwrap();
629        g.add_edge(0, 2, 1.0).unwrap();
630        // Clique 2: 3-4-5.
631        g.add_edge(3, 4, 1.0).unwrap();
632        g.add_edge(4, 5, 1.0).unwrap();
633        g.add_edge(3, 5, 1.0).unwrap();
634        // Weak bridge.
635        g.add_edge(2, 3, 0.1).unwrap();
636
637        let community = g.louvain();
638        // Nodes in the same clique should be in the same community.
639        assert_eq!(community.assignments[0], community.assignments[1]);
640        assert_eq!(community.assignments[1], community.assignments[2]);
641        assert_eq!(community.assignments[3], community.assignments[4]);
642        assert_eq!(community.assignments[4], community.assignments[5]);
643        // The two cliques should be in different communities.
644        assert_ne!(community.assignments[0], community.assignments[3]);
645        assert!(community.modularity > 0.0);
646    }
647
648    #[test]
649    fn modularity_single_community() {
650        // All nodes in one community → Q should be 0.
651        let mut g = Graph::new(3, false);
652        g.add_edge(0, 1, 1.0).unwrap();
653        g.add_edge(1, 2, 1.0).unwrap();
654        let assignments = vec![0, 0, 0];
655        let q = g.modularity(&assignments);
656        assert!(q.abs() < 1e-10, "Q should be ~0 for single community, got {}", q);
657    }
658
659    #[test]
660    fn from_sparse_matrix() {
661        let mut sm = SparseMatrix::new(4, 4);
662        sm.insert(0, 1, 1.0).unwrap();
663        sm.insert(1, 0, 1.0).unwrap();
664        sm.insert(2, 3, 0.5).unwrap();
665        sm.insert(3, 2, 0.5).unwrap();
666        let g = Graph::from_sparse_matrix(&sm);
667        assert_eq!(g.n_nodes(), 4);
668        assert_eq!(g.n_edges(), 2);
669    }
670
671    #[test]
672    fn louvain_with_resolution_default() {
673        let mut g = Graph::new(6, false);
674        g.add_edge(0, 1, 1.0).unwrap();
675        g.add_edge(1, 2, 1.0).unwrap();
676        g.add_edge(0, 2, 1.0).unwrap();
677        g.add_edge(3, 4, 1.0).unwrap();
678        g.add_edge(4, 5, 1.0).unwrap();
679        g.add_edge(3, 5, 1.0).unwrap();
680        g.add_edge(2, 3, 0.1).unwrap();
681
682        let community = g.louvain_with_resolution(1.0);
683        assert_eq!(community.assignments[0], community.assignments[1]);
684        assert_eq!(community.assignments[1], community.assignments[2]);
685        assert_ne!(community.assignments[0], community.assignments[3]);
686    }
687
688    #[test]
689    fn louvain_with_high_resolution() {
690        // Higher resolution should produce more communities
691        let mut g = Graph::new(8, false);
692        // Two loosely connected groups of 4
693        for i in 0..4 {
694            for j in (i + 1)..4 {
695                g.add_edge(i, j, 1.0).unwrap();
696            }
697        }
698        for i in 4..8 {
699            for j in (i + 1)..8 {
700                g.add_edge(i, j, 1.0).unwrap();
701            }
702        }
703        g.add_edge(3, 4, 0.01).unwrap();
704
705        let low_res = g.louvain_with_resolution(0.5);
706        let high_res = g.louvain_with_resolution(3.0);
707        let n_communities_low = *low_res.assignments.iter().max().unwrap() + 1;
708        let n_communities_high = *high_res.assignments.iter().max().unwrap() + 1;
709        assert!(n_communities_high >= n_communities_low);
710    }
711
712    #[test]
713    fn modularity_with_resolution() {
714        let mut g = Graph::new(4, false);
715        g.add_edge(0, 1, 1.0).unwrap();
716        g.add_edge(2, 3, 1.0).unwrap();
717        let assignments = vec![0, 0, 1, 1];
718        let q1 = g.modularity_with_resolution(&assignments, 1.0);
719        let q_standard = g.modularity(&assignments);
720        assert!((q1 - q_standard).abs() < 1e-10);
721    }
722
723    #[test]
724    fn from_correlation_matrix_threshold() {
725        let matrix = vec![
726            vec![1.0, 0.9, 0.2],
727            vec![0.9, 1.0, 0.3],
728            vec![0.2, 0.3, 1.0],
729        ];
730        let g = Graph::from_correlation_matrix(&matrix, 0.5);
731        // Only the 0.9 edge should pass the 0.5 threshold.
732        assert_eq!(g.n_edges(), 1);
733        assert_eq!(g.n_nodes(), 3);
734    }
735}