Skip to main content

proof_engine/graph/
partition.rs

1use std::collections::{HashMap, HashSet};
2use super::graph_core::{Graph, GraphKind, NodeId};
3
4fn pseudo_random(seed: u64, i: u64) -> f64 {
5    let mut x = seed.wrapping_mul(6364136223846793005).wrapping_add(i.wrapping_mul(1442695040888963407));
6    x ^= x >> 33;
7    x = x.wrapping_mul(0xff51afd7ed558ccd);
8    x ^= x >> 33;
9    (x as f64) / (u64::MAX as f64)
10}
11
12/// Spectral bisection using the Fiedler vector (2nd smallest eigenvector of Laplacian).
13/// Nodes with Fiedler value < median go to partition A, rest to partition B.
14pub fn spectral_partition<N, E>(graph: &Graph<N, E>) -> (Vec<NodeId>, Vec<NodeId>) {
15    let node_ids = graph.node_ids();
16    let n = node_ids.len();
17    if n <= 1 {
18        return (node_ids, Vec::new());
19    }
20
21    let idx: HashMap<NodeId, usize> = node_ids.iter().enumerate().map(|(i, &nid)| (nid, i)).collect();
22
23    // Build Laplacian
24    let mut laplacian = vec![vec![0.0f64; n]; n];
25    for edge in graph.edges() {
26        if let (Some(&i), Some(&j)) = (idx.get(&edge.from), idx.get(&edge.to)) {
27            laplacian[i][j] -= 1.0;
28            laplacian[j][i] -= 1.0;
29            laplacian[i][i] += 1.0;
30            laplacian[j][j] += 1.0;
31        }
32    }
33
34    // Compute Fiedler vector via power iteration on L
35    let fiedler = fiedler_vector(&laplacian, n);
36
37    // Split by median
38    let mut sorted_vals: Vec<f64> = fiedler.clone();
39    sorted_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
40    let median = sorted_vals[n / 2];
41
42    let mut part_a = Vec::new();
43    let mut part_b = Vec::new();
44    for (i, &nid) in node_ids.iter().enumerate() {
45        if fiedler[i] <= median {
46            part_a.push(nid);
47        } else {
48            part_b.push(nid);
49        }
50    }
51
52    // Ensure both partitions are non-empty
53    if part_a.is_empty() {
54        part_a.push(part_b.pop().unwrap());
55    } else if part_b.is_empty() {
56        part_b.push(part_a.pop().unwrap());
57    }
58
59    (part_a, part_b)
60}
61
62fn fiedler_vector(laplacian: &[Vec<f64>], n: usize) -> Vec<f64> {
63    let max_iter = 300;
64    let mut v: Vec<f64> = (0..n).map(|i| pseudo_random(42, i as u64) - 0.5).collect();
65
66    // Orthogonalize against constant vector
67    let mean: f64 = v.iter().sum::<f64>() / n as f64;
68    for x in v.iter_mut() { *x -= mean; }
69    let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
70    if norm > 1e-12 { for x in v.iter_mut() { *x /= norm; } }
71
72    for _ in 0..max_iter {
73        let mut w = vec![0.0f64; n];
74        for i in 0..n {
75            for j in 0..n {
76                w[i] += laplacian[i][j] * v[j];
77            }
78        }
79        let mean: f64 = w.iter().sum::<f64>() / n as f64;
80        for x in w.iter_mut() { *x -= mean; }
81        let norm: f64 = w.iter().map(|x| x * x).sum::<f64>().sqrt();
82        if norm > 1e-12 { for x in w.iter_mut() { *x /= norm; } }
83        v = w;
84    }
85    v
86}
87
88/// Kernighan-Lin refinement: iteratively swap pairs of nodes between partitions
89/// to minimize edge cut.
90pub fn kernighan_lin<N, E>(graph: &Graph<N, E>, partition: (Vec<NodeId>, Vec<NodeId>)) -> (Vec<NodeId>, Vec<NodeId>) {
91    let (mut part_a, mut part_b) = partition;
92    if part_a.is_empty() || part_b.is_empty() {
93        return (part_a, part_b);
94    }
95
96    let node_ids = graph.node_ids();
97    let node_set: HashSet<NodeId> = node_ids.iter().copied().collect();
98
99    // Build adjacency weight map
100    let mut adj: HashMap<(NodeId, NodeId), f32> = HashMap::new();
101    for edge in graph.edges() {
102        let w = edge.weight;
103        *adj.entry((edge.from, edge.to)).or_insert(0.0) += w;
104        if graph.kind == GraphKind::Undirected {
105            *adj.entry((edge.to, edge.from)).or_insert(0.0) += w;
106        }
107    }
108
109    let max_passes = 20;
110    for _ in 0..max_passes {
111        let set_a: HashSet<NodeId> = part_a.iter().copied().collect();
112        let set_b: HashSet<NodeId> = part_b.iter().copied().collect();
113
114        // Compute D values: D[v] = external_cost - internal_cost
115        let mut d: HashMap<NodeId, f32> = HashMap::new();
116        for &v in &part_a {
117            let ext: f32 = part_b.iter()
118                .map(|&u| adj.get(&(v, u)).copied().unwrap_or(0.0))
119                .sum();
120            let int: f32 = part_a.iter()
121                .filter(|&&u| u != v)
122                .map(|&u| adj.get(&(v, u)).copied().unwrap_or(0.0))
123                .sum();
124            d.insert(v, ext - int);
125        }
126        for &v in &part_b {
127            let ext: f32 = part_a.iter()
128                .map(|&u| adj.get(&(v, u)).copied().unwrap_or(0.0))
129                .sum();
130            let int: f32 = part_b.iter()
131                .filter(|&&u| u != v)
132                .map(|&u| adj.get(&(v, u)).copied().unwrap_or(0.0))
133                .sum();
134            d.insert(v, ext - int);
135        }
136
137        // Find best swap
138        let mut best_gain = f32::NEG_INFINITY;
139        let mut best_a = part_a[0];
140        let mut best_b = part_b[0];
141
142        for &a in &part_a {
143            for &b in &part_b {
144                let c_ab = adj.get(&(a, b)).copied().unwrap_or(0.0);
145                let gain = d[&a] + d[&b] - 2.0 * c_ab;
146                if gain > best_gain {
147                    best_gain = gain;
148                    best_a = a;
149                    best_b = b;
150                }
151            }
152        }
153
154        if best_gain <= 0.0 {
155            break;
156        }
157
158        // Perform swap
159        if let Some(pos) = part_a.iter().position(|&x| x == best_a) {
160            part_a[pos] = best_b;
161        }
162        if let Some(pos) = part_b.iter().position(|&x| x == best_b) {
163            part_b[pos] = best_a;
164        }
165    }
166
167    (part_a, part_b)
168}
169
170/// Recursive bisection: repeatedly partition each part.
171pub fn recursive_bisection<N: Clone, E: Clone>(graph: &Graph<N, E>, depth: usize) -> Vec<Vec<NodeId>> {
172    if depth == 0 || graph.node_count() <= 1 {
173        return vec![graph.node_ids()];
174    }
175
176    let (a, b) = spectral_partition(graph);
177
178    let mut result = Vec::new();
179    if depth > 1 && a.len() > 1 {
180        let sub_a = graph.subgraph(&a);
181        result.extend(recursive_bisection(&sub_a, depth - 1));
182    } else {
183        result.push(a);
184    }
185    if depth > 1 && b.len() > 1 {
186        let sub_b = graph.subgraph(&b);
187        result.extend(recursive_bisection(&sub_b, depth - 1));
188    } else {
189        result.push(b);
190    }
191
192    result
193}
194
195/// Partition quality: ratio of edges cut to total edges.
196/// Lower is better (fewer inter-partition edges).
197pub fn partition_quality<N, E>(graph: &Graph<N, E>, parts: &[Vec<NodeId>]) -> f32 {
198    let total_edges = graph.edge_count() as f32;
199    if total_edges == 0.0 { return 0.0; }
200
201    let mut node_part: HashMap<NodeId, usize> = HashMap::new();
202    for (pi, part) in parts.iter().enumerate() {
203        for &nid in part {
204            node_part.insert(nid, pi);
205        }
206    }
207
208    let mut cut_edges = 0usize;
209    for edge in graph.edges() {
210        let pa = node_part.get(&edge.from).copied();
211        let pb = node_part.get(&edge.to).copied();
212        if pa != pb {
213            cut_edges += 1;
214        }
215    }
216
217    cut_edges as f32 / total_edges
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::graph::generators;
224
225    #[test]
226    fn test_spectral_partition_splits() {
227        let g = generators::path_graph(10);
228        let (a, b) = spectral_partition(&g);
229        assert!(!a.is_empty());
230        assert!(!b.is_empty());
231        assert_eq!(a.len() + b.len(), 10);
232    }
233
234    #[test]
235    fn test_spectral_partition_two_components() {
236        // Two disconnected cliques should be easy to partition
237        let mut g = Graph::<(), ()>::new(GraphKind::Undirected);
238        let n1: Vec<NodeId> = (0..5).map(|_| g.add_node(())).collect();
239        for i in 0..5 { for j in (i+1)..5 { g.add_edge(n1[i], n1[j], ()); } }
240        let n2: Vec<NodeId> = (0..5).map(|_| g.add_node(())).collect();
241        for i in 0..5 { for j in (i+1)..5 { g.add_edge(n2[i], n2[j], ()); } }
242
243        let (a, b) = spectral_partition(&g);
244        assert!(!a.is_empty());
245        assert!(!b.is_empty());
246    }
247
248    #[test]
249    fn test_kernighan_lin_improves() {
250        let g = generators::path_graph(8);
251        let ids = g.node_ids();
252        // Bad initial partition: alternating
253        let a: Vec<NodeId> = ids.iter().step_by(2).copied().collect();
254        let b: Vec<NodeId> = ids.iter().skip(1).step_by(2).copied().collect();
255        let q_before = partition_quality(&g, &[a.clone(), b.clone()]);
256        let (ra, rb) = kernighan_lin(&g, (a, b));
257        let q_after = partition_quality(&g, &[ra, rb]);
258        assert!(q_after <= q_before + 0.01, "KL should not significantly worsen: {} vs {}", q_after, q_before);
259    }
260
261    #[test]
262    fn test_recursive_bisection() {
263        let g = generators::path_graph(16);
264        let parts = recursive_bisection(&g, 2);
265        assert!(parts.len() >= 2);
266        let total: usize = parts.iter().map(|p| p.len()).sum();
267        assert_eq!(total, 16);
268    }
269
270    #[test]
271    fn test_partition_quality_perfect() {
272        // Two disconnected components, partitioned correctly
273        let mut g = Graph::<(), ()>::new(GraphKind::Undirected);
274        let a = g.add_node(());
275        let b = g.add_node(());
276        let c = g.add_node(());
277        let d = g.add_node(());
278        g.add_edge(a, b, ());
279        g.add_edge(c, d, ());
280        let q = partition_quality(&g, &[vec![a, b], vec![c, d]]);
281        assert_eq!(q, 0.0);
282    }
283
284    #[test]
285    fn test_partition_quality_worst() {
286        let g = generators::complete_bipartite(3, 3);
287        let ids = g.node_ids();
288        // Put all in one partition: no cuts
289        let q = partition_quality(&g, &[ids.clone()]);
290        assert_eq!(q, 0.0);
291    }
292
293    #[test]
294    fn test_single_node() {
295        let mut g = Graph::<(), ()>::new(GraphKind::Undirected);
296        g.add_node(());
297        let (a, b) = spectral_partition(&g);
298        assert_eq!(a.len() + b.len(), 1);
299    }
300
301    #[test]
302    fn test_recursive_bisection_depth_0() {
303        let g = generators::path_graph(5);
304        let parts = recursive_bisection(&g, 0);
305        assert_eq!(parts.len(), 1);
306        assert_eq!(parts[0].len(), 5);
307    }
308}