1use std::cell::RefCell;
12use std::cmp::Reverse;
13use std::collections::BinaryHeap;
14use std::sync::atomic::{AtomicBool, Ordering};
15
16use rayon::prelude::*;
17
18use crate::cch::contract::Topology;
19use crate::cch::{Cch, CchGraph};
20
21thread_local! {
22 static WORKSPACE: RefCell<Vec<WeightPair>> = const { RefCell::new(Vec::new()) };
23}
24
25#[repr(C)]
26#[derive(Clone, Copy)]
27pub struct WeightPair {
28 up: f32,
29 down: f32,
30}
31#[derive(Clone)]
32pub struct Weights {
33 pub up: Vec<f32>,
34 pub down: Vec<f32>,
35}
36#[repr(C)]
37#[derive(Clone, Copy)]
38pub struct WitnessPair {
39 pub up: u32,
40 pub down: u32,
41}
42#[derive(Clone)]
43pub struct Shortcuts {
44 pub middles: Vec<WitnessPair>,
45}
46
47#[derive(Debug, Clone)]
48pub struct WeightMap {
49 pub up_map: Vec<u32>,
50 pub down_map: Vec<u32>,
51}
52impl WeightMap {
53 pub fn build<G: CchGraph>(graph: &G, ranks: &[u32], topology: &Topology) -> Self {
54 let (mut up_map, mut down_map) = (vec![u32::MAX; graph.num_edges()], vec![u32::MAX; graph.num_edges()]);
55 (0..graph.num_nodes() as u32).for_each(|u| {
56 let u_rank = ranks[u as usize];
57 graph.edge_indices(u as usize).for_each(|edge_idx| {
58 let v_rank = ranks[graph.head()[edge_idx] as usize];
59 let (from_r, to_r) = if u_rank < v_rank { (u_rank, v_rank) } else { (v_rank, u_rank) };
60 if let Some(arc_id) = unsafe { topology.find_edge_id(from_r, to_r) } {
61 let target = if u_rank < v_rank { &mut up_map } else { &mut down_map };
62 target[edge_idx] = arc_id;
63 }
64 });
65 });
66 Self { up_map, down_map }
67 }
68}
69impl WeightPair {
70 const INFINITY: Self = Self {
71 up: f32::INFINITY,
72 down: f32::INFINITY,
73 };
74 #[inline(always)]
75 fn is_infinite(&self) -> bool { self.up == f32::INFINITY && self.down == f32::INFINITY }
76 #[inline(always)]
77 fn reset(&mut self) {
78 self.up = f32::INFINITY;
79 self.down = f32::INFINITY;
80 }
81}
82
83#[derive(Clone, Copy)]
84struct RawSlice<T>(*mut T);
85unsafe impl<T> Send for RawSlice<T> {}
86unsafe impl<T> Sync for RawSlice<T> {}
87impl<T> RawSlice<T> {
88 #[inline(always)]
89 unsafe fn get(self, i: usize) -> *mut T { unsafe { self.0.add(i) } }
90}
91
92struct MinArcQueue {
93 heap: BinaryHeap<Reverse<u32>>,
94 queued: Vec<bool>,
95}
96impl MinArcQueue {
97 fn new(arc_count: usize) -> Self {
98 Self {
99 heap: BinaryHeap::new(),
100 queued: vec![false; arc_count],
101 }
102 }
103
104 fn push(&mut self, arc: u32) {
105 let slot = &mut self.queued[arc as usize];
106 if !*slot {
107 *slot = true;
108 self.heap.push(Reverse(arc));
109 }
110 }
111
112 fn pop(&mut self) -> Option<u32> {
113 self.heap.pop().map(|Reverse(arc)| {
114 self.queued[arc as usize] = false;
115 arc
116 })
117 }
118}
119
120#[derive(Clone)]
121pub struct PartialUpdateContext {
122 tail: Vec<u32>,
123 first_in: Vec<u32>,
124 in_arc: Vec<u32>,
125 up_first: Vec<u32>,
126 up_src: Vec<u32>,
127 down_first: Vec<u32>,
128 down_src: Vec<u32>,
129}
130impl PartialUpdateContext {
131 pub fn build(topology: &Topology, mapper: &WeightMap) -> Self {
132 let arc_count = topology.head.len();
133 let mut tail = vec![0; arc_count];
134 topology.first_out.windows(2).enumerate().for_each(|(u, w)| tail[w[0] as usize..w[1] as usize].fill(u as u32));
135
136 let mut first_in = vec![0; topology.first_out.len()];
137 topology.head.iter().for_each(|&v| first_in[v as usize + 1] += 1);
138 (1..first_in.len()).for_each(|i| first_in[i] += first_in[i - 1]);
139
140 let mut in_arc = vec![0; arc_count];
141 let mut next_in = first_in[..first_in.len() - 1].to_vec();
142 topology.head.iter().enumerate().for_each(|(arc, &v)| {
143 in_arc[next_in[v as usize] as usize] = arc as u32;
144 next_in[v as usize] += 1;
145 });
146
147 let (up_first, up_src) = Self::build_source_index(&mapper.up_map, arc_count);
148 let (down_first, down_src) = Self::build_source_index(&mapper.down_map, arc_count);
149
150 Self {
151 tail,
152 first_in,
153 in_arc,
154 up_first,
155 up_src,
156 down_first,
157 down_src,
158 }
159 }
160
161 fn build_source_index(map: &[u32], arc_count: usize) -> (Vec<u32>, Vec<u32>) {
162 let mut f = vec![0; arc_count + 1];
163 map.iter().filter(|&&a| a != u32::MAX).for_each(|&a| f[a as usize + 1] += 1);
164 (1..f.len()).for_each(|i| f[i] += f[i - 1]);
165 let (mut s, mut n) = (vec![0; f[arc_count] as usize], f.clone());
166 map.iter().enumerate().filter(|&(_, &a)| a != u32::MAX).for_each(|(i, &a)| {
167 s[n[a as usize] as usize] = i as u32;
168 n[a as usize] += 1;
169 });
170 (f, s)
171 }
172
173 #[inline(always)]
174 fn intersect_by<K: Ord, F: FnMut(usize, usize)>(n1: usize, n2: usize, mut k1: impl FnMut(usize) -> K, mut k2: impl FnMut(usize) -> K, mut f: F) {
175 let (mut i, mut j) = (0, 0);
176 while i < n1 && j < n2 {
177 match k1(i).cmp(&k2(j)) {
178 std::cmp::Ordering::Less => i += 1,
179 std::cmp::Ordering::Greater => j += 1,
180 std::cmp::Ordering::Equal => {
181 f(i, j);
182 i += 1;
183 j += 1;
184 },
185 }
186 }
187 }
188
189 pub fn for_each_upper<F: FnMut(u32, u32, u32)>(&self, topo: &Topology, arc: u32, mut f: F) {
190 let (x, y) = (self.tail[arc as usize] as usize, topo.head[arc as usize] as usize);
191 let (xo, yo) = (arc as usize + 1, topo.first_out[y] as usize);
192 Self::intersect_by(
193 (topo.first_out[x + 1] as usize).saturating_sub(xo),
194 (topo.first_out[y + 1] as usize).saturating_sub(yo),
195 |i| topo.head[xo + i],
196 |j| topo.head[yo + j],
197 |i, j| f(arc, (xo + i) as u32, (yo + j) as u32),
198 );
199 }
200
201 pub fn for_each_intermediate<F: FnMut(u32, u32, u32)>(&self, topo: &Topology, arc: u32, mut f: F) {
202 let (x, y) = (self.tail[arc as usize] as usize, topo.head[arc as usize] as usize);
203 let (xo, yo) = (topo.first_out[x] as usize, self.first_in[y] as usize);
204 Self::intersect_by(
205 (arc as usize).saturating_sub(xo),
206 (self.first_in[y + 1] as usize).saturating_sub(yo),
207 |i| topo.head[xo + i],
208 |j| self.tail[self.in_arc[yo + j] as usize],
209 |i, j| f((xo + i) as u32, arc, self.in_arc[yo + j]),
210 );
211 }
212
213 pub fn for_each_lower<F: FnMut(u32, u32, u32)>(&self, topo: &Topology, arc: u32, mut f: F) {
214 let (x, y) = (self.tail[arc as usize] as usize, topo.head[arc as usize] as usize);
215 let (xo, yo) = (self.first_in[x] as usize, self.first_in[y] as usize);
216 Self::intersect_by(
217 (self.first_in[x + 1] as usize).saturating_sub(xo),
218 (self.first_in[y + 1] as usize).saturating_sub(yo),
219 |i| self.tail[self.in_arc[xo + i] as usize],
220 |j| self.tail[self.in_arc[yo + j] as usize],
221 |i, j| f(self.in_arc[xo + i], self.in_arc[yo + j], arc),
222 );
223 }
224}
225
226impl Cch {
227 pub fn build_partial_update_context(&self) -> PartialUpdateContext { PartialUpdateContext::build(&self.topology, &self.weight_map) }
228
229 #[inline]
230 fn ensure_workspace_size(&self, ws: &mut Vec<WeightPair>, required_size: usize) {
231 if ws.len() < required_size {
232 ws.resize(required_size, WeightPair::INFINITY);
233 }
234 }
235
236 #[inline(always)]
237 fn initial_arc_weight(&self, ctx: &PartialUpdateContext, input_weights: &[f32], arc: u32) -> WeightPair {
238 let mut pair = WeightPair::INFINITY;
239 ctx.up_src[ctx.up_first[arc as usize] as usize..ctx.up_first[arc as usize + 1] as usize]
240 .iter()
241 .for_each(|&e| pair.up = pair.up.min(input_weights[e as usize]));
242 ctx.down_src[ctx.down_first[arc as usize] as usize..ctx.down_first[arc as usize + 1] as usize]
243 .iter()
244 .for_each(|&e| pair.down = pair.down.min(input_weights[e as usize]));
245 pair
246 }
247
248 fn recompute_arc(&self, ctx: &PartialUpdateContext, input_weights: &[f32], weights: &mut Weights, shortcuts: &mut Shortcuts, arc: u32) -> WeightPair {
249 let mut pair = self.initial_arc_weight(ctx, input_weights, arc);
250 let mut middle = WitnessPair { up: u32::MAX, down: u32::MAX };
251
252 ctx.for_each_lower(&self.topology, arc, |bottom_arc, mid_arc, _| {
253 let bottom = bottom_arc as usize;
254 let mid = mid_arc as usize;
255 let witness = ctx.tail[bottom];
256
257 let next_up = weights.down[bottom] + weights.up[mid];
258 if next_up < pair.up {
259 pair.up = next_up;
260 middle.up = witness;
261 }
262
263 let next_down = weights.up[bottom] + weights.down[mid];
264 if next_down < pair.down {
265 pair.down = next_down;
266 middle.down = witness;
267 }
268 });
269
270 weights.up[arc as usize] = pair.up;
271 weights.down[arc as usize] = pair.down;
272 shortcuts.middles[arc as usize] = middle;
273 pair
274 }
275
276 pub fn customize(&self, mapper: &WeightMap, scheduler: &[Vec<u32>], original_weights: &[f32]) -> (Weights, Shortcuts) {
277 let num_arcs = self.topology.head.len();
278 let mut weights = vec![WeightPair::INFINITY; num_arcs];
279 let mut middles = vec![WitnessPair { up: u32::MAX, down: u32::MAX }; num_arcs];
280
281 let weights_ptr = RawSlice(weights.as_mut_ptr());
282 let witness_ptr = RawSlice(middles.as_mut_ptr());
283
284 original_weights.iter().enumerate().for_each(|(i, &weight)| unsafe {
285 let set_weight = |idx: u32, is_up: bool| {
286 if idx != u32::MAX {
287 let p = &mut *weights_ptr.get(idx as usize);
288 if is_up {
289 p.up = p.up.min(weight);
290 } else {
291 p.down = p.down.min(weight);
292 }
293 }
294 };
295 set_weight(mapper.up_map[i], true);
296 set_weight(mapper.down_map[i], false);
297 });
298
299 let num_nodes = self.ranks.len();
300 let locks: Vec<AtomicBool> = (0..num_nodes).map(|_| AtomicBool::new(false)).collect();
301
302 scheduler.iter().for_each(|l| l.par_iter().for_each(|&u| self.relax_node(u, weights_ptr, witness_ptr, &locks)));
303
304 (
305 Weights {
306 up: weights.iter().map(|w| w.up).collect(),
307 down: weights.iter().map(|w| w.down).collect(),
308 },
309 Shortcuts { middles },
310 )
311 }
312
313 pub fn customize_partial(&self, mapper: &WeightMap, ctx: &PartialUpdateContext, input_weights: &mut [f32], weights: &mut Weights, shortcuts: &mut Shortcuts, updates: &[(usize, f32)]) {
314 if updates.is_empty() {
315 return;
316 }
317
318 let mut queue = MinArcQueue::new(self.topology.head.len());
319 updates.iter().for_each(|&(edge_idx, new_weight)| {
320 assert!(edge_idx < input_weights.len(), "Input edge index {edge_idx} out of bounds");
321 input_weights[edge_idx] = new_weight;
322
323 let up_arc = mapper.up_map[edge_idx];
324 if up_arc != u32::MAX {
325 queue.push(up_arc);
326 }
327
328 let down_arc = mapper.down_map[edge_idx];
329 if down_arc != u32::MAX {
330 queue.push(down_arc);
331 }
332 });
333
334 while let Some(arc) = queue.pop() {
335 let old_up = weights.up[arc as usize];
336 let old_down = weights.down[arc as usize];
337 let new_pair = self.recompute_arc(ctx, input_weights, weights, shortcuts, arc);
338
339 if old_up == new_pair.up && old_down == new_pair.down {
340 continue;
341 }
342
343 ctx.for_each_intermediate(&self.topology, arc, |bottom_arc, _, top_arc| {
344 let bottom = bottom_arc as usize;
345 let top = top_arc as usize;
346 if weights.down[bottom] + old_up == weights.up[top]
347 || weights.up[bottom] + old_down == weights.down[top]
348 || weights.down[bottom] + new_pair.up < weights.up[top]
349 || weights.up[bottom] + new_pair.down < weights.down[top]
350 {
351 queue.push(top_arc);
352 }
353 });
354
355 ctx.for_each_upper(&self.topology, arc, |_, mid_arc, top_arc| {
356 let mid = mid_arc as usize;
357 let top = top_arc as usize;
358 if weights.up[mid] + old_down == weights.up[top]
359 || weights.down[mid] + old_up == weights.down[top]
360 || weights.up[mid] + new_pair.down < weights.up[top]
361 || weights.down[mid] + new_pair.up < weights.down[top]
362 {
363 queue.push(top_arc);
364 }
365 });
366 }
367
368 #[cfg(debug_assertions)]
369 {
370 (0..self.topology.head.len() as u32).for_each(|arc| {
371 ctx.for_each_lower(&self.topology, arc, |bottom_arc, mid_arc, top_arc| {
372 let bottom = bottom_arc as usize;
373 let mid = mid_arc as usize;
374 let top = top_arc as usize;
375 assert!(weights.up[top] <= weights.down[bottom] + weights.up[mid]);
376 assert!(weights.down[top] <= weights.up[bottom] + weights.down[mid]);
377 });
378 });
379 }
380 }
381
382 #[inline(always)]
383 fn relax_node(&self, u: u32, weights_ptr: RawSlice<WeightPair>, witnesses_ptr: RawSlice<WitnessPair>, locks: &[AtomicBool]) {
384 WORKSPACE.with(|cell| {
385 let mut ws = cell.borrow_mut();
386 self.ensure_workspace_size(&mut ws, self.ranks.len());
387
388 unsafe {
389 let ws_ptr = ws.as_mut_ptr();
390 let u_rng = self.topology.get_range(u as usize);
391 u_rng.clone().for_each(|arc| *ws_ptr.add(self.topology.get_head(arc)) = *weights_ptr.get(arc));
392
393 for uv_arc in u_rng.clone() {
394 let v = self.topology.get_head(uv_arc);
395 let uv_pair = *ws_ptr.add(v);
396 if uv_pair.is_infinite() {
397 continue;
398 }
399
400 let mut backoff = 1;
401 while locks[v].compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed).is_err() {
402 for _ in 0..backoff {
403 std::hint::spin_loop();
404 }
405 if backoff < 64 {
406 backoff <<= 1;
407 }
408 }
409
410 for vw_arc in self.topology.get_range(v) {
411 let w = self.topology.get_head(vw_arc);
412 let uw_pair = *ws_ptr.add(w);
413
414 let (target_weight, target_witness) = (weights_ptr.get(vw_arc), witnesses_ptr.get(vw_arc));
415 let (n_up, n_down) = (uv_pair.down + uw_pair.up, uv_pair.up + uw_pair.down);
416
417 if n_up < (*target_weight).up {
418 (*target_weight).up = n_up;
419 (*target_witness).up = u;
420 }
421 if n_down < (*target_weight).down {
422 (*target_weight).down = n_down;
423 (*target_witness).down = u;
424 }
425 }
426 locks[v].store(false, Ordering::Release);
427 }
428 u_rng.into_iter().for_each(|arc| (*ws_ptr.add(self.topology.get_head(arc))).reset());
429 }
430 });
431 }
432}