Skip to main content

leiden_rs/
multiplex.rs

1//! Multiplex/multilayer network community detection.
2
3use rand::rngs::StdRng;
4use rand::SeedableRng;
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7use rustc_hash::FxHashMap;
8
9use crate::algorithm;
10use crate::leiden::QualityType;
11use crate::partition::Partition;
12use crate::quality::{GraphData, Modularity, QualityFunction};
13
14/// Configuration for multiplex Leiden optimization.
15#[derive(Debug, Clone)]
16pub struct MultiplexConfig {
17    /// Maximum number of Leiden iterations per level.
18    pub max_iterations: usize,
19    /// Resolution parameter γ for all layers.
20    pub resolution: f64,
21    /// Optional RNG seed for reproducible results.
22    pub seed: Option<u64>,
23    /// Quality function to optimize (applied to all layers).
24    pub quality: QualityType,
25    /// Convergence threshold.
26    pub epsilon: f64,
27    /// Maximum community size (0 = unlimited).
28    pub max_comm_size: usize,
29    /// Weight for each layer. Length must match the number of layers.
30    /// Negative weights invert the quality (push nodes apart).
31    pub layer_weights: Vec<f64>,
32    /// Minimum edge slots (CSR entries) for parallel aggregation (default: 10000).
33    pub parallel_aggregation_threshold: Option<usize>,
34}
35
36impl Default for MultiplexConfig {
37    fn default() -> Self {
38        Self {
39            max_iterations: 100,
40            resolution: 1.0,
41            seed: None,
42            quality: QualityType::default(),
43            epsilon: 1e-10,
44            max_comm_size: 0,
45            layer_weights: Vec::new(),
46            parallel_aggregation_threshold: None,
47        }
48    }
49}
50
51impl MultiplexConfig {
52    /// Validate configuration parameters.
53    ///
54    /// Returns `Ok(())` if all parameters are valid, or an `InvalidParameter`
55    /// error describing the first invalid parameter found.
56    pub fn validate(&self) -> crate::error::Result<()> {
57        if self.max_iterations == 0 {
58            return Err(crate::error::LeidenError::InvalidParameter {
59                message: "max_iterations must be > 0".to_string(),
60            });
61        }
62        if !self.resolution.is_finite() || self.resolution < 0.0 {
63            return Err(crate::error::LeidenError::InvalidParameter {
64                message: format!(
65                    "resolution must be finite and non-negative, got {}",
66                    self.resolution
67                ),
68            });
69        }
70        if !self.epsilon.is_finite() || self.epsilon <= 0.0 {
71            return Err(crate::error::LeidenError::InvalidParameter {
72                message: format!("epsilon must be finite and positive, got {}", self.epsilon),
73            });
74        }
75        if !self.layer_weights.iter().all(|w| w.is_finite()) {
76            return Err(crate::error::LeidenError::InvalidParameter {
77                message: format!(
78                    "all layer_weights must be finite, got {:?}",
79                    self.layer_weights
80                ),
81            });
82        }
83        Ok(())
84    }
85}
86
87/// Result of multiplex Leiden optimization.
88#[derive(Debug, Clone)]
89#[non_exhaustive]
90pub struct MultiplexOutput {
91    /// The community partition (same for all layers).
92    pub partition: Partition,
93    /// Total weighted quality score: Σ w_l * Q_l.
94    pub quality: f64,
95    /// Per-layer quality scores.
96    pub layer_qualities: Vec<f64>,
97}
98
99/// Run Leiden community detection on multiple graph layers simultaneously.
100///
101/// All layers must have the same number of nodes. The algorithm optimizes a
102/// weighted sum of quality functions: Q = Σ_l w_l * Q_l.
103///
104/// # Arguments
105/// * `layers` - Graph data for each layer (one GraphData per layer)
106/// * `config` - Multiplex configuration including layer weights
107///
108/// # Returns
109/// A `MultiplexOutput` with the shared partition and quality scores.
110pub fn run_multiplex(
111    layers: &[GraphData],
112    config: &MultiplexConfig,
113) -> crate::error::Result<MultiplexOutput> {
114    config.validate()?;
115    if layers.is_empty() {
116        return Err(crate::error::LeidenError::InvalidParameter {
117            message: "at least one layer is required".to_string(),
118        });
119    }
120
121    let n = layers[0].node_count();
122    for (i, layer) in layers.iter().enumerate() {
123        if layer.node_count() != n {
124            return Err(crate::error::LeidenError::InvalidParameter {
125                message: format!(
126                    "layer {} has {} nodes, expected {} (all layers must share the same vertex set)",
127                    i,
128                    layer.node_count(),
129                    n
130                ),
131            });
132        }
133    }
134
135    let layer_weights: Vec<f64> = if config.layer_weights.is_empty() {
136        vec![1.0; layers.len()]
137    } else if config.layer_weights.len() != layers.len() {
138        return Err(crate::error::LeidenError::InvalidParameter {
139            message: format!(
140                "layer_weights has {} entries but there are {} layers",
141                config.layer_weights.len(),
142                layers.len()
143            ),
144        });
145    } else {
146        config.layer_weights.clone()
147    };
148
149    if n == 0 {
150        return Ok(MultiplexOutput {
151            partition: Partition::new(0),
152            quality: 0.0,
153            layer_qualities: vec![0.0; layers.len()],
154        });
155    }
156
157    let modularity = Modularity::with_resolution(config.resolution);
158    let cpm = crate::quality::CPM::new(config.resolution);
159    let rbconfig = crate::quality::RBConfiguration::with_resolution(config.resolution);
160    let rber = crate::quality::RBER::new(config.resolution);
161    let quality: &(dyn QualityFunction + Sync) = match config.quality {
162        QualityType::Modularity => &modularity,
163        QualityType::CPM => &cpm,
164        QualityType::RBConfiguration => &rbconfig,
165        QualityType::RBER => &rber,
166    };
167
168    let original_n = n;
169    let mut partition = Partition::new(n);
170    let mut flat_mapping: Vec<usize> = (0..n).collect();
171
172    let mut rng = match config.seed {
173        Some(seed) => StdRng::seed_from_u64(seed),
174        None => StdRng::from_rng(&mut rand::rng()),
175    };
176
177    let mut local_moving_buffers =
178        algorithm::LocalMovingBuffers::new(n, layers.len());
179    let mut refinement_buffers =
180        algorithm::RefinementBuffers::new(n, layers.len());
181
182    let mut agg_layer: Option<GraphData> = None;
183
184    for _iter in 0..config.max_iterations {
185        let (current_layers, current_weights): (&[GraphData], &[f64]) = match &agg_layer {
186            Some(data) => (std::slice::from_ref(data), &[1.0][..]),
187            None => (layers, &layer_weights),
188        };
189
190        let q_before = weighted_quality(current_layers, current_weights, &partition, quality);
191
192        let changed = algorithm::local_moving_generic(
193            current_layers,
194            current_weights,
195            &mut partition,
196            quality,
197            &mut rng,
198            &algorithm::MovingConfig {
199                max_comm_size: config.max_comm_size,
200                epsilon: config.epsilon,
201            },
202            &mut local_moving_buffers,
203        );
204        if !changed {
205            break;
206        }
207        partition.renumber();
208
209        let q_after = weighted_quality(current_layers, current_weights, &partition, quality);
210        if (q_after - q_before).abs() < config.epsilon {
211            break;
212        }
213
214        let refined = multiplex_refinement(
215            current_layers,
216            current_weights,
217            &partition,
218            quality,
219            &mut rng,
220            config.epsilon,
221            &mut refinement_buffers,
222        );
223
224        let (agg_data, orig_to_agg, agg_initial) =
225            multiplex_aggregate(current_layers, &refined, &partition, config.parallel_aggregation_threshold)?;
226
227        for orig_node in 0..original_n {
228            flat_mapping[orig_node] = orig_to_agg[flat_mapping[orig_node]];
229        }
230
231        if agg_data.node_count() <= 1 {
232            break;
233        }
234
235        agg_layer = Some(agg_data);
236        partition = agg_initial;
237    }
238
239    let mut result = Partition::from_membership(vec![0; original_n]);
240    for (orig_node, &agg_node) in flat_mapping.iter().enumerate() {
241        let comm = partition.community_of(agg_node);
242        result.move_node(orig_node, comm);
243    }
244    result.renumber();
245
246    let layer_qualities: Vec<f64> = layers
247        .iter()
248        .map(|layer| quality.total_quality(layer, &result))
249        .collect();
250    let total_quality: f64 = layer_weights
251        .iter()
252        .zip(layer_qualities.iter())
253        .map(|(w, q)| w * q)
254        .sum();
255
256    Ok(MultiplexOutput {
257        partition: result,
258        quality: total_quality,
259        layer_qualities,
260    })
261}
262
263fn weighted_quality(
264    layers: &[GraphData],
265    layer_weights: &[f64],
266    partition: &Partition,
267    quality: &dyn QualityFunction,
268) -> f64 {
269    layer_weights
270        .iter()
271        .zip(layers.iter())
272        .map(|(w, layer)| w * quality.total_quality(layer, partition))
273        .sum()
274}
275
276fn multiplex_refinement(
277    layers: &[GraphData],
278    layer_weights: &[f64],
279    partition: &Partition,
280    quality: &(dyn QualityFunction + Sync),
281    rng: &mut StdRng,
282    epsilon: f64,
283    buffers: &mut algorithm::RefinementBuffers,
284) -> Partition {
285    let m: f64 = layers.iter().map(|l| l.total_weight()).sum();
286    if m <= 0.0 {
287        return Partition::new(layers[0].node_count());
288    }
289    algorithm::refinement_generic(
290        layers[0].node_count(),
291        layers.len(),
292        partition,
293        rng,
294        buffers,
295        |community, nodes, buf| {
296            algorithm::refine_community_generic(
297                layers,
298                layer_weights,
299                partition,
300                quality,
301                &algorithm::CommunitySubset { community, nodes },
302                &algorithm::MovingConfig {
303                    max_comm_size: 0,
304                    epsilon,
305                },
306                buf,
307            )
308        },
309    )
310}
311
312fn multiplex_aggregate_edges_sequential(
313    layers: &[GraphData],
314    orig_to_agg: &[usize],
315    n: usize,
316) -> FxHashMap<(usize, usize), f64> {
317    let mut agg_edges: FxHashMap<(usize, usize), f64> = FxHashMap::default();
318    for layer in layers {
319        let directed = layer.is_directed();
320        for u in 0..n {
321            algorithm::aggregate_node_edges_into(layer, u, orig_to_agg, directed, &mut agg_edges);
322        }
323    }
324    agg_edges
325}
326
327#[cfg(feature = "rayon")]
328fn multiplex_aggregate_edges_parallel(
329    layers: &[GraphData],
330    orig_to_agg: &[usize],
331    n: usize,
332) -> FxHashMap<(usize, usize), f64> {
333    (0..n)
334        .into_par_iter()
335        .fold(FxHashMap::<(usize, usize), f64>::default, |mut local, u| {
336            for layer in layers {
337                let directed = layer.is_directed();
338                algorithm::aggregate_node_edges_into(layer, u, orig_to_agg, directed, &mut local);
339            }
340            local
341        })
342        .reduce(
343            FxHashMap::<(usize, usize), f64>::default,
344            |mut acc, local| {
345                for (k, v) in local {
346                    *acc.entry(k).or_default() += v;
347                }
348                acc
349            },
350        )
351}
352
353fn multiplex_aggregate(
354    layers: &[GraphData],
355    refined_partition: &Partition,
356    coarse_partition: &Partition,
357    parallel_aggregation_threshold: Option<usize>,
358) -> crate::error::Result<(GraphData, Vec<usize>, Partition)> {
359    let n = layers[0].node_count();
360    let (orig_to_agg, agg_n) = algorithm::build_orig_to_agg_mapping(n, refined_partition);
361
362    let any_directed = layers.iter().any(|l| l.is_directed());
363    let agg_edges_map: FxHashMap<(usize, usize), f64> = {
364        #[cfg(feature = "rayon")]
365        {
366            let edge_slots: usize = layers.iter().map(|l| l.out_offsets[n]).sum();
367            let threshold = parallel_aggregation_threshold.unwrap_or(crate::parallel::AGG_PARALLEL_THRESHOLD);
368            if edge_slots >= threshold {
369                multiplex_aggregate_edges_parallel(layers, &orig_to_agg, n)
370            } else {
371                multiplex_aggregate_edges_sequential(layers, &orig_to_agg, n)
372            }
373        }
374        #[cfg(not(feature = "rayon"))]
375        {
376            let _ = parallel_aggregation_threshold;
377            multiplex_aggregate_edges_sequential(layers, &orig_to_agg, n)
378        }
379    };
380
381    algorithm::build_aggregated_graph(
382        orig_to_agg,
383        agg_n,
384        any_directed,
385        agg_edges_map,
386        coarse_partition,
387        |orig| layers[0].node_weight(orig),
388    )
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::graph::GraphDataBuilder;
395
396    fn build_graph(n: usize, edges: &[(usize, usize, f64)]) -> crate::graph::GraphData {
397        let mut b = GraphDataBuilder::new(n);
398        for &(u, v, w) in edges {
399            b.add_edge(u, v, w).unwrap();
400        }
401        b.build().unwrap()
402    }
403
404    #[test]
405    fn test_multiplex_two_layers() {
406        let edges1 = [
407            (0, 1, 1.0),
408            (1, 2, 1.0),
409            (0, 2, 1.0),
410            (3, 4, 1.0),
411            (4, 5, 1.0),
412            (3, 5, 1.0),
413            (2, 3, 1.0),
414        ];
415        let layer1 = build_graph(6, &edges1);
416
417        let edges2 = [
418            (0, 1, 1.0),
419            (1, 2, 1.0),
420            (0, 2, 1.0),
421            (3, 4, 1.0),
422            (4, 5, 1.0),
423            (3, 5, 1.0),
424            (1, 4, 1.0),
425        ];
426        let layer2 = build_graph(6, &edges2);
427
428        let config = MultiplexConfig {
429            seed: Some(42),
430            layer_weights: vec![1.0, 1.0],
431            ..Default::default()
432        };
433
434        let result = run_multiplex(&[layer1, layer2], &config).unwrap();
435        assert!(result.partition.num_communities() >= 1);
436        assert_eq!(result.layer_qualities.len(), 2);
437    }
438
439    #[test]
440    fn test_multiplex_single_layer_matches_standard() {
441        let edges = [
442            (0, 1, 1.0),
443            (1, 2, 1.0),
444            (0, 2, 1.0),
445            (3, 4, 1.0),
446            (4, 5, 1.0),
447            (3, 5, 1.0),
448            (2, 3, 1.0),
449        ];
450        let layer = build_graph(6, &edges);
451
452        let config = MultiplexConfig {
453            seed: Some(42),
454            layer_weights: vec![1.0],
455            ..Default::default()
456        };
457
458        let result = run_multiplex(&[layer], &config).unwrap();
459        assert!(result.partition.num_communities() >= 1);
460    }
461
462    #[test]
463    fn test_multiplex_mismatched_nodes() {
464        let layer1 = build_graph(2, &[(0, 1, 1.0)]);
465        let layer2 = build_graph(3, &[(0, 1, 1.0), (1, 2, 1.0)]);
466
467        let config = MultiplexConfig {
468            layer_weights: vec![1.0, 1.0],
469            ..Default::default()
470        };
471
472        assert!(run_multiplex(&[layer1, layer2], &config).is_err());
473    }
474
475    #[test]
476    fn test_multiplex_empty_layers() {
477        let config = MultiplexConfig::default();
478        assert!(run_multiplex(&[], &config).is_err());
479    }
480
481    #[test]
482    fn test_multiplex_per_layer_stats_correctness() {
483        // Two layers with different degree distributions to expose the
484        // accumulated-sigma bug in modularity delta computation.
485        let mut b1 = GraphDataBuilder::new(4);
486        b1.add_edge(0, 1, 1.0).unwrap();
487        b1.add_edge(1, 2, 1.0).unwrap();
488        b1.add_edge(0, 2, 1.0).unwrap();
489
490        let mut b2 = GraphDataBuilder::new(4);
491        b2.add_edge(0, 1, 2.0).unwrap();
492        b2.add_edge(1, 2, 2.0).unwrap();
493        b2.add_edge(2, 3, 2.0).unwrap();
494
495        let config = MultiplexConfig {
496            seed: Some(42),
497            layer_weights: vec![1.0, 0.5],
498            quality: QualityType::Modularity,
499            ..Default::default()
500        };
501
502        let result =
503            run_multiplex(&[b1.build().unwrap(), b2.build().unwrap()], &config).unwrap();
504
505        assert!(
506            result.quality >= 0.0,
507            "total quality should be non-negative, got {}",
508            result.quality,
509        );
510        assert_eq!(result.layer_qualities.len(), 2);
511        assert!(
512            result.layer_qualities.iter().all(|&q| q.is_finite()),
513            "per-layer qualities should be finite: {:?}",
514            result.layer_qualities,
515        );
516    }
517
518    #[test]
519    fn test_multiplex_weighted_layers() {
520        let edges1 = [
521            (0, 1, 1.0),
522            (1, 2, 1.0),
523            (0, 2, 1.0),
524            (3, 4, 1.0),
525            (4, 5, 1.0),
526            (3, 5, 1.0),
527            (2, 3, 0.01),
528        ];
529        let edges2 = [(0, 3, 1.0), (1, 4, 1.0), (2, 5, 1.0)];
530        let layer1 = build_graph(6, &edges1);
531        let layer2 = build_graph(6, &edges2);
532
533        let config = MultiplexConfig {
534            seed: Some(42),
535            layer_weights: vec![10.0, 0.1],
536            ..Default::default()
537        };
538
539        let result = run_multiplex(&[layer1, layer2], &config).unwrap();
540        assert!(
541            result.partition.num_communities() >= 2,
542            "expected >= 2 communities, got {}",
543            result.partition.num_communities()
544        );
545    }
546
547    #[test]
548    fn test_validate_resolution_negative() {
549        let config = MultiplexConfig {
550            resolution: -0.1,
551            ..Default::default()
552        };
553        let err = config.validate().unwrap_err();
554        assert!(
555            err.to_string().contains("resolution"),
556            "expected resolution error, got: {}",
557            err
558        );
559    }
560
561    #[test]
562    fn test_validate_resolution_nan() {
563        let config = MultiplexConfig {
564            resolution: f64::NAN,
565            ..Default::default()
566        };
567        let err = config.validate().unwrap_err();
568        assert!(
569            err.to_string().contains("resolution"),
570            "expected resolution error, got: {}",
571            err
572        );
573    }
574
575    #[test]
576    fn test_validate_epsilon_zero() {
577        let config = MultiplexConfig {
578            epsilon: 0.0,
579            ..Default::default()
580        };
581        let err = config.validate().unwrap_err();
582        assert!(
583            err.to_string().contains("epsilon"),
584            "expected epsilon error, got: {}",
585            err
586        );
587    }
588
589    #[test]
590    fn test_validate_epsilon_negative() {
591        let config = MultiplexConfig {
592            epsilon: -1e-10,
593            ..Default::default()
594        };
595        let err = config.validate().unwrap_err();
596        assert!(
597            err.to_string().contains("epsilon"),
598            "expected epsilon error, got: {}",
599            err
600        );
601    }
602
603    #[test]
604    fn test_validate_epsilon_nan() {
605        let config = MultiplexConfig {
606            epsilon: f64::NAN,
607            ..Default::default()
608        };
609        let err = config.validate().unwrap_err();
610        assert!(
611            err.to_string().contains("epsilon"),
612            "expected epsilon error, got: {}",
613            err
614        );
615    }
616
617    #[test]
618    fn test_validate_max_iterations_zero() {
619        let config = MultiplexConfig {
620            max_iterations: 0,
621            ..Default::default()
622        };
623        let err = config.validate().unwrap_err();
624        assert!(
625            err.to_string().contains("max_iterations"),
626            "expected max_iterations error, got: {}",
627            err
628        );
629    }
630
631    #[test]
632    fn test_validate_layer_weights_nan() {
633        let config = MultiplexConfig {
634            layer_weights: vec![1.0, f64::NAN],
635            ..Default::default()
636        };
637        let err = config.validate().unwrap_err();
638        assert!(
639            err.to_string().contains("layer_weights"),
640            "expected layer_weights error, got: {}",
641            err
642        );
643    }
644
645    #[test]
646    fn test_validate_layer_weights_inf() {
647        let config = MultiplexConfig {
648            layer_weights: vec![1.0, f64::INFINITY],
649            ..Default::default()
650        };
651        let err = config.validate().unwrap_err();
652        assert!(
653            err.to_string().contains("layer_weights"),
654            "expected layer_weights error, got: {}",
655            err
656        );
657    }
658
659    #[test]
660    fn test_validate_valid() {
661        let config = MultiplexConfig {
662            layer_weights: vec![1.0, 2.0],
663            ..Default::default()
664        };
665        config.validate().expect("valid config should pass validation");
666    }
667}