1use rand::rngs::StdRng;
4use rand::SeedableRng;
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7use rustc_hash::FxHashMap;
8
9use crate::algorithm;
10use crate::leiden::QualityType;
11use crate::partition::Partition;
12use crate::quality::{GraphData, Modularity, QualityFunction};
13
14#[derive(Debug, Clone)]
16pub struct MultiplexConfig {
17 pub max_iterations: usize,
19 pub resolution: f64,
21 pub seed: Option<u64>,
23 pub quality: QualityType,
25 pub epsilon: f64,
27 pub max_comm_size: usize,
29 pub layer_weights: Vec<f64>,
32 pub parallel_aggregation_threshold: Option<usize>,
34}
35
36impl Default for MultiplexConfig {
37 fn default() -> Self {
38 Self {
39 max_iterations: 100,
40 resolution: 1.0,
41 seed: None,
42 quality: QualityType::default(),
43 epsilon: 1e-10,
44 max_comm_size: 0,
45 layer_weights: Vec::new(),
46 parallel_aggregation_threshold: None,
47 }
48 }
49}
50
51impl MultiplexConfig {
52 pub fn validate(&self) -> crate::error::Result<()> {
57 if self.max_iterations == 0 {
58 return Err(crate::error::LeidenError::InvalidParameter {
59 message: "max_iterations must be > 0".to_string(),
60 });
61 }
62 if !self.resolution.is_finite() || self.resolution < 0.0 {
63 return Err(crate::error::LeidenError::InvalidParameter {
64 message: format!(
65 "resolution must be finite and non-negative, got {}",
66 self.resolution
67 ),
68 });
69 }
70 if !self.epsilon.is_finite() || self.epsilon <= 0.0 {
71 return Err(crate::error::LeidenError::InvalidParameter {
72 message: format!("epsilon must be finite and positive, got {}", self.epsilon),
73 });
74 }
75 if !self.layer_weights.iter().all(|w| w.is_finite()) {
76 return Err(crate::error::LeidenError::InvalidParameter {
77 message: format!(
78 "all layer_weights must be finite, got {:?}",
79 self.layer_weights
80 ),
81 });
82 }
83 Ok(())
84 }
85}
86
87#[derive(Debug, Clone)]
89#[non_exhaustive]
90pub struct MultiplexOutput {
91 pub partition: Partition,
93 pub quality: f64,
95 pub layer_qualities: Vec<f64>,
97}
98
99pub fn run_multiplex(
111 layers: &[GraphData],
112 config: &MultiplexConfig,
113) -> crate::error::Result<MultiplexOutput> {
114 config.validate()?;
115 if layers.is_empty() {
116 return Err(crate::error::LeidenError::InvalidParameter {
117 message: "at least one layer is required".to_string(),
118 });
119 }
120
121 let n = layers[0].node_count();
122 for (i, layer) in layers.iter().enumerate() {
123 if layer.node_count() != n {
124 return Err(crate::error::LeidenError::InvalidParameter {
125 message: format!(
126 "layer {} has {} nodes, expected {} (all layers must share the same vertex set)",
127 i,
128 layer.node_count(),
129 n
130 ),
131 });
132 }
133 }
134
135 let layer_weights: Vec<f64> = if config.layer_weights.is_empty() {
136 vec![1.0; layers.len()]
137 } else if config.layer_weights.len() != layers.len() {
138 return Err(crate::error::LeidenError::InvalidParameter {
139 message: format!(
140 "layer_weights has {} entries but there are {} layers",
141 config.layer_weights.len(),
142 layers.len()
143 ),
144 });
145 } else {
146 config.layer_weights.clone()
147 };
148
149 if n == 0 {
150 return Ok(MultiplexOutput {
151 partition: Partition::new(0),
152 quality: 0.0,
153 layer_qualities: vec![0.0; layers.len()],
154 });
155 }
156
157 let modularity = Modularity::with_resolution(config.resolution);
158 let cpm = crate::quality::CPM::new(config.resolution);
159 let rbconfig = crate::quality::RBConfiguration::with_resolution(config.resolution);
160 let rber = crate::quality::RBER::new(config.resolution);
161 let quality: &(dyn QualityFunction + Sync) = match config.quality {
162 QualityType::Modularity => &modularity,
163 QualityType::CPM => &cpm,
164 QualityType::RBConfiguration => &rbconfig,
165 QualityType::RBER => &rber,
166 };
167
168 let original_n = n;
169 let mut partition = Partition::new(n);
170 let mut flat_mapping: Vec<usize> = (0..n).collect();
171
172 let mut rng = match config.seed {
173 Some(seed) => StdRng::seed_from_u64(seed),
174 None => StdRng::from_rng(&mut rand::rng()),
175 };
176
177 let mut local_moving_buffers =
178 algorithm::LocalMovingBuffers::new(n, layers.len());
179 let mut refinement_buffers =
180 algorithm::RefinementBuffers::new(n, layers.len());
181
182 let mut agg_layer: Option<GraphData> = None;
183
184 for _iter in 0..config.max_iterations {
185 let (current_layers, current_weights): (&[GraphData], &[f64]) = match &agg_layer {
186 Some(data) => (std::slice::from_ref(data), &[1.0][..]),
187 None => (layers, &layer_weights),
188 };
189
190 let q_before = weighted_quality(current_layers, current_weights, &partition, quality);
191
192 let changed = algorithm::local_moving_generic(
193 current_layers,
194 current_weights,
195 &mut partition,
196 quality,
197 &mut rng,
198 &algorithm::MovingConfig {
199 max_comm_size: config.max_comm_size,
200 epsilon: config.epsilon,
201 },
202 &mut local_moving_buffers,
203 );
204 if !changed {
205 break;
206 }
207 partition.renumber();
208
209 let q_after = weighted_quality(current_layers, current_weights, &partition, quality);
210 if (q_after - q_before).abs() < config.epsilon {
211 break;
212 }
213
214 let refined = multiplex_refinement(
215 current_layers,
216 current_weights,
217 &partition,
218 quality,
219 &mut rng,
220 config.epsilon,
221 &mut refinement_buffers,
222 );
223
224 let (agg_data, orig_to_agg, agg_initial) =
225 multiplex_aggregate(current_layers, &refined, &partition, config.parallel_aggregation_threshold)?;
226
227 for orig_node in 0..original_n {
228 flat_mapping[orig_node] = orig_to_agg[flat_mapping[orig_node]];
229 }
230
231 if agg_data.node_count() <= 1 {
232 break;
233 }
234
235 agg_layer = Some(agg_data);
236 partition = agg_initial;
237 }
238
239 let mut result = Partition::from_membership(vec![0; original_n]);
240 for (orig_node, &agg_node) in flat_mapping.iter().enumerate() {
241 let comm = partition.community_of(agg_node);
242 result.move_node(orig_node, comm);
243 }
244 result.renumber();
245
246 let layer_qualities: Vec<f64> = layers
247 .iter()
248 .map(|layer| quality.total_quality(layer, &result))
249 .collect();
250 let total_quality: f64 = layer_weights
251 .iter()
252 .zip(layer_qualities.iter())
253 .map(|(w, q)| w * q)
254 .sum();
255
256 Ok(MultiplexOutput {
257 partition: result,
258 quality: total_quality,
259 layer_qualities,
260 })
261}
262
263fn weighted_quality(
264 layers: &[GraphData],
265 layer_weights: &[f64],
266 partition: &Partition,
267 quality: &dyn QualityFunction,
268) -> f64 {
269 layer_weights
270 .iter()
271 .zip(layers.iter())
272 .map(|(w, layer)| w * quality.total_quality(layer, partition))
273 .sum()
274}
275
276fn multiplex_refinement(
277 layers: &[GraphData],
278 layer_weights: &[f64],
279 partition: &Partition,
280 quality: &(dyn QualityFunction + Sync),
281 rng: &mut StdRng,
282 epsilon: f64,
283 buffers: &mut algorithm::RefinementBuffers,
284) -> Partition {
285 let m: f64 = layers.iter().map(|l| l.total_weight()).sum();
286 if m <= 0.0 {
287 return Partition::new(layers[0].node_count());
288 }
289 algorithm::refinement_generic(
290 layers[0].node_count(),
291 layers.len(),
292 partition,
293 rng,
294 buffers,
295 |community, nodes, buf| {
296 algorithm::refine_community_generic(
297 layers,
298 layer_weights,
299 partition,
300 quality,
301 &algorithm::CommunitySubset { community, nodes },
302 &algorithm::MovingConfig {
303 max_comm_size: 0,
304 epsilon,
305 },
306 buf,
307 )
308 },
309 )
310}
311
312fn multiplex_aggregate_edges_sequential(
313 layers: &[GraphData],
314 orig_to_agg: &[usize],
315 n: usize,
316) -> FxHashMap<(usize, usize), f64> {
317 let mut agg_edges: FxHashMap<(usize, usize), f64> = FxHashMap::default();
318 for layer in layers {
319 let directed = layer.is_directed();
320 for u in 0..n {
321 algorithm::aggregate_node_edges_into(layer, u, orig_to_agg, directed, &mut agg_edges);
322 }
323 }
324 agg_edges
325}
326
327#[cfg(feature = "rayon")]
328fn multiplex_aggregate_edges_parallel(
329 layers: &[GraphData],
330 orig_to_agg: &[usize],
331 n: usize,
332) -> FxHashMap<(usize, usize), f64> {
333 (0..n)
334 .into_par_iter()
335 .fold(FxHashMap::<(usize, usize), f64>::default, |mut local, u| {
336 for layer in layers {
337 let directed = layer.is_directed();
338 algorithm::aggregate_node_edges_into(layer, u, orig_to_agg, directed, &mut local);
339 }
340 local
341 })
342 .reduce(
343 FxHashMap::<(usize, usize), f64>::default,
344 |mut acc, local| {
345 for (k, v) in local {
346 *acc.entry(k).or_default() += v;
347 }
348 acc
349 },
350 )
351}
352
353fn multiplex_aggregate(
354 layers: &[GraphData],
355 refined_partition: &Partition,
356 coarse_partition: &Partition,
357 parallel_aggregation_threshold: Option<usize>,
358) -> crate::error::Result<(GraphData, Vec<usize>, Partition)> {
359 let n = layers[0].node_count();
360 let (orig_to_agg, agg_n) = algorithm::build_orig_to_agg_mapping(n, refined_partition);
361
362 let any_directed = layers.iter().any(|l| l.is_directed());
363 let agg_edges_map: FxHashMap<(usize, usize), f64> = {
364 #[cfg(feature = "rayon")]
365 {
366 let edge_slots: usize = layers.iter().map(|l| l.out_offsets[n]).sum();
367 let threshold = parallel_aggregation_threshold.unwrap_or(crate::parallel::AGG_PARALLEL_THRESHOLD);
368 if edge_slots >= threshold {
369 multiplex_aggregate_edges_parallel(layers, &orig_to_agg, n)
370 } else {
371 multiplex_aggregate_edges_sequential(layers, &orig_to_agg, n)
372 }
373 }
374 #[cfg(not(feature = "rayon"))]
375 {
376 let _ = parallel_aggregation_threshold;
377 multiplex_aggregate_edges_sequential(layers, &orig_to_agg, n)
378 }
379 };
380
381 algorithm::build_aggregated_graph(
382 orig_to_agg,
383 agg_n,
384 any_directed,
385 agg_edges_map,
386 coarse_partition,
387 |orig| layers[0].node_weight(orig),
388 )
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::graph::GraphDataBuilder;
395
396 fn build_graph(n: usize, edges: &[(usize, usize, f64)]) -> crate::graph::GraphData {
397 let mut b = GraphDataBuilder::new(n);
398 for &(u, v, w) in edges {
399 b.add_edge(u, v, w).unwrap();
400 }
401 b.build().unwrap()
402 }
403
404 #[test]
405 fn test_multiplex_two_layers() {
406 let edges1 = [
407 (0, 1, 1.0),
408 (1, 2, 1.0),
409 (0, 2, 1.0),
410 (3, 4, 1.0),
411 (4, 5, 1.0),
412 (3, 5, 1.0),
413 (2, 3, 1.0),
414 ];
415 let layer1 = build_graph(6, &edges1);
416
417 let edges2 = [
418 (0, 1, 1.0),
419 (1, 2, 1.0),
420 (0, 2, 1.0),
421 (3, 4, 1.0),
422 (4, 5, 1.0),
423 (3, 5, 1.0),
424 (1, 4, 1.0),
425 ];
426 let layer2 = build_graph(6, &edges2);
427
428 let config = MultiplexConfig {
429 seed: Some(42),
430 layer_weights: vec![1.0, 1.0],
431 ..Default::default()
432 };
433
434 let result = run_multiplex(&[layer1, layer2], &config).unwrap();
435 assert!(result.partition.num_communities() >= 1);
436 assert_eq!(result.layer_qualities.len(), 2);
437 }
438
439 #[test]
440 fn test_multiplex_single_layer_matches_standard() {
441 let edges = [
442 (0, 1, 1.0),
443 (1, 2, 1.0),
444 (0, 2, 1.0),
445 (3, 4, 1.0),
446 (4, 5, 1.0),
447 (3, 5, 1.0),
448 (2, 3, 1.0),
449 ];
450 let layer = build_graph(6, &edges);
451
452 let config = MultiplexConfig {
453 seed: Some(42),
454 layer_weights: vec![1.0],
455 ..Default::default()
456 };
457
458 let result = run_multiplex(&[layer], &config).unwrap();
459 assert!(result.partition.num_communities() >= 1);
460 }
461
462 #[test]
463 fn test_multiplex_mismatched_nodes() {
464 let layer1 = build_graph(2, &[(0, 1, 1.0)]);
465 let layer2 = build_graph(3, &[(0, 1, 1.0), (1, 2, 1.0)]);
466
467 let config = MultiplexConfig {
468 layer_weights: vec![1.0, 1.0],
469 ..Default::default()
470 };
471
472 assert!(run_multiplex(&[layer1, layer2], &config).is_err());
473 }
474
475 #[test]
476 fn test_multiplex_empty_layers() {
477 let config = MultiplexConfig::default();
478 assert!(run_multiplex(&[], &config).is_err());
479 }
480
481 #[test]
482 fn test_multiplex_per_layer_stats_correctness() {
483 let mut b1 = GraphDataBuilder::new(4);
486 b1.add_edge(0, 1, 1.0).unwrap();
487 b1.add_edge(1, 2, 1.0).unwrap();
488 b1.add_edge(0, 2, 1.0).unwrap();
489
490 let mut b2 = GraphDataBuilder::new(4);
491 b2.add_edge(0, 1, 2.0).unwrap();
492 b2.add_edge(1, 2, 2.0).unwrap();
493 b2.add_edge(2, 3, 2.0).unwrap();
494
495 let config = MultiplexConfig {
496 seed: Some(42),
497 layer_weights: vec![1.0, 0.5],
498 quality: QualityType::Modularity,
499 ..Default::default()
500 };
501
502 let result =
503 run_multiplex(&[b1.build().unwrap(), b2.build().unwrap()], &config).unwrap();
504
505 assert!(
506 result.quality >= 0.0,
507 "total quality should be non-negative, got {}",
508 result.quality,
509 );
510 assert_eq!(result.layer_qualities.len(), 2);
511 assert!(
512 result.layer_qualities.iter().all(|&q| q.is_finite()),
513 "per-layer qualities should be finite: {:?}",
514 result.layer_qualities,
515 );
516 }
517
518 #[test]
519 fn test_multiplex_weighted_layers() {
520 let edges1 = [
521 (0, 1, 1.0),
522 (1, 2, 1.0),
523 (0, 2, 1.0),
524 (3, 4, 1.0),
525 (4, 5, 1.0),
526 (3, 5, 1.0),
527 (2, 3, 0.01),
528 ];
529 let edges2 = [(0, 3, 1.0), (1, 4, 1.0), (2, 5, 1.0)];
530 let layer1 = build_graph(6, &edges1);
531 let layer2 = build_graph(6, &edges2);
532
533 let config = MultiplexConfig {
534 seed: Some(42),
535 layer_weights: vec![10.0, 0.1],
536 ..Default::default()
537 };
538
539 let result = run_multiplex(&[layer1, layer2], &config).unwrap();
540 assert!(
541 result.partition.num_communities() >= 2,
542 "expected >= 2 communities, got {}",
543 result.partition.num_communities()
544 );
545 }
546
547 #[test]
548 fn test_validate_resolution_negative() {
549 let config = MultiplexConfig {
550 resolution: -0.1,
551 ..Default::default()
552 };
553 let err = config.validate().unwrap_err();
554 assert!(
555 err.to_string().contains("resolution"),
556 "expected resolution error, got: {}",
557 err
558 );
559 }
560
561 #[test]
562 fn test_validate_resolution_nan() {
563 let config = MultiplexConfig {
564 resolution: f64::NAN,
565 ..Default::default()
566 };
567 let err = config.validate().unwrap_err();
568 assert!(
569 err.to_string().contains("resolution"),
570 "expected resolution error, got: {}",
571 err
572 );
573 }
574
575 #[test]
576 fn test_validate_epsilon_zero() {
577 let config = MultiplexConfig {
578 epsilon: 0.0,
579 ..Default::default()
580 };
581 let err = config.validate().unwrap_err();
582 assert!(
583 err.to_string().contains("epsilon"),
584 "expected epsilon error, got: {}",
585 err
586 );
587 }
588
589 #[test]
590 fn test_validate_epsilon_negative() {
591 let config = MultiplexConfig {
592 epsilon: -1e-10,
593 ..Default::default()
594 };
595 let err = config.validate().unwrap_err();
596 assert!(
597 err.to_string().contains("epsilon"),
598 "expected epsilon error, got: {}",
599 err
600 );
601 }
602
603 #[test]
604 fn test_validate_epsilon_nan() {
605 let config = MultiplexConfig {
606 epsilon: f64::NAN,
607 ..Default::default()
608 };
609 let err = config.validate().unwrap_err();
610 assert!(
611 err.to_string().contains("epsilon"),
612 "expected epsilon error, got: {}",
613 err
614 );
615 }
616
617 #[test]
618 fn test_validate_max_iterations_zero() {
619 let config = MultiplexConfig {
620 max_iterations: 0,
621 ..Default::default()
622 };
623 let err = config.validate().unwrap_err();
624 assert!(
625 err.to_string().contains("max_iterations"),
626 "expected max_iterations error, got: {}",
627 err
628 );
629 }
630
631 #[test]
632 fn test_validate_layer_weights_nan() {
633 let config = MultiplexConfig {
634 layer_weights: vec![1.0, f64::NAN],
635 ..Default::default()
636 };
637 let err = config.validate().unwrap_err();
638 assert!(
639 err.to_string().contains("layer_weights"),
640 "expected layer_weights error, got: {}",
641 err
642 );
643 }
644
645 #[test]
646 fn test_validate_layer_weights_inf() {
647 let config = MultiplexConfig {
648 layer_weights: vec![1.0, f64::INFINITY],
649 ..Default::default()
650 };
651 let err = config.validate().unwrap_err();
652 assert!(
653 err.to_string().contains("layer_weights"),
654 "expected layer_weights error, got: {}",
655 err
656 );
657 }
658
659 #[test]
660 fn test_validate_valid() {
661 let config = MultiplexConfig {
662 layer_weights: vec![1.0, 2.0],
663 ..Default::default()
664 };
665 config.validate().expect("valid config should pass validation");
666 }
667}