Skip to main content

oximedia_graph/
graph_partition.rs

1#![allow(dead_code)]
2//! Graph partitioning for parallel execution.
3//!
4//! This module provides algorithms for splitting a processing graph into
5//! partitions that can be executed on different threads or machines,
6//! minimizing inter-partition communication.
7
8use std::collections::{HashMap, HashSet};
9
10/// Identifier for a partition.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct PartitionId(pub u32);
13
14impl std::fmt::Display for PartitionId {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        write!(f, "partition_{}", self.0)
17    }
18}
19
20/// A node in the partitionable graph.
21#[derive(Debug, Clone)]
22pub struct PartNode {
23    /// Node identifier.
24    pub id: u64,
25    /// Computational weight (cost).
26    pub weight: f64,
27    /// Memory requirement in bytes.
28    pub memory_bytes: u64,
29}
30
31impl PartNode {
32    /// Create a new partition node.
33    pub fn new(id: u64, weight: f64, memory_bytes: u64) -> Self {
34        Self {
35            id,
36            weight,
37            memory_bytes,
38        }
39    }
40}
41
42/// An edge between two nodes, carrying a communication cost.
43#[derive(Debug, Clone)]
44pub struct PartEdge {
45    /// Source node ID.
46    pub from: u64,
47    /// Destination node ID.
48    pub to: u64,
49    /// Communication cost if these nodes are in different partitions.
50    pub comm_cost: f64,
51}
52
53impl PartEdge {
54    /// Create a new partition edge.
55    pub fn new(from: u64, to: u64, comm_cost: f64) -> Self {
56        Self {
57            from,
58            to,
59            comm_cost,
60        }
61    }
62}
63
64/// A partition of graph nodes.
65#[derive(Debug, Clone)]
66pub struct Partition {
67    /// Partition identifier.
68    pub id: PartitionId,
69    /// Node IDs assigned to this partition.
70    pub nodes: Vec<u64>,
71    /// Total computational weight.
72    pub total_weight: f64,
73    /// Total memory requirement.
74    pub total_memory: u64,
75}
76
77impl Partition {
78    /// Create a new empty partition.
79    pub fn new(id: PartitionId) -> Self {
80        Self {
81            id,
82            nodes: Vec::new(),
83            total_weight: 0.0,
84            total_memory: 0,
85        }
86    }
87
88    /// Add a node to the partition.
89    pub fn add_node(&mut self, node: &PartNode) {
90        self.nodes.push(node.id);
91        self.total_weight += node.weight;
92        self.total_memory += node.memory_bytes;
93    }
94
95    /// Number of nodes in this partition.
96    pub fn node_count(&self) -> usize {
97        self.nodes.len()
98    }
99
100    /// Check if a node is in this partition.
101    pub fn contains(&self, node_id: u64) -> bool {
102        self.nodes.contains(&node_id)
103    }
104}
105
106/// Result of a graph partitioning operation.
107#[derive(Debug, Clone)]
108pub struct PartitionResult {
109    /// The computed partitions.
110    pub partitions: Vec<Partition>,
111    /// Node-to-partition assignment.
112    pub assignment: HashMap<u64, PartitionId>,
113    /// Total inter-partition communication cost (edge cut).
114    pub edge_cut_cost: f64,
115    /// Load imbalance ratio (max_weight / avg_weight).
116    pub imbalance: f64,
117}
118
119impl PartitionResult {
120    /// Number of partitions.
121    pub fn partition_count(&self) -> usize {
122        self.partitions.len()
123    }
124
125    /// Get the partition a node belongs to.
126    pub fn partition_of(&self, node_id: u64) -> Option<PartitionId> {
127        self.assignment.get(&node_id).copied()
128    }
129
130    /// Get edges that cross partition boundaries.
131    pub fn cut_edges<'a>(&'a self, edges: &'a [PartEdge]) -> Vec<&'a PartEdge> {
132        edges
133            .iter()
134            .filter(|e| {
135                let p_from = self.assignment.get(&e.from);
136                let p_to = self.assignment.get(&e.to);
137                match (p_from, p_to) {
138                    (Some(a), Some(b)) => a != b,
139                    _ => false,
140                }
141            })
142            .collect()
143    }
144}
145
146/// Strategy for partitioning.
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum PartitionStrategy {
149    /// Simple round-robin assignment by node order.
150    RoundRobin,
151    /// Greedy assignment minimizing maximum partition weight.
152    GreedyBalance,
153    /// Greedy assignment minimizing edge cut cost.
154    GreedyMinCut,
155}
156
157/// Graph partitioner.
158pub struct GraphPartitioner {
159    /// Nodes in the graph.
160    nodes: Vec<PartNode>,
161    /// Edges in the graph.
162    edges: Vec<PartEdge>,
163}
164
165impl GraphPartitioner {
166    /// Create a new partitioner with the given nodes and edges.
167    pub fn new(nodes: Vec<PartNode>, edges: Vec<PartEdge>) -> Self {
168        Self { nodes, edges }
169    }
170
171    /// Partition the graph into `k` partitions using the given strategy.
172    #[allow(clippy::cast_precision_loss)]
173    pub fn partition(&self, k: u32, strategy: PartitionStrategy) -> PartitionResult {
174        if k == 0 {
175            return PartitionResult {
176                partitions: Vec::new(),
177                assignment: HashMap::new(),
178                edge_cut_cost: 0.0,
179                imbalance: 0.0,
180            };
181        }
182        if self.nodes.is_empty() {
183            let partitions = (0..k).map(|i| Partition::new(PartitionId(i))).collect();
184            return PartitionResult {
185                partitions,
186                assignment: HashMap::new(),
187                edge_cut_cost: 0.0,
188                imbalance: 0.0,
189            };
190        }
191
192        let assignment = match strategy {
193            PartitionStrategy::RoundRobin => self.round_robin(k),
194            PartitionStrategy::GreedyBalance => self.greedy_balance(k),
195            PartitionStrategy::GreedyMinCut => self.greedy_min_cut(k),
196        };
197
198        self.build_result(k, &assignment)
199    }
200
201    /// Simple round-robin assignment.
202    fn round_robin(&self, k: u32) -> HashMap<u64, PartitionId> {
203        let mut assignment = HashMap::new();
204        for (i, node) in self.nodes.iter().enumerate() {
205            let part = PartitionId((i as u32) % k);
206            assignment.insert(node.id, part);
207        }
208        assignment
209    }
210
211    /// Greedy load-balanced assignment.
212    #[allow(clippy::cast_precision_loss)]
213    fn greedy_balance(&self, k: u32) -> HashMap<u64, PartitionId> {
214        let mut assignment = HashMap::new();
215        let mut weights = vec![0.0_f64; k as usize];
216
217        // Sort nodes by decreasing weight for LPT (Longest Processing Time) heuristic
218        let mut sorted_nodes: Vec<_> = self.nodes.iter().collect();
219        sorted_nodes.sort_by(|a, b| {
220            b.weight
221                .partial_cmp(&a.weight)
222                .unwrap_or(std::cmp::Ordering::Equal)
223        });
224
225        for node in sorted_nodes {
226            // Assign to the partition with the smallest total weight
227            let min_idx = weights
228                .iter()
229                .enumerate()
230                .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
231                .map(|(i, _)| i)
232                .unwrap_or(0);
233            assignment.insert(node.id, PartitionId(min_idx as u32));
234            weights[min_idx] += node.weight;
235        }
236
237        assignment
238    }
239
240    /// Greedy assignment minimizing edge cut.
241    fn greedy_min_cut(&self, k: u32) -> HashMap<u64, PartitionId> {
242        let mut assignment = HashMap::new();
243        let mut partition_nodes: Vec<HashSet<u64>> = vec![HashSet::new(); k as usize];
244
245        // Build adjacency for quick neighbor lookup
246        let mut adjacency: HashMap<u64, Vec<(u64, f64)>> = HashMap::new();
247        for edge in &self.edges {
248            adjacency
249                .entry(edge.from)
250                .or_default()
251                .push((edge.to, edge.comm_cost));
252            adjacency
253                .entry(edge.to)
254                .or_default()
255                .push((edge.from, edge.comm_cost));
256        }
257
258        for node in &self.nodes {
259            // For each partition, compute how much communication cost would
260            // be saved by placing this node there (neighbors already in partition).
261            let mut best_part = 0_usize;
262            let mut best_saved = f64::NEG_INFINITY;
263
264            for p in 0..k as usize {
265                let saved: f64 = adjacency
266                    .get(&node.id)
267                    .map(|neighbors| {
268                        neighbors
269                            .iter()
270                            .filter(|(nid, _)| partition_nodes[p].contains(nid))
271                            .map(|(_, cost)| *cost)
272                            .sum()
273                    })
274                    .unwrap_or(0.0);
275
276                if saved > best_saved
277                    || (saved == best_saved
278                        && partition_nodes[p].len() < partition_nodes[best_part].len())
279                {
280                    best_saved = saved;
281                    best_part = p;
282                }
283            }
284
285            assignment.insert(node.id, PartitionId(best_part as u32));
286            partition_nodes[best_part].insert(node.id);
287        }
288
289        assignment
290    }
291
292    /// Build the result from an assignment.
293    #[allow(clippy::cast_precision_loss)]
294    fn build_result(&self, k: u32, assignment: &HashMap<u64, PartitionId>) -> PartitionResult {
295        let node_map: HashMap<u64, &PartNode> = self.nodes.iter().map(|n| (n.id, n)).collect();
296
297        let mut partitions: Vec<Partition> =
298            (0..k).map(|i| Partition::new(PartitionId(i))).collect();
299
300        for (node_id, part_id) in assignment {
301            if let Some(node) = node_map.get(node_id) {
302                if (part_id.0 as usize) < partitions.len() {
303                    partitions[part_id.0 as usize].add_node(node);
304                }
305            }
306        }
307
308        let edge_cut_cost: f64 = self
309            .edges
310            .iter()
311            .filter(|e| {
312                let p_from = assignment.get(&e.from);
313                let p_to = assignment.get(&e.to);
314                match (p_from, p_to) {
315                    (Some(a), Some(b)) => a != b,
316                    _ => false,
317                }
318            })
319            .map(|e| e.comm_cost)
320            .sum();
321
322        let weights: Vec<f64> = partitions.iter().map(|p| p.total_weight).collect();
323        let avg = if weights.is_empty() {
324            1.0
325        } else {
326            let sum: f64 = weights.iter().sum();
327            sum / weights.len() as f64
328        };
329        let max_w = weights.iter().cloned().fold(0.0_f64, f64::max);
330        let imbalance = if avg > 0.0 { max_w / avg } else { 0.0 };
331
332        PartitionResult {
333            partitions,
334            assignment: assignment.clone(),
335            edge_cut_cost,
336            imbalance,
337        }
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    fn make_nodes(n: u64) -> Vec<PartNode> {
346        (0..n).map(|i| PartNode::new(i, 1.0, 1024)).collect()
347    }
348
349    fn make_chain_edges(n: u64) -> Vec<PartEdge> {
350        (0..n.saturating_sub(1))
351            .map(|i| PartEdge::new(i, i + 1, 1.0))
352            .collect()
353    }
354
355    #[test]
356    fn test_partition_id_display() {
357        assert_eq!(format!("{}", PartitionId(3)), "partition_3");
358    }
359
360    #[test]
361    fn test_part_node() {
362        let n = PartNode::new(1, 5.0, 2048);
363        assert_eq!(n.id, 1);
364        assert!((n.weight - 5.0).abs() < f64::EPSILON);
365        assert_eq!(n.memory_bytes, 2048);
366    }
367
368    #[test]
369    fn test_partition_add_node() {
370        let mut p = Partition::new(PartitionId(0));
371        p.add_node(&PartNode::new(1, 3.0, 100));
372        p.add_node(&PartNode::new(2, 2.0, 200));
373        assert_eq!(p.node_count(), 2);
374        assert!((p.total_weight - 5.0).abs() < f64::EPSILON);
375        assert_eq!(p.total_memory, 300);
376    }
377
378    #[test]
379    fn test_partition_contains() {
380        let mut p = Partition::new(PartitionId(0));
381        p.add_node(&PartNode::new(42, 1.0, 10));
382        assert!(p.contains(42));
383        assert!(!p.contains(99));
384    }
385
386    #[test]
387    fn test_round_robin_partition() {
388        let nodes = make_nodes(6);
389        let edges = make_chain_edges(6);
390        let partitioner = GraphPartitioner::new(nodes, edges);
391        let result = partitioner.partition(3, PartitionStrategy::RoundRobin);
392        assert_eq!(result.partition_count(), 3);
393        for p in &result.partitions {
394            assert_eq!(p.node_count(), 2);
395        }
396    }
397
398    #[test]
399    fn test_greedy_balance_partition() {
400        let nodes = vec![
401            PartNode::new(0, 10.0, 100),
402            PartNode::new(1, 5.0, 100),
403            PartNode::new(2, 3.0, 100),
404            PartNode::new(3, 2.0, 100),
405        ];
406        let edges = Vec::new();
407        let partitioner = GraphPartitioner::new(nodes, edges);
408        let result = partitioner.partition(2, PartitionStrategy::GreedyBalance);
409        assert_eq!(result.partition_count(), 2);
410        // With LPT: 10+2=12 and 5+3=8, imbalance = 12/10 = 1.2
411        assert!(result.imbalance <= 1.5);
412    }
413
414    #[test]
415    fn test_greedy_min_cut() {
416        let nodes = make_nodes(4);
417        let edges = vec![
418            PartEdge::new(0, 1, 10.0),
419            PartEdge::new(2, 3, 10.0),
420            PartEdge::new(1, 2, 1.0),
421        ];
422        let partitioner = GraphPartitioner::new(nodes, edges);
423        let result = partitioner.partition(2, PartitionStrategy::GreedyMinCut);
424        assert_eq!(result.partition_count(), 2);
425        // Ideally {0,1} and {2,3} with cut cost = 1.0
426        // The greedy might not be perfect but edge_cut_cost should be computed
427        assert!(result.edge_cut_cost >= 0.0);
428    }
429
430    #[test]
431    fn test_partition_result_partition_of() {
432        let nodes = make_nodes(4);
433        let edges = Vec::new();
434        let partitioner = GraphPartitioner::new(nodes, edges);
435        let result = partitioner.partition(2, PartitionStrategy::RoundRobin);
436        for i in 0..4 {
437            assert!(result.partition_of(i).is_some());
438        }
439        assert!(result.partition_of(999).is_none());
440    }
441
442    #[test]
443    fn test_cut_edges() {
444        let nodes = make_nodes(4);
445        let edges = vec![
446            PartEdge::new(0, 1, 5.0),
447            PartEdge::new(2, 3, 5.0),
448            PartEdge::new(1, 2, 3.0),
449        ];
450        let partitioner = GraphPartitioner::new(nodes, edges.clone());
451        let result = partitioner.partition(2, PartitionStrategy::RoundRobin);
452        let cuts = result.cut_edges(&edges);
453        // Some edges should cross partitions
454        assert!(!cuts.is_empty() || result.edge_cut_cost == 0.0);
455    }
456
457    #[test]
458    fn test_empty_graph_partition() {
459        let partitioner = GraphPartitioner::new(Vec::new(), Vec::new());
460        let result = partitioner.partition(2, PartitionStrategy::RoundRobin);
461        assert_eq!(result.partition_count(), 2);
462        assert!((result.edge_cut_cost - 0.0).abs() < f64::EPSILON);
463    }
464
465    #[test]
466    fn test_zero_partitions() {
467        let nodes = make_nodes(4);
468        let partitioner = GraphPartitioner::new(nodes, Vec::new());
469        let result = partitioner.partition(0, PartitionStrategy::RoundRobin);
470        assert!(result.partitions.is_empty());
471    }
472
473    #[test]
474    fn test_single_partition() {
475        let nodes = make_nodes(4);
476        let edges = make_chain_edges(4);
477        let partitioner = GraphPartitioner::new(nodes, edges);
478        let result = partitioner.partition(1, PartitionStrategy::GreedyBalance);
479        assert_eq!(result.partition_count(), 1);
480        assert_eq!(result.partitions[0].node_count(), 4);
481        assert!((result.edge_cut_cost - 0.0).abs() < f64::EPSILON);
482    }
483
484    #[test]
485    fn test_imbalance_ratio() {
486        let nodes = vec![PartNode::new(0, 10.0, 100), PartNode::new(1, 1.0, 100)];
487        let partitioner = GraphPartitioner::new(nodes, Vec::new());
488        let result = partitioner.partition(2, PartitionStrategy::RoundRobin);
489        // Partition 0 has weight 10, partition 1 has weight 1, avg = 5.5
490        assert!(result.imbalance > 1.0);
491    }
492}