1use std::collections::{HashMap, HashSet};
2use super::graph_core::{Graph, GraphKind, NodeId};
3
4fn pseudo_random(seed: u64, i: u64) -> f64 {
5 let mut x = seed.wrapping_mul(6364136223846793005).wrapping_add(i.wrapping_mul(1442695040888963407));
6 x ^= x >> 33;
7 x = x.wrapping_mul(0xff51afd7ed558ccd);
8 x ^= x >> 33;
9 (x as f64) / (u64::MAX as f64)
10}
11
12pub fn spectral_partition<N, E>(graph: &Graph<N, E>) -> (Vec<NodeId>, Vec<NodeId>) {
15 let node_ids = graph.node_ids();
16 let n = node_ids.len();
17 if n <= 1 {
18 return (node_ids, Vec::new());
19 }
20
21 let idx: HashMap<NodeId, usize> = node_ids.iter().enumerate().map(|(i, &nid)| (nid, i)).collect();
22
23 let mut laplacian = vec![vec![0.0f64; n]; n];
25 for edge in graph.edges() {
26 if let (Some(&i), Some(&j)) = (idx.get(&edge.from), idx.get(&edge.to)) {
27 laplacian[i][j] -= 1.0;
28 laplacian[j][i] -= 1.0;
29 laplacian[i][i] += 1.0;
30 laplacian[j][j] += 1.0;
31 }
32 }
33
34 let fiedler = fiedler_vector(&laplacian, n);
36
37 let mut sorted_vals: Vec<f64> = fiedler.clone();
39 sorted_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
40 let median = sorted_vals[n / 2];
41
42 let mut part_a = Vec::new();
43 let mut part_b = Vec::new();
44 for (i, &nid) in node_ids.iter().enumerate() {
45 if fiedler[i] <= median {
46 part_a.push(nid);
47 } else {
48 part_b.push(nid);
49 }
50 }
51
52 if part_a.is_empty() {
54 part_a.push(part_b.pop().unwrap());
55 } else if part_b.is_empty() {
56 part_b.push(part_a.pop().unwrap());
57 }
58
59 (part_a, part_b)
60}
61
62fn fiedler_vector(laplacian: &[Vec<f64>], n: usize) -> Vec<f64> {
63 let max_iter = 300;
64 let mut v: Vec<f64> = (0..n).map(|i| pseudo_random(42, i as u64) - 0.5).collect();
65
66 let mean: f64 = v.iter().sum::<f64>() / n as f64;
68 for x in v.iter_mut() { *x -= mean; }
69 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
70 if norm > 1e-12 { for x in v.iter_mut() { *x /= norm; } }
71
72 for _ in 0..max_iter {
73 let mut w = vec![0.0f64; n];
74 for i in 0..n {
75 for j in 0..n {
76 w[i] += laplacian[i][j] * v[j];
77 }
78 }
79 let mean: f64 = w.iter().sum::<f64>() / n as f64;
80 for x in w.iter_mut() { *x -= mean; }
81 let norm: f64 = w.iter().map(|x| x * x).sum::<f64>().sqrt();
82 if norm > 1e-12 { for x in w.iter_mut() { *x /= norm; } }
83 v = w;
84 }
85 v
86}
87
88pub fn kernighan_lin<N, E>(graph: &Graph<N, E>, partition: (Vec<NodeId>, Vec<NodeId>)) -> (Vec<NodeId>, Vec<NodeId>) {
91 let (mut part_a, mut part_b) = partition;
92 if part_a.is_empty() || part_b.is_empty() {
93 return (part_a, part_b);
94 }
95
96 let node_ids = graph.node_ids();
97 let node_set: HashSet<NodeId> = node_ids.iter().copied().collect();
98
99 let mut adj: HashMap<(NodeId, NodeId), f32> = HashMap::new();
101 for edge in graph.edges() {
102 let w = edge.weight;
103 *adj.entry((edge.from, edge.to)).or_insert(0.0) += w;
104 if graph.kind == GraphKind::Undirected {
105 *adj.entry((edge.to, edge.from)).or_insert(0.0) += w;
106 }
107 }
108
109 let max_passes = 20;
110 for _ in 0..max_passes {
111 let set_a: HashSet<NodeId> = part_a.iter().copied().collect();
112 let set_b: HashSet<NodeId> = part_b.iter().copied().collect();
113
114 let mut d: HashMap<NodeId, f32> = HashMap::new();
116 for &v in &part_a {
117 let ext: f32 = part_b.iter()
118 .map(|&u| adj.get(&(v, u)).copied().unwrap_or(0.0))
119 .sum();
120 let int: f32 = part_a.iter()
121 .filter(|&&u| u != v)
122 .map(|&u| adj.get(&(v, u)).copied().unwrap_or(0.0))
123 .sum();
124 d.insert(v, ext - int);
125 }
126 for &v in &part_b {
127 let ext: f32 = part_a.iter()
128 .map(|&u| adj.get(&(v, u)).copied().unwrap_or(0.0))
129 .sum();
130 let int: f32 = part_b.iter()
131 .filter(|&&u| u != v)
132 .map(|&u| adj.get(&(v, u)).copied().unwrap_or(0.0))
133 .sum();
134 d.insert(v, ext - int);
135 }
136
137 let mut best_gain = f32::NEG_INFINITY;
139 let mut best_a = part_a[0];
140 let mut best_b = part_b[0];
141
142 for &a in &part_a {
143 for &b in &part_b {
144 let c_ab = adj.get(&(a, b)).copied().unwrap_or(0.0);
145 let gain = d[&a] + d[&b] - 2.0 * c_ab;
146 if gain > best_gain {
147 best_gain = gain;
148 best_a = a;
149 best_b = b;
150 }
151 }
152 }
153
154 if best_gain <= 0.0 {
155 break;
156 }
157
158 if let Some(pos) = part_a.iter().position(|&x| x == best_a) {
160 part_a[pos] = best_b;
161 }
162 if let Some(pos) = part_b.iter().position(|&x| x == best_b) {
163 part_b[pos] = best_a;
164 }
165 }
166
167 (part_a, part_b)
168}
169
170pub fn recursive_bisection<N: Clone, E: Clone>(graph: &Graph<N, E>, depth: usize) -> Vec<Vec<NodeId>> {
172 if depth == 0 || graph.node_count() <= 1 {
173 return vec![graph.node_ids()];
174 }
175
176 let (a, b) = spectral_partition(graph);
177
178 let mut result = Vec::new();
179 if depth > 1 && a.len() > 1 {
180 let sub_a = graph.subgraph(&a);
181 result.extend(recursive_bisection(&sub_a, depth - 1));
182 } else {
183 result.push(a);
184 }
185 if depth > 1 && b.len() > 1 {
186 let sub_b = graph.subgraph(&b);
187 result.extend(recursive_bisection(&sub_b, depth - 1));
188 } else {
189 result.push(b);
190 }
191
192 result
193}
194
195pub fn partition_quality<N, E>(graph: &Graph<N, E>, parts: &[Vec<NodeId>]) -> f32 {
198 let total_edges = graph.edge_count() as f32;
199 if total_edges == 0.0 { return 0.0; }
200
201 let mut node_part: HashMap<NodeId, usize> = HashMap::new();
202 for (pi, part) in parts.iter().enumerate() {
203 for &nid in part {
204 node_part.insert(nid, pi);
205 }
206 }
207
208 let mut cut_edges = 0usize;
209 for edge in graph.edges() {
210 let pa = node_part.get(&edge.from).copied();
211 let pb = node_part.get(&edge.to).copied();
212 if pa != pb {
213 cut_edges += 1;
214 }
215 }
216
217 cut_edges as f32 / total_edges
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use crate::graph::generators;
224
225 #[test]
226 fn test_spectral_partition_splits() {
227 let g = generators::path_graph(10);
228 let (a, b) = spectral_partition(&g);
229 assert!(!a.is_empty());
230 assert!(!b.is_empty());
231 assert_eq!(a.len() + b.len(), 10);
232 }
233
234 #[test]
235 fn test_spectral_partition_two_components() {
236 let mut g = Graph::<(), ()>::new(GraphKind::Undirected);
238 let n1: Vec<NodeId> = (0..5).map(|_| g.add_node(())).collect();
239 for i in 0..5 { for j in (i+1)..5 { g.add_edge(n1[i], n1[j], ()); } }
240 let n2: Vec<NodeId> = (0..5).map(|_| g.add_node(())).collect();
241 for i in 0..5 { for j in (i+1)..5 { g.add_edge(n2[i], n2[j], ()); } }
242
243 let (a, b) = spectral_partition(&g);
244 assert!(!a.is_empty());
245 assert!(!b.is_empty());
246 }
247
248 #[test]
249 fn test_kernighan_lin_improves() {
250 let g = generators::path_graph(8);
251 let ids = g.node_ids();
252 let a: Vec<NodeId> = ids.iter().step_by(2).copied().collect();
254 let b: Vec<NodeId> = ids.iter().skip(1).step_by(2).copied().collect();
255 let q_before = partition_quality(&g, &[a.clone(), b.clone()]);
256 let (ra, rb) = kernighan_lin(&g, (a, b));
257 let q_after = partition_quality(&g, &[ra, rb]);
258 assert!(q_after <= q_before + 0.01, "KL should not significantly worsen: {} vs {}", q_after, q_before);
259 }
260
261 #[test]
262 fn test_recursive_bisection() {
263 let g = generators::path_graph(16);
264 let parts = recursive_bisection(&g, 2);
265 assert!(parts.len() >= 2);
266 let total: usize = parts.iter().map(|p| p.len()).sum();
267 assert_eq!(total, 16);
268 }
269
270 #[test]
271 fn test_partition_quality_perfect() {
272 let mut g = Graph::<(), ()>::new(GraphKind::Undirected);
274 let a = g.add_node(());
275 let b = g.add_node(());
276 let c = g.add_node(());
277 let d = g.add_node(());
278 g.add_edge(a, b, ());
279 g.add_edge(c, d, ());
280 let q = partition_quality(&g, &[vec![a, b], vec![c, d]]);
281 assert_eq!(q, 0.0);
282 }
283
284 #[test]
285 fn test_partition_quality_worst() {
286 let g = generators::complete_bipartite(3, 3);
287 let ids = g.node_ids();
288 let q = partition_quality(&g, &[ids.clone()]);
290 assert_eq!(q, 0.0);
291 }
292
293 #[test]
294 fn test_single_node() {
295 let mut g = Graph::<(), ()>::new(GraphKind::Undirected);
296 g.add_node(());
297 let (a, b) = spectral_partition(&g);
298 assert_eq!(a.len() + b.len(), 1);
299 }
300
301 #[test]
302 fn test_recursive_bisection_depth_0() {
303 let g = generators::path_graph(5);
304 let parts = recursive_bisection(&g, 0);
305 assert_eq!(parts.len(), 1);
306 assert_eq!(parts[0].len(), 5);
307 }
308}