1#![allow(dead_code)]
2use std::collections::{HashMap, HashSet};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct PartitionId(pub u32);
13
14impl std::fmt::Display for PartitionId {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 write!(f, "partition_{}", self.0)
17 }
18}
19
20#[derive(Debug, Clone)]
22pub struct PartNode {
23 pub id: u64,
25 pub weight: f64,
27 pub memory_bytes: u64,
29}
30
31impl PartNode {
32 pub fn new(id: u64, weight: f64, memory_bytes: u64) -> Self {
34 Self {
35 id,
36 weight,
37 memory_bytes,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct PartEdge {
45 pub from: u64,
47 pub to: u64,
49 pub comm_cost: f64,
51}
52
53impl PartEdge {
54 pub fn new(from: u64, to: u64, comm_cost: f64) -> Self {
56 Self {
57 from,
58 to,
59 comm_cost,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct Partition {
67 pub id: PartitionId,
69 pub nodes: Vec<u64>,
71 pub total_weight: f64,
73 pub total_memory: u64,
75}
76
77impl Partition {
78 pub fn new(id: PartitionId) -> Self {
80 Self {
81 id,
82 nodes: Vec::new(),
83 total_weight: 0.0,
84 total_memory: 0,
85 }
86 }
87
88 pub fn add_node(&mut self, node: &PartNode) {
90 self.nodes.push(node.id);
91 self.total_weight += node.weight;
92 self.total_memory += node.memory_bytes;
93 }
94
95 pub fn node_count(&self) -> usize {
97 self.nodes.len()
98 }
99
100 pub fn contains(&self, node_id: u64) -> bool {
102 self.nodes.contains(&node_id)
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct PartitionResult {
109 pub partitions: Vec<Partition>,
111 pub assignment: HashMap<u64, PartitionId>,
113 pub edge_cut_cost: f64,
115 pub imbalance: f64,
117}
118
119impl PartitionResult {
120 pub fn partition_count(&self) -> usize {
122 self.partitions.len()
123 }
124
125 pub fn partition_of(&self, node_id: u64) -> Option<PartitionId> {
127 self.assignment.get(&node_id).copied()
128 }
129
130 pub fn cut_edges<'a>(&'a self, edges: &'a [PartEdge]) -> Vec<&'a PartEdge> {
132 edges
133 .iter()
134 .filter(|e| {
135 let p_from = self.assignment.get(&e.from);
136 let p_to = self.assignment.get(&e.to);
137 match (p_from, p_to) {
138 (Some(a), Some(b)) => a != b,
139 _ => false,
140 }
141 })
142 .collect()
143 }
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum PartitionStrategy {
149 RoundRobin,
151 GreedyBalance,
153 GreedyMinCut,
155}
156
157pub struct GraphPartitioner {
159 nodes: Vec<PartNode>,
161 edges: Vec<PartEdge>,
163}
164
165impl GraphPartitioner {
166 pub fn new(nodes: Vec<PartNode>, edges: Vec<PartEdge>) -> Self {
168 Self { nodes, edges }
169 }
170
171 #[allow(clippy::cast_precision_loss)]
173 pub fn partition(&self, k: u32, strategy: PartitionStrategy) -> PartitionResult {
174 if k == 0 {
175 return PartitionResult {
176 partitions: Vec::new(),
177 assignment: HashMap::new(),
178 edge_cut_cost: 0.0,
179 imbalance: 0.0,
180 };
181 }
182 if self.nodes.is_empty() {
183 let partitions = (0..k).map(|i| Partition::new(PartitionId(i))).collect();
184 return PartitionResult {
185 partitions,
186 assignment: HashMap::new(),
187 edge_cut_cost: 0.0,
188 imbalance: 0.0,
189 };
190 }
191
192 let assignment = match strategy {
193 PartitionStrategy::RoundRobin => self.round_robin(k),
194 PartitionStrategy::GreedyBalance => self.greedy_balance(k),
195 PartitionStrategy::GreedyMinCut => self.greedy_min_cut(k),
196 };
197
198 self.build_result(k, &assignment)
199 }
200
201 fn round_robin(&self, k: u32) -> HashMap<u64, PartitionId> {
203 let mut assignment = HashMap::new();
204 for (i, node) in self.nodes.iter().enumerate() {
205 let part = PartitionId((i as u32) % k);
206 assignment.insert(node.id, part);
207 }
208 assignment
209 }
210
211 #[allow(clippy::cast_precision_loss)]
213 fn greedy_balance(&self, k: u32) -> HashMap<u64, PartitionId> {
214 let mut assignment = HashMap::new();
215 let mut weights = vec![0.0_f64; k as usize];
216
217 let mut sorted_nodes: Vec<_> = self.nodes.iter().collect();
219 sorted_nodes.sort_by(|a, b| {
220 b.weight
221 .partial_cmp(&a.weight)
222 .unwrap_or(std::cmp::Ordering::Equal)
223 });
224
225 for node in sorted_nodes {
226 let min_idx = weights
228 .iter()
229 .enumerate()
230 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
231 .map(|(i, _)| i)
232 .unwrap_or(0);
233 assignment.insert(node.id, PartitionId(min_idx as u32));
234 weights[min_idx] += node.weight;
235 }
236
237 assignment
238 }
239
240 fn greedy_min_cut(&self, k: u32) -> HashMap<u64, PartitionId> {
242 let mut assignment = HashMap::new();
243 let mut partition_nodes: Vec<HashSet<u64>> = vec![HashSet::new(); k as usize];
244
245 let mut adjacency: HashMap<u64, Vec<(u64, f64)>> = HashMap::new();
247 for edge in &self.edges {
248 adjacency
249 .entry(edge.from)
250 .or_default()
251 .push((edge.to, edge.comm_cost));
252 adjacency
253 .entry(edge.to)
254 .or_default()
255 .push((edge.from, edge.comm_cost));
256 }
257
258 for node in &self.nodes {
259 let mut best_part = 0_usize;
262 let mut best_saved = f64::NEG_INFINITY;
263
264 for p in 0..k as usize {
265 let saved: f64 = adjacency
266 .get(&node.id)
267 .map(|neighbors| {
268 neighbors
269 .iter()
270 .filter(|(nid, _)| partition_nodes[p].contains(nid))
271 .map(|(_, cost)| *cost)
272 .sum()
273 })
274 .unwrap_or(0.0);
275
276 if saved > best_saved
277 || (saved == best_saved
278 && partition_nodes[p].len() < partition_nodes[best_part].len())
279 {
280 best_saved = saved;
281 best_part = p;
282 }
283 }
284
285 assignment.insert(node.id, PartitionId(best_part as u32));
286 partition_nodes[best_part].insert(node.id);
287 }
288
289 assignment
290 }
291
292 #[allow(clippy::cast_precision_loss)]
294 fn build_result(&self, k: u32, assignment: &HashMap<u64, PartitionId>) -> PartitionResult {
295 let node_map: HashMap<u64, &PartNode> = self.nodes.iter().map(|n| (n.id, n)).collect();
296
297 let mut partitions: Vec<Partition> =
298 (0..k).map(|i| Partition::new(PartitionId(i))).collect();
299
300 for (node_id, part_id) in assignment {
301 if let Some(node) = node_map.get(node_id) {
302 if (part_id.0 as usize) < partitions.len() {
303 partitions[part_id.0 as usize].add_node(node);
304 }
305 }
306 }
307
308 let edge_cut_cost: f64 = self
309 .edges
310 .iter()
311 .filter(|e| {
312 let p_from = assignment.get(&e.from);
313 let p_to = assignment.get(&e.to);
314 match (p_from, p_to) {
315 (Some(a), Some(b)) => a != b,
316 _ => false,
317 }
318 })
319 .map(|e| e.comm_cost)
320 .sum();
321
322 let weights: Vec<f64> = partitions.iter().map(|p| p.total_weight).collect();
323 let avg = if weights.is_empty() {
324 1.0
325 } else {
326 let sum: f64 = weights.iter().sum();
327 sum / weights.len() as f64
328 };
329 let max_w = weights.iter().cloned().fold(0.0_f64, f64::max);
330 let imbalance = if avg > 0.0 { max_w / avg } else { 0.0 };
331
332 PartitionResult {
333 partitions,
334 assignment: assignment.clone(),
335 edge_cut_cost,
336 imbalance,
337 }
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 fn make_nodes(n: u64) -> Vec<PartNode> {
346 (0..n).map(|i| PartNode::new(i, 1.0, 1024)).collect()
347 }
348
349 fn make_chain_edges(n: u64) -> Vec<PartEdge> {
350 (0..n.saturating_sub(1))
351 .map(|i| PartEdge::new(i, i + 1, 1.0))
352 .collect()
353 }
354
355 #[test]
356 fn test_partition_id_display() {
357 assert_eq!(format!("{}", PartitionId(3)), "partition_3");
358 }
359
360 #[test]
361 fn test_part_node() {
362 let n = PartNode::new(1, 5.0, 2048);
363 assert_eq!(n.id, 1);
364 assert!((n.weight - 5.0).abs() < f64::EPSILON);
365 assert_eq!(n.memory_bytes, 2048);
366 }
367
368 #[test]
369 fn test_partition_add_node() {
370 let mut p = Partition::new(PartitionId(0));
371 p.add_node(&PartNode::new(1, 3.0, 100));
372 p.add_node(&PartNode::new(2, 2.0, 200));
373 assert_eq!(p.node_count(), 2);
374 assert!((p.total_weight - 5.0).abs() < f64::EPSILON);
375 assert_eq!(p.total_memory, 300);
376 }
377
378 #[test]
379 fn test_partition_contains() {
380 let mut p = Partition::new(PartitionId(0));
381 p.add_node(&PartNode::new(42, 1.0, 10));
382 assert!(p.contains(42));
383 assert!(!p.contains(99));
384 }
385
386 #[test]
387 fn test_round_robin_partition() {
388 let nodes = make_nodes(6);
389 let edges = make_chain_edges(6);
390 let partitioner = GraphPartitioner::new(nodes, edges);
391 let result = partitioner.partition(3, PartitionStrategy::RoundRobin);
392 assert_eq!(result.partition_count(), 3);
393 for p in &result.partitions {
394 assert_eq!(p.node_count(), 2);
395 }
396 }
397
398 #[test]
399 fn test_greedy_balance_partition() {
400 let nodes = vec![
401 PartNode::new(0, 10.0, 100),
402 PartNode::new(1, 5.0, 100),
403 PartNode::new(2, 3.0, 100),
404 PartNode::new(3, 2.0, 100),
405 ];
406 let edges = Vec::new();
407 let partitioner = GraphPartitioner::new(nodes, edges);
408 let result = partitioner.partition(2, PartitionStrategy::GreedyBalance);
409 assert_eq!(result.partition_count(), 2);
410 assert!(result.imbalance <= 1.5);
412 }
413
414 #[test]
415 fn test_greedy_min_cut() {
416 let nodes = make_nodes(4);
417 let edges = vec![
418 PartEdge::new(0, 1, 10.0),
419 PartEdge::new(2, 3, 10.0),
420 PartEdge::new(1, 2, 1.0),
421 ];
422 let partitioner = GraphPartitioner::new(nodes, edges);
423 let result = partitioner.partition(2, PartitionStrategy::GreedyMinCut);
424 assert_eq!(result.partition_count(), 2);
425 assert!(result.edge_cut_cost >= 0.0);
428 }
429
430 #[test]
431 fn test_partition_result_partition_of() {
432 let nodes = make_nodes(4);
433 let edges = Vec::new();
434 let partitioner = GraphPartitioner::new(nodes, edges);
435 let result = partitioner.partition(2, PartitionStrategy::RoundRobin);
436 for i in 0..4 {
437 assert!(result.partition_of(i).is_some());
438 }
439 assert!(result.partition_of(999).is_none());
440 }
441
442 #[test]
443 fn test_cut_edges() {
444 let nodes = make_nodes(4);
445 let edges = vec![
446 PartEdge::new(0, 1, 5.0),
447 PartEdge::new(2, 3, 5.0),
448 PartEdge::new(1, 2, 3.0),
449 ];
450 let partitioner = GraphPartitioner::new(nodes, edges.clone());
451 let result = partitioner.partition(2, PartitionStrategy::RoundRobin);
452 let cuts = result.cut_edges(&edges);
453 assert!(!cuts.is_empty() || result.edge_cut_cost == 0.0);
455 }
456
457 #[test]
458 fn test_empty_graph_partition() {
459 let partitioner = GraphPartitioner::new(Vec::new(), Vec::new());
460 let result = partitioner.partition(2, PartitionStrategy::RoundRobin);
461 assert_eq!(result.partition_count(), 2);
462 assert!((result.edge_cut_cost - 0.0).abs() < f64::EPSILON);
463 }
464
465 #[test]
466 fn test_zero_partitions() {
467 let nodes = make_nodes(4);
468 let partitioner = GraphPartitioner::new(nodes, Vec::new());
469 let result = partitioner.partition(0, PartitionStrategy::RoundRobin);
470 assert!(result.partitions.is_empty());
471 }
472
473 #[test]
474 fn test_single_partition() {
475 let nodes = make_nodes(4);
476 let edges = make_chain_edges(4);
477 let partitioner = GraphPartitioner::new(nodes, edges);
478 let result = partitioner.partition(1, PartitionStrategy::GreedyBalance);
479 assert_eq!(result.partition_count(), 1);
480 assert_eq!(result.partitions[0].node_count(), 4);
481 assert!((result.edge_cut_cost - 0.0).abs() < f64::EPSILON);
482 }
483
484 #[test]
485 fn test_imbalance_ratio() {
486 let nodes = vec![PartNode::new(0, 10.0, 100), PartNode::new(1, 1.0, 100)];
487 let partitioner = GraphPartitioner::new(nodes, Vec::new());
488 let result = partitioner.partition(2, PartitionStrategy::RoundRobin);
489 assert!(result.imbalance > 1.0);
491 }
492}