Skip to main content

leiden_rs/
quality.rs

1//! Quality functions for community detection.
2
3pub use crate::graph::{GraphData, MoveComponents};
4
5/// Trait for quality functions used by the Leiden algorithm.
6pub trait QualityFunction {
7    /// Compute the quality delta of moving a node, given precomputed components.
8    fn delta_move_from_components(&self, c: &MoveComponents) -> f64;
9
10    /// Compute the total quality of a partition.
11    fn total_quality(&self, data: &GraphData, partition: &crate::partition::Partition) -> f64;
12}
13
14/// Modularity: Q = Σ_c [e_c/m - γ*(Σ_c/(2m))²]
15pub struct Modularity {
16    /// Resolution parameter γ.
17    pub resolution: f64,
18}
19
20impl Modularity {
21    /// Create a new Modularity with default resolution (1.0).
22    #[must_use = "constructor returns a new instance"]
23    pub fn new() -> Self {
24        Self { resolution: 1.0 }
25    }
26
27    /// Create a new Modularity with a custom resolution parameter.
28    #[must_use = "constructor returns a new instance"]
29    pub fn with_resolution(resolution: f64) -> Self {
30        Self { resolution }
31    }
32}
33
34impl Default for Modularity {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40#[inline]
41fn modularity_delta(resolution: f64, c: &MoveComponents) -> f64 {
42    if c.two_m == 0.0 {
43        return 0.0;
44    }
45    if !c.directed {
46        (c.k_v_to_target_out - c.k_v_to_current_out) * 2.0 / c.two_m
47            - resolution
48                * c.k_v_out
49                * (c.sigma_tot_target_out - c.sigma_tot_current_out + c.k_v_out)
50                * 2.0
51                / (c.two_m * c.two_m)
52    } else {
53        let m = c.two_m / 2.0;
54        let d_internal = (c.k_v_to_target_out + c.k_v_to_target_in)
55            - (c.k_v_to_current_out + c.k_v_to_current_in);
56        let d_expected = c.k_v_in * (c.sigma_tot_target_out - c.sigma_tot_current_out)
57            + c.k_v_out * (c.sigma_tot_target_in - c.sigma_tot_current_in)
58            + 2.0 * c.k_v_out * c.k_v_in;
59        d_internal / m - resolution * d_expected / (m * m)
60    }
61}
62
63fn modularity_total_quality(
64    resolution: f64,
65    data: &GraphData,
66    partition: &crate::partition::Partition,
67) -> f64 {
68    let n = data.node_count();
69    let m = data.total_weight();
70    if m == 0.0 {
71        return 0.0;
72    }
73
74    let num_comms = partition.num_communities();
75
76    if !data.is_directed() {
77        let mut sigma_tot: Vec<f64> = vec![0.0; num_comms];
78        let mut e_c: Vec<f64> = vec![0.0; num_comms];
79
80        for node in 0..n {
81            let comm = partition.community_of(node);
82            if comm >= num_comms {
83                continue;
84            }
85            sigma_tot[comm] += data.degree_of(node);
86            for (neighbor, weight) in data.neighbors(node) {
87                if neighbor >= node && partition.community_of(neighbor) == comm {
88                    e_c[comm] += weight;
89                }
90            }
91        }
92
93        let two_m = 2.0 * m;
94        let mut q = 0.0;
95        for c in 0..num_comms {
96            q += e_c[c] / m - resolution * (sigma_tot[c] / two_m).powi(2);
97        }
98        q
99    } else {
100        let mut sigma_tot_out: Vec<f64> = vec![0.0; num_comms];
101        let mut sigma_tot_in: Vec<f64> = vec![0.0; num_comms];
102        let mut e_c: Vec<f64> = vec![0.0; num_comms];
103
104        for node in 0..n {
105            let comm = partition.community_of(node);
106            if comm >= num_comms {
107                continue;
108            }
109            sigma_tot_out[comm] += data.out_degree_of(node);
110            sigma_tot_in[comm] += data.in_degree_of(node);
111            for (neighbor, weight) in data.out_neighbors(node) {
112                if partition.community_of(neighbor) == comm {
113                    e_c[comm] += weight;
114                }
115            }
116        }
117
118        let mut q = 0.0;
119        for c in 0..num_comms {
120            q += e_c[c] / m - resolution * sigma_tot_out[c] * sigma_tot_in[c] / (m * m);
121        }
122        q
123    }
124}
125
126impl QualityFunction for Modularity {
127    #[inline]
128    fn delta_move_from_components(&self, c: &MoveComponents) -> f64 {
129        modularity_delta(self.resolution, c)
130    }
131
132    fn total_quality(&self, data: &GraphData, partition: &crate::partition::Partition) -> f64 {
133        modularity_total_quality(self.resolution, data, partition)
134    }
135}
136
137/// CPM (Constant Potts Model): H = Σ_c [e_c - γ * n_c * (n_c - 1) / 2]
138pub struct CPM {
139    /// Resolution parameter γ.
140    pub resolution: f64,
141}
142
143impl CPM {
144    /// Create a new CPM with the given resolution parameter.
145    #[must_use = "constructor returns a new instance"]
146    pub fn new(resolution: f64) -> Self {
147        Self { resolution }
148    }
149}
150
151impl QualityFunction for CPM {
152    #[inline]
153    fn delta_move_from_components(&self, c: &MoveComponents) -> f64 {
154        (c.k_v_to_target_out + c.k_v_to_target_in)
155            - (c.k_v_to_current_out + c.k_v_to_current_in)
156            - self.resolution * c.node_weight * (c.n_target - c.n_current + c.node_weight)
157    }
158
159    fn total_quality(&self, data: &GraphData, partition: &crate::partition::Partition) -> f64 {
160        let n = data.node_count();
161        let num_comms = partition.num_communities();
162        let mut e_c: Vec<f64> = vec![0.0; num_comms];
163        let mut n_c: Vec<f64> = vec![0.0; num_comms];
164
165        let directed = data.is_directed();
166        for node in 0..n {
167            let comm = partition.community_of(node);
168            if comm >= num_comms {
169                continue;
170            }
171            n_c[comm] += data.node_weight(node);
172            if directed {
173                for (neighbor, weight) in data.out_neighbors(node) {
174                    if partition.community_of(neighbor) == comm {
175                        e_c[comm] += weight;
176                    }
177                }
178            } else {
179                for (neighbor, weight) in data.neighbors(node) {
180                    if neighbor >= node && partition.community_of(neighbor) == comm {
181                        e_c[comm] += weight;
182                    }
183                }
184            }
185        }
186
187        let mut h = 0.0;
188        for c in 0..num_comms {
189            h += e_c[c] - self.resolution * n_c[c] * (n_c[c] - 1.0) / 2.0;
190        }
191        h
192    }
193}
194
195/// RBConfiguration: Reichardt-Bornholdt with configuration model null.
196///
197/// Q = Σ_c [e_c - γ * K_c² / (4m)]
198///
199/// Mathematically equivalent to `Modularity::with_resolution(γ)`.
200/// Provided for API compatibility with the leidenalg Python library.
201pub struct RBConfiguration {
202    /// Resolution parameter γ.
203    pub resolution: f64,
204}
205
206impl RBConfiguration {
207    /// Create a new RBConfiguration with default resolution (1.0).
208    #[must_use = "constructor returns a new instance"]
209    pub fn new() -> Self {
210        Self { resolution: 1.0 }
211    }
212
213    /// Create a new RBConfiguration with a custom resolution parameter.
214    #[must_use = "constructor returns a new instance"]
215    pub fn with_resolution(resolution: f64) -> Self {
216        Self { resolution }
217    }
218}
219
220impl Default for RBConfiguration {
221    fn default() -> Self {
222        Self::new()
223    }
224}
225
226impl QualityFunction for RBConfiguration {
227    #[inline]
228    fn delta_move_from_components(&self, c: &MoveComponents) -> f64 {
229        modularity_delta(self.resolution, c)
230    }
231
232    fn total_quality(&self, data: &GraphData, partition: &crate::partition::Partition) -> f64 {
233        modularity_total_quality(self.resolution, data, partition)
234    }
235}
236
237/// RBER: Reichardt-Bornholdt with Erdős-Rényi null model.
238///
239/// Q = Σ_c [e_c - γ * p * n_c * (n_c - 1) / 2]
240///
241/// Where p = 2m / (N*(N-1)) is the graph density and N is the total node weight.
242pub struct RBER {
243    /// Resolution parameter γ.
244    pub resolution: f64,
245}
246
247impl RBER {
248    /// Create a new RBER with the given resolution parameter.
249    #[must_use = "constructor returns a new instance"]
250    pub fn new(resolution: f64) -> Self {
251        Self { resolution }
252    }
253}
254
255impl QualityFunction for RBER {
256    #[inline]
257    fn delta_move_from_components(&self, c: &MoveComponents) -> f64 {
258        let total_n = c.total_node_weight;
259        if total_n <= 1.0 || c.two_m == 0.0 {
260            return 0.0;
261        }
262        let p = c.two_m / (total_n * (total_n - 1.0));
263        (c.k_v_to_target_out + c.k_v_to_target_in)
264            - (c.k_v_to_current_out + c.k_v_to_current_in)
265            - self.resolution * p * c.node_weight * (c.n_target - c.n_current + c.node_weight)
266    }
267
268    fn total_quality(&self, data: &GraphData, partition: &crate::partition::Partition) -> f64 {
269        let n = data.node_count();
270        let m = data.total_weight();
271        if n <= 1 || m == 0.0 {
272            return 0.0;
273        }
274
275        let total_n = data.total_node_weight();
276        if total_n <= 1.0 {
277            return 0.0;
278        }
279        let p = 2.0 * m / (total_n * (total_n - 1.0));
280
281        let num_comms = partition.num_communities();
282        let mut e_c: Vec<f64> = vec![0.0; num_comms];
283        let mut n_c: Vec<f64> = vec![0.0; num_comms];
284
285        let directed = data.is_directed();
286        for node in 0..n {
287            let comm = partition.community_of(node);
288            if comm >= num_comms {
289                continue;
290            }
291            n_c[comm] += data.node_weight(node);
292            if directed {
293                for (neighbor, weight) in data.out_neighbors(node) {
294                    if partition.community_of(neighbor) == comm {
295                        e_c[comm] += weight;
296                    }
297                }
298            } else {
299                for (neighbor, weight) in data.neighbors(node) {
300                    if neighbor >= node && partition.community_of(neighbor) == comm {
301                        e_c[comm] += weight;
302                    }
303                }
304            }
305        }
306
307        let mut q = 0.0;
308        for c in 0..num_comms {
309            q += e_c[c] - self.resolution * p * n_c[c] * (n_c[c] - 1.0) / 2.0;
310        }
311        q
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::graph::GraphDataBuilder;
319
320    fn undirected_mc() -> MoveComponents {
321        MoveComponents {
322            two_m: 20.0,
323            node_weight: 1.0,
324            total_node_weight: 10.0,
325            k_v_out: 3.0,
326            k_v_to_target_out: 2.0,
327            k_v_to_current_out: 0.0,
328            sigma_tot_target_out: 10.0,
329            sigma_tot_current_out: 3.0,
330            k_v_in: 0.0,
331            k_v_to_target_in: 0.0,
332            k_v_to_current_in: 0.0,
333            sigma_tot_target_in: 0.0,
334            sigma_tot_current_in: 0.0,
335            n_target: 1.0,
336            n_current: 1.0,
337            directed: false,
338        }
339    }
340
341    #[test]
342    fn test_graph_data_extraction() {
343        let mut b = GraphDataBuilder::new(3);
344        b.add_edge(0, 1, 1.0).unwrap();
345        b.add_edge(1, 2, 2.0).unwrap();
346        let data = b.build().unwrap();
347        assert_eq!(data.node_count(), 3);
348        assert!((data.total_weight() - 3.0).abs() < 1e-10);
349        assert!((data.degree_of(0) - 1.0).abs() < 1e-10);
350        assert!((data.degree_of(1) - 3.0).abs() < 1e-10);
351        assert!((data.degree_of(2) - 2.0).abs() < 1e-10);
352    }
353
354    #[test]
355    fn test_modularity_delta_positive() {
356        let m = Modularity::new();
357        let delta = m.delta_move_from_components(&undirected_mc());
358        assert!(delta > 0.0);
359    }
360
361    #[test]
362    fn test_cpm_delta_positive() {
363        let cpm = CPM::new(0.1);
364        let delta = cpm.delta_move_from_components(&MoveComponents {
365            two_m: 20.0,
366            node_weight: 1.0,
367            total_node_weight: 10.0,
368            k_v_out: 3.0,
369            k_v_to_target_out: 2.0,
370            k_v_to_current_out: 0.0,
371            sigma_tot_target_out: 10.0,
372            sigma_tot_current_out: 3.0,
373            k_v_in: 0.0,
374            k_v_to_target_in: 0.0,
375            k_v_to_current_in: 0.0,
376            sigma_tot_target_in: 0.0,
377            sigma_tot_current_in: 0.0,
378            n_target: 5.0,
379            n_current: 1.0,
380            directed: false,
381        });
382        // delta = (2+0) - (0+0) - 0.1 * 1.0 * (5 - 1 + 1) = 2.0 - 0.5 = 1.5
383        assert!((delta - 1.5).abs() < 1e-10);
384    }
385
386    #[test]
387    fn test_rbconfiguration_matches_modularity() {
388        let rb = RBConfiguration::new();
389        let m = Modularity::new();
390        let c = undirected_mc();
391        assert!(
392            (rb.delta_move_from_components(&c) - m.delta_move_from_components(&c)).abs() < 1e-10
393        );
394    }
395
396    #[test]
397    fn test_rbconfiguration_with_resolution() {
398        let rb = RBConfiguration::with_resolution(2.0);
399        let m = Modularity::with_resolution(2.0);
400        let c = MoveComponents {
401            two_m: 30.0,
402            node_weight: 1.0,
403            total_node_weight: 20.0,
404            k_v_out: 5.0,
405            k_v_to_target_out: 3.0,
406            k_v_to_current_out: 1.0,
407            sigma_tot_target_out: 15.0,
408            sigma_tot_current_out: 8.0,
409            k_v_in: 0.0,
410            k_v_to_target_in: 0.0,
411            k_v_to_current_in: 0.0,
412            sigma_tot_target_in: 0.0,
413            sigma_tot_current_in: 0.0,
414            n_target: 3.0,
415            n_current: 2.0,
416            directed: false,
417        };
418        assert!(
419            (rb.delta_move_from_components(&c) - m.delta_move_from_components(&c)).abs() < 1e-10
420        );
421    }
422
423    #[test]
424    fn test_rber_delta_positive() {
425        let rber = RBER::new(1.0);
426        let c = MoveComponents {
427            two_m: 20.0,
428            node_weight: 1.0,
429            total_node_weight: 10.0,
430            k_v_out: 5.0,
431            k_v_to_target_out: 4.0,
432            k_v_to_current_out: 0.0,
433            sigma_tot_target_out: 10.0,
434            sigma_tot_current_out: 5.0,
435            k_v_in: 0.0,
436            k_v_to_target_in: 0.0,
437            k_v_to_current_in: 0.0,
438            sigma_tot_target_in: 0.0,
439            sigma_tot_current_in: 0.0,
440            n_target: 5.0,
441            n_current: 1.0,
442            directed: false,
443        };
444        let delta = rber.delta_move_from_components(&c);
445        assert!(delta > 0.0, "RBER delta should be positive, got {delta}");
446    }
447
448    #[test]
449    fn test_rber_delta_calculation() {
450        let rber = RBER::new(1.0);
451        // p = 20 / (10 * 9) = 0.2222...
452        // delta = (4+0 - 0+0) - 1.0 * 0.2222 * 1.0 * (5 - 1 + 1) = 4 - 1.111 = 2.889
453        let c = MoveComponents {
454            two_m: 20.0,
455            node_weight: 1.0,
456            total_node_weight: 10.0,
457            k_v_out: 5.0,
458            k_v_to_target_out: 4.0,
459            k_v_to_current_out: 0.0,
460            sigma_tot_target_out: 10.0,
461            sigma_tot_current_out: 5.0,
462            k_v_in: 0.0,
463            k_v_to_target_in: 0.0,
464            k_v_to_current_in: 0.0,
465            sigma_tot_target_in: 0.0,
466            sigma_tot_current_in: 0.0,
467            n_target: 5.0,
468            n_current: 1.0,
469            directed: false,
470        };
471        let delta = rber.delta_move_from_components(&c);
472        let p = 20.0 / (10.0 * 9.0);
473        let expected = 4.0 - 1.0 * p * 1.0 * (5.0 - 1.0 + 1.0);
474        assert!(
475            (delta - expected).abs() < 1e-10,
476            "expected {expected}, got {delta}"
477        );
478    }
479
480    #[test]
481    fn test_rber_zero_two_m() {
482        let rber = RBER::new(1.0);
483        let c = MoveComponents {
484            two_m: 0.0,
485            node_weight: 1.0,
486            total_node_weight: 10.0,
487            k_v_out: 0.0,
488            k_v_to_target_out: 0.0,
489            k_v_to_current_out: 0.0,
490            sigma_tot_target_out: 0.0,
491            sigma_tot_current_out: 0.0,
492            k_v_in: 0.0,
493            k_v_to_target_in: 0.0,
494            k_v_to_current_in: 0.0,
495            sigma_tot_target_in: 0.0,
496            sigma_tot_current_in: 0.0,
497            n_target: 1.0,
498            n_current: 1.0,
499            directed: false,
500        };
501        assert!((rber.delta_move_from_components(&c)).abs() < 1e-10);
502    }
503
504    #[test]
505    fn test_modularity_directed_delta() {
506        let m = Modularity::new();
507        let c = MoveComponents {
508            two_m: 20.0,
509            node_weight: 1.0,
510            total_node_weight: 10.0,
511            k_v_out: 3.0,
512            k_v_to_target_out: 2.0,
513            k_v_to_current_out: 0.0,
514            sigma_tot_target_out: 10.0,
515            sigma_tot_current_out: 3.0,
516            k_v_in: 2.0,
517            k_v_to_target_in: 1.0,
518            k_v_to_current_in: 0.0,
519            sigma_tot_target_in: 8.0,
520            sigma_tot_current_in: 2.0,
521            n_target: 1.0,
522            n_current: 1.0,
523            directed: true,
524        };
525        let delta = m.delta_move_from_components(&c);
526        // m = 10.0
527        // d_internal = (2+1) - (0+0) = 3.0
528        // d_expected = 2.0*(10-3) + 3.0*(8-2) + 2*3*2 = 14 + 18 + 12 = 44
529        // delta = 3.0/10.0 - 1.0 * 44.0/100.0 = 0.3 - 0.44 = -0.14
530        let expected = 3.0 / 10.0 - 44.0 / 100.0;
531        assert!(
532            (delta - expected).abs() < 1e-10,
533            "expected {expected}, got {delta}"
534        );
535    }
536
537    #[test]
538    fn test_cpm_directed_delta() {
539        let cpm = CPM::new(0.1);
540        let c = MoveComponents {
541            two_m: 20.0,
542            node_weight: 1.0,
543            total_node_weight: 10.0,
544            k_v_out: 3.0,
545            k_v_to_target_out: 2.0,
546            k_v_to_current_out: 1.0,
547            sigma_tot_target_out: 10.0,
548            sigma_tot_current_out: 3.0,
549            k_v_in: 2.0,
550            k_v_to_target_in: 1.0,
551            k_v_to_current_in: 0.0,
552            sigma_tot_target_in: 8.0,
553            sigma_tot_current_in: 2.0,
554            n_target: 5.0,
555            n_current: 1.0,
556            directed: true,
557        };
558        let delta = cpm.delta_move_from_components(&c);
559        // (2+1) - (1+0) - 0.1*1.0*(5-1+1) = 3 - 1 - 0.5 = 1.5
560        assert!((delta - 1.5).abs() < 1e-10);
561    }
562
563    #[test]
564    fn test_rbconfiguration_directed_matches_modularity() {
565        let rb = RBConfiguration::new();
566        let m = Modularity::new();
567        let c = MoveComponents {
568            two_m: 20.0,
569            node_weight: 1.0,
570            total_node_weight: 10.0,
571            k_v_out: 3.0,
572            k_v_to_target_out: 2.0,
573            k_v_to_current_out: 0.0,
574            sigma_tot_target_out: 10.0,
575            sigma_tot_current_out: 3.0,
576            k_v_in: 2.0,
577            k_v_to_target_in: 1.0,
578            k_v_to_current_in: 0.0,
579            sigma_tot_target_in: 8.0,
580            sigma_tot_current_in: 2.0,
581            n_target: 1.0,
582            n_current: 1.0,
583            directed: true,
584        };
585        assert!(
586            (rb.delta_move_from_components(&c) - m.delta_move_from_components(&c)).abs() < 1e-10
587        );
588    }
589
590    #[test]
591    fn test_rber_directed_delta() {
592        let rber = RBER::new(1.0);
593        let c = MoveComponents {
594            two_m: 20.0,
595            node_weight: 1.0,
596            total_node_weight: 10.0,
597            k_v_out: 5.0,
598            k_v_to_target_out: 4.0,
599            k_v_to_current_out: 1.0,
600            sigma_tot_target_out: 10.0,
601            sigma_tot_current_out: 5.0,
602            k_v_in: 3.0,
603            k_v_to_target_in: 2.0,
604            k_v_to_current_in: 0.0,
605            sigma_tot_target_in: 8.0,
606            sigma_tot_current_in: 3.0,
607            n_target: 5.0,
608            n_current: 1.0,
609            directed: true,
610        };
611        let delta = rber.delta_move_from_components(&c);
612        let p = 20.0 / (10.0 * 9.0);
613        let expected = (4.0 + 2.0) - (1.0 + 0.0) - 1.0 * p * 1.0 * (5.0 - 1.0 + 1.0);
614        assert!(
615            (delta - expected).abs() < 1e-10,
616            "expected {expected}, got {delta}"
617        );
618    }
619}