1use std::collections::HashMap;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct GraphPartition {
16 pub node_id: String,
18 pub partition: usize,
20}
21
22#[derive(Debug, Clone)]
24pub struct PartitionResult {
25 pub assignments: Vec<GraphPartition>,
27 pub num_partitions: usize,
29 pub cut_edges: usize,
31 pub balance_score: f64,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum PartitionMethod {
38 Greedy,
40 LabelPropagation,
42 Bisection,
44}
45
46#[derive(Debug, Clone)]
52pub struct GraphPartitioner {
53 pub num_partitions: usize,
55 pub method: PartitionMethod,
57 pub max_iterations: usize,
59}
60
61impl GraphPartitioner {
62 pub fn new(num_partitions: usize) -> Self {
64 Self {
65 num_partitions: num_partitions.max(1),
66 method: PartitionMethod::Greedy,
67 max_iterations: 20,
68 }
69 }
70
71 pub fn with_method(mut self, method: PartitionMethod) -> Self {
73 self.method = method;
74 self
75 }
76
77 pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
79 self.max_iterations = max_iter;
80 self
81 }
82
83 pub fn partition(&self, nodes: &[String], edges: &[(String, String)]) -> PartitionResult {
89 if nodes.is_empty() {
90 return PartitionResult {
91 assignments: vec![],
92 num_partitions: self.num_partitions,
93 cut_edges: 0,
94 balance_score: 1.0,
95 };
96 }
97
98 let k = self.num_partitions;
99
100 let labels = match &self.method {
101 PartitionMethod::Greedy => Self::greedy_partition(nodes, edges, k),
102 PartitionMethod::LabelPropagation => {
103 Self::label_propagation(nodes, edges, k, self.max_iterations)
104 }
105 PartitionMethod::Bisection => Self::bisection_partition(nodes, edges, k),
106 };
107
108 let node_idx: HashMap<&str, usize> = nodes
110 .iter()
111 .enumerate()
112 .map(|(i, n)| (n.as_str(), i))
113 .collect();
114
115 let int_edges: Vec<(usize, usize)> = edges
116 .iter()
117 .filter_map(|(a, b)| {
118 let ai = node_idx.get(a.as_str())?;
119 let bi = node_idx.get(b.as_str())?;
120 Some((*ai, *bi))
121 })
122 .collect();
123
124 let cut_edges = Self::count_cut_edges(&labels, &int_edges);
125 let balance = Self::balance_score(&labels, k);
126
127 let assignments = nodes
128 .iter()
129 .enumerate()
130 .map(|(i, n)| GraphPartition {
131 node_id: n.clone(),
132 partition: labels[i],
133 })
134 .collect();
135
136 PartitionResult {
137 assignments,
138 num_partitions: k,
139 cut_edges,
140 balance_score: balance,
141 }
142 }
143
144 pub fn greedy_partition(nodes: &[String], edges: &[(String, String)], k: usize) -> Vec<usize> {
150 let k = k.max(1);
151 let n = nodes.len();
152 if n == 0 {
153 return vec![];
154 }
155
156 let mut labels: Vec<usize> = (0..n).map(|i| i % k).collect();
158
159 let adj = Self::build_adjacency(nodes, edges);
161
162 let target_per_partition = (n + k - 1) / k; for i in 0..n {
167 if adj[i].is_empty() {
168 continue;
169 }
170 let mut counts = vec![0usize; k];
171 for &nb in &adj[i] {
172 counts[labels[nb]] += 1;
173 }
174 let current_part = labels[i];
176 let mut best_part = current_part;
177 let mut best_count = counts[current_part];
178
179 let mut part_sizes = vec![0usize; k];
181 for &l in labels.iter() {
182 part_sizes[l.min(k - 1)] += 1;
183 }
184
185 for (p, &c) in counts.iter().enumerate() {
186 if c > best_count && part_sizes[p] < target_per_partition {
187 best_part = p;
188 best_count = c;
189 }
190 }
191 labels[i] = best_part;
192 }
193 labels
194 }
195
196 pub fn label_propagation(
202 nodes: &[String],
203 edges: &[(String, String)],
204 k: usize,
205 max_iter: usize,
206 ) -> Vec<usize> {
207 let k = k.max(1);
208 let n = nodes.len();
209 if n == 0 {
210 return vec![];
211 }
212
213 let adj = Self::build_adjacency(nodes, edges);
214 let mut labels: Vec<usize> = (0..n).map(|i| i % k).collect();
216
217 for _ in 0..max_iter {
218 let mut changed = false;
219 let prev = labels.clone();
220
221 for i in 0..n {
222 if adj[i].is_empty() {
223 continue;
224 }
225 let mut counts = vec![0usize; k];
226 for &nb in &adj[i] {
227 counts[prev[nb]] += 1;
228 }
229 let best = counts
231 .iter()
232 .enumerate()
233 .max_by(|(la, &ca), (lb, &cb)| ca.cmp(&cb).then(lb.cmp(la)))
234 .map(|(p, _)| p)
235 .unwrap_or(labels[i]);
236
237 if best != labels[i] {
238 labels[i] = best;
239 changed = true;
240 }
241 }
242
243 if !changed {
244 break;
245 }
246 }
247
248 for l in &mut labels {
250 *l = (*l).min(k - 1);
251 }
252 labels
253 }
254
255 pub fn bisection_partition(
257 nodes: &[String],
258 edges: &[(String, String)],
259 k: usize,
260 ) -> Vec<usize> {
261 let k = k.max(1);
262 let n = nodes.len();
263 if n == 0 {
264 return vec![];
265 }
266
267 let adj = Self::build_adjacency(nodes, edges);
268 let mut labels = vec![0usize; n];
269
270 let target_splits = k.saturating_sub(1);
272
273 for (current_k, _) in (1usize..).zip(0..target_splits) {
274 if current_k >= k {
275 break;
276 }
277 let mut part_sizes = vec![0usize; current_k];
279 for &l in &labels {
280 part_sizes[l] += 1;
281 }
282 let largest_part = part_sizes
283 .iter()
284 .enumerate()
285 .max_by_key(|(_, &s)| s)
286 .map(|(p, _)| p)
287 .unwrap_or(0);
288
289 let part_nodes: Vec<usize> = (0..n).filter(|&i| labels[i] == largest_part).collect();
291 if part_nodes.len() < 2 {
292 break;
293 }
294
295 let half = part_nodes.len() / 2;
297 let new_part = current_k;
298
299 let mut bfs_order: Vec<usize> = Vec::with_capacity(part_nodes.len());
300 let mut visited = vec![false; n];
301 let mut queue = std::collections::VecDeque::new();
302 let start = part_nodes[0];
303 queue.push_back(start);
304 visited[start] = true;
305
306 while let Some(node) = queue.pop_front() {
307 if labels[node] == largest_part {
308 bfs_order.push(node);
309 }
310 for &nb in &adj[node] {
311 if !visited[nb] && labels[nb] == largest_part {
312 visited[nb] = true;
313 queue.push_back(nb);
314 }
315 }
316 }
317
318 for &pn in &part_nodes {
320 if !bfs_order.contains(&pn) {
321 bfs_order.push(pn);
322 }
323 }
324
325 for &node in bfs_order.iter().skip(half) {
327 labels[node] = new_part;
328 }
329 }
330
331 labels
332 }
333
334 pub fn count_cut_edges(assignments: &[usize], edges: &[(usize, usize)]) -> usize {
338 edges
339 .iter()
340 .filter(|&&(a, b)| {
341 a < assignments.len() && b < assignments.len() && assignments[a] != assignments[b]
342 })
343 .count()
344 }
345
346 pub fn balance_score(assignments: &[usize], k: usize) -> f64 {
351 if assignments.is_empty() || k <= 1 {
352 return 1.0;
353 }
354 let mut counts = vec![0usize; k];
355 for &l in assignments {
356 let idx = l.min(k - 1);
357 counts[idx] += 1;
358 }
359 let max = *counts.iter().max().unwrap_or(&0);
360 let min = *counts.iter().min().unwrap_or(&0);
361 if max == 0 {
362 return 1.0;
363 }
364 min as f64 / max as f64
365 }
366
367 pub fn build_adjacency(nodes: &[String], edges: &[(String, String)]) -> Vec<Vec<usize>> {
369 let n = nodes.len();
370 let node_idx: HashMap<&str, usize> = nodes
371 .iter()
372 .enumerate()
373 .map(|(i, s)| (s.as_str(), i))
374 .collect();
375
376 let mut adj = vec![vec![]; n];
377 for (a, b) in edges {
378 if let (Some(&ai), Some(&bi)) = (node_idx.get(a.as_str()), node_idx.get(b.as_str())) {
379 if ai != bi {
380 adj[ai].push(bi);
381 adj[bi].push(ai);
382 }
383 }
384 }
385 adj
386 }
387}
388
389#[cfg(test)]
394mod tests {
395 use super::*;
396
397 fn make_nodes(n: usize) -> Vec<String> {
398 (0..n).map(|i| format!("node_{i}")).collect()
399 }
400
401 fn chain_edges(n: usize) -> Vec<(String, String)> {
402 (0..n.saturating_sub(1))
403 .map(|i| (format!("node_{i}"), format!("node_{}", i + 1)))
404 .collect()
405 }
406
407 #[test]
410 fn test_new_default_method() {
411 let gp = GraphPartitioner::new(4);
412 assert_eq!(gp.num_partitions, 4);
413 assert_eq!(gp.method, PartitionMethod::Greedy);
414 assert_eq!(gp.max_iterations, 20);
415 }
416
417 #[test]
418 fn test_new_zero_becomes_one() {
419 let gp = GraphPartitioner::new(0);
420 assert_eq!(gp.num_partitions, 1);
421 }
422
423 #[test]
424 fn test_with_method_label_propagation() {
425 let gp = GraphPartitioner::new(3).with_method(PartitionMethod::LabelPropagation);
426 assert_eq!(gp.method, PartitionMethod::LabelPropagation);
427 }
428
429 #[test]
430 fn test_with_max_iterations() {
431 let gp = GraphPartitioner::new(3).with_max_iterations(50);
432 assert_eq!(gp.max_iterations, 50);
433 }
434
435 #[test]
438 fn test_partition_empty_nodes() {
439 let gp = GraphPartitioner::new(3);
440 let result = gp.partition(&[], &[]);
441 assert!(result.assignments.is_empty());
442 assert_eq!(result.cut_edges, 0);
443 assert_eq!(result.balance_score, 1.0);
444 }
445
446 #[test]
449 fn test_partition_single_node() {
450 let gp = GraphPartitioner::new(3);
451 let nodes = vec!["A".to_string()];
452 let result = gp.partition(&nodes, &[]);
453 assert_eq!(result.assignments.len(), 1);
454 assert_eq!(result.assignments[0].node_id, "A");
455 assert_eq!(result.cut_edges, 0);
456 }
457
458 #[test]
461 fn test_partition_returns_all_nodes() {
462 let nodes = make_nodes(10);
463 let edges = chain_edges(10);
464 let gp = GraphPartitioner::new(3);
465 let result = gp.partition(&nodes, &edges);
466 assert_eq!(result.assignments.len(), 10);
467 }
468
469 #[test]
470 fn test_partition_labels_in_range() {
471 let nodes = make_nodes(12);
472 let edges = chain_edges(12);
473 let k = 4;
474 let gp = GraphPartitioner::new(k);
475 let result = gp.partition(&nodes, &edges);
476 for a in &result.assignments {
477 assert!(a.partition < k, "label {} out of range", a.partition);
478 }
479 }
480
481 #[test]
482 fn test_partition_num_partitions_field() {
483 let nodes = make_nodes(6);
484 let gp = GraphPartitioner::new(3);
485 let result = gp.partition(&nodes, &[]);
486 assert_eq!(result.num_partitions, 3);
487 }
488
489 #[test]
492 fn test_greedy_partition_count() {
493 let nodes = make_nodes(9);
494 let edges = chain_edges(9);
495 let labels = GraphPartitioner::greedy_partition(&nodes, &edges, 3);
496 assert_eq!(labels.len(), 9);
497 }
498
499 #[test]
500 fn test_greedy_partition_labels_valid() {
501 let nodes = make_nodes(9);
502 let edges = chain_edges(9);
503 let labels = GraphPartitioner::greedy_partition(&nodes, &edges, 3);
504 for &l in &labels {
505 assert!(l < 3);
506 }
507 }
508
509 #[test]
510 fn test_greedy_partition_empty() {
511 let labels = GraphPartitioner::greedy_partition(&[], &[], 3);
512 assert!(labels.is_empty());
513 }
514
515 #[test]
516 fn test_greedy_partition_k1() {
517 let nodes = make_nodes(5);
518 let labels = GraphPartitioner::greedy_partition(&nodes, &[], 1);
519 assert!(labels.iter().all(|&l| l == 0));
520 }
521
522 #[test]
525 fn test_label_propagation_count() {
526 let nodes = make_nodes(8);
527 let edges = chain_edges(8);
528 let labels = GraphPartitioner::label_propagation(&nodes, &edges, 2, 10);
529 assert_eq!(labels.len(), 8);
530 }
531
532 #[test]
533 fn test_label_propagation_labels_valid() {
534 let nodes = make_nodes(8);
535 let edges = chain_edges(8);
536 let labels = GraphPartitioner::label_propagation(&nodes, &edges, 4, 10);
537 for &l in &labels {
538 assert!(l < 4);
539 }
540 }
541
542 #[test]
543 fn test_label_propagation_empty() {
544 let labels = GraphPartitioner::label_propagation(&[], &[], 3, 10);
545 assert!(labels.is_empty());
546 }
547
548 #[test]
549 fn test_label_propagation_converges() {
550 let nodes = make_nodes(10);
552 let edges = chain_edges(10);
553 let labels = GraphPartitioner::label_propagation(&nodes, &edges, 3, 100);
554 assert_eq!(labels.len(), 10);
555 for &l in &labels {
556 assert!(l < 3);
557 }
558 }
559
560 #[test]
563 fn test_count_cut_edges_none() {
564 let assignments = vec![0, 0, 0];
566 let edges = vec![(0, 1), (1, 2)];
567 assert_eq!(GraphPartitioner::count_cut_edges(&assignments, &edges), 0);
568 }
569
570 #[test]
571 fn test_count_cut_edges_all() {
572 let assignments = vec![0, 1, 0, 1];
574 let edges = vec![(0, 1), (1, 2), (2, 3)];
575 assert_eq!(GraphPartitioner::count_cut_edges(&assignments, &edges), 3);
576 }
577
578 #[test]
579 fn test_count_cut_edges_empty() {
580 assert_eq!(GraphPartitioner::count_cut_edges(&[], &[]), 0);
581 }
582
583 #[test]
584 fn test_count_cut_edges_partial() {
585 let assignments = vec![0, 0, 1, 1];
586 let edges = vec![(0, 1), (1, 2), (2, 3)];
587 assert_eq!(GraphPartitioner::count_cut_edges(&assignments, &edges), 1);
589 }
590
591 #[test]
594 fn test_balance_score_perfect() {
595 let assignments = vec![0, 1, 0, 1];
596 let score = GraphPartitioner::balance_score(&assignments, 2);
597 assert!((score - 1.0).abs() < 1e-10);
598 }
599
600 #[test]
601 fn test_balance_score_empty() {
602 assert!((GraphPartitioner::balance_score(&[], 3) - 1.0).abs() < 1e-10);
603 }
604
605 #[test]
606 fn test_balance_score_k1() {
607 let assignments = vec![0, 0, 0];
608 assert!((GraphPartitioner::balance_score(&assignments, 1) - 1.0).abs() < 1e-10);
609 }
610
611 #[test]
612 fn test_balance_score_imbalanced() {
613 let assignments = vec![0, 1, 1, 1];
615 let score = GraphPartitioner::balance_score(&assignments, 2);
616 assert!((score - 1.0 / 3.0).abs() < 1e-10);
617 }
618
619 #[test]
620 fn test_balance_score_in_range() {
621 let assignments = vec![0, 1, 2, 0, 1, 2, 0];
622 let score = GraphPartitioner::balance_score(&assignments, 3);
623 assert!((0.0..=1.0).contains(&score));
624 }
625
626 #[test]
629 fn test_build_adjacency_empty() {
630 let adj = GraphPartitioner::build_adjacency(&[], &[]);
631 assert!(adj.is_empty());
632 }
633
634 #[test]
635 fn test_build_adjacency_chain() {
636 let nodes = make_nodes(3);
637 let edges = chain_edges(3);
638 let adj = GraphPartitioner::build_adjacency(&nodes, &edges);
639 assert_eq!(adj.len(), 3);
640 assert!(adj[0].contains(&1));
641 assert!(adj[1].contains(&0));
642 assert!(adj[1].contains(&2));
643 assert!(adj[2].contains(&1));
644 }
645
646 #[test]
647 fn test_build_adjacency_no_self_loops() {
648 let nodes = vec!["A".to_string()];
649 let edges = vec![("A".to_string(), "A".to_string())];
650 let adj = GraphPartitioner::build_adjacency(&nodes, &edges);
651 assert!(adj[0].is_empty());
652 }
653
654 #[test]
655 fn test_build_adjacency_unknown_node_ignored() {
656 let nodes = make_nodes(2);
657 let edges = vec![("node_0".to_string(), "node_99".to_string())];
659 let adj = GraphPartitioner::build_adjacency(&nodes, &edges);
660 assert!(adj[0].is_empty());
661 }
662
663 #[test]
666 fn test_bisection_labels_valid() {
667 let nodes = make_nodes(8);
668 let edges = chain_edges(8);
669 let labels = GraphPartitioner::bisection_partition(&nodes, &edges, 4);
670 for &l in &labels {
671 assert!(l < 4);
672 }
673 }
674
675 #[test]
676 fn test_bisection_count() {
677 let nodes = make_nodes(6);
678 let labels = GraphPartitioner::bisection_partition(&nodes, &[], 3);
679 assert_eq!(labels.len(), 6);
680 }
681
682 #[test]
685 fn test_partition_label_propagation_method() {
686 let nodes = make_nodes(10);
687 let edges = chain_edges(10);
688 let gp = GraphPartitioner::new(2).with_method(PartitionMethod::LabelPropagation);
689 let result = gp.partition(&nodes, &edges);
690 assert_eq!(result.assignments.len(), 10);
691 assert!(result.balance_score >= 0.0 && result.balance_score <= 1.0);
692 }
693
694 #[test]
697 fn test_partition_bisection_method() {
698 let nodes = make_nodes(8);
699 let edges = chain_edges(8);
700 let gp = GraphPartitioner::new(4).with_method(PartitionMethod::Bisection);
701 let result = gp.partition(&nodes, &edges);
702 assert_eq!(result.assignments.len(), 8);
703 for a in &result.assignments {
704 assert!(a.partition < 4);
705 }
706 }
707
708 #[test]
711 fn test_graph_partition_fields() {
712 let gp = GraphPartition {
713 node_id: "A".to_string(),
714 partition: 2,
715 };
716 assert_eq!(gp.node_id, "A");
717 assert_eq!(gp.partition, 2);
718 }
719
720 #[test]
721 fn test_graph_partition_clone() {
722 let gp = GraphPartition {
723 node_id: "X".to_string(),
724 partition: 1,
725 };
726 let gp2 = gp.clone();
727 assert_eq!(gp, gp2);
728 }
729
730 #[test]
733 fn test_fully_connected_partition() {
734 let nodes = make_nodes(4);
735 let edges: Vec<(String, String)> = vec![
736 ("node_0".to_string(), "node_1".to_string()),
737 ("node_0".to_string(), "node_2".to_string()),
738 ("node_0".to_string(), "node_3".to_string()),
739 ("node_1".to_string(), "node_2".to_string()),
740 ("node_1".to_string(), "node_3".to_string()),
741 ("node_2".to_string(), "node_3".to_string()),
742 ];
743 let gp = GraphPartitioner::new(2);
744 let result = gp.partition(&nodes, &edges);
745 assert_eq!(result.assignments.len(), 4);
746 assert!(result.cut_edges > 0);
748 }
749
750 #[test]
753 fn test_partition_result_fields() {
754 let nodes = make_nodes(6);
755 let edges = chain_edges(6);
756 let gp = GraphPartitioner::new(2);
757 let result = gp.partition(&nodes, &edges);
758 assert_eq!(result.num_partitions, 2);
759 assert!(result.balance_score >= 0.0 && result.balance_score <= 1.0);
760 }
761
762 #[test]
763 fn test_partition_result_assignment_node_ids() {
764 let nodes = make_nodes(4);
765 let gp = GraphPartitioner::new(2);
766 let result = gp.partition(&nodes, &[]);
767 let ids: Vec<&str> = result
768 .assignments
769 .iter()
770 .map(|a| a.node_id.as_str())
771 .collect();
772 assert!(ids.contains(&"node_0"));
773 assert!(ids.contains(&"node_3"));
774 }
775
776 #[test]
779 fn test_greedy_no_edges() {
780 let nodes = make_nodes(6);
781 let labels = GraphPartitioner::greedy_partition(&nodes, &[], 3);
782 assert_eq!(labels.len(), 6);
783 for &l in &labels {
784 assert!(l < 3);
785 }
786 }
787
788 #[test]
791 fn test_label_propagation_k_larger_than_n() {
792 let nodes = make_nodes(3);
793 let edges = chain_edges(3);
794 let labels = GraphPartitioner::label_propagation(&nodes, &edges, 10, 5);
795 assert_eq!(labels.len(), 3);
796 for &l in &labels {
797 assert!(l < 10);
798 }
799 }
800
801 #[test]
804 fn test_balance_score_all_same() {
805 let assignments = vec![0, 0, 0, 0];
806 let score = GraphPartitioner::balance_score(&assignments, 2);
808 assert_eq!(score, 0.0);
809 }
810
811 #[test]
814 fn test_count_cut_edges_out_of_range() {
815 let assignments = vec![0, 1];
816 let edges = vec![(0usize, 5usize)];
818 assert_eq!(GraphPartitioner::count_cut_edges(&assignments, &edges), 0);
819 }
820
821 #[test]
824 fn test_bisection_single_node() {
825 let nodes = vec!["A".to_string()];
826 let labels = GraphPartitioner::bisection_partition(&nodes, &[], 2);
827 assert_eq!(labels.len(), 1);
828 assert_eq!(labels[0], 0);
829 }
830
831 #[test]
834 fn test_partition_method_debug() {
835 let m = PartitionMethod::LabelPropagation;
836 let s = format!("{m:?}");
837 assert!(s.contains("LabelPropagation"));
838 }
839}