Skip to main content

cch_rs/cch/
customize.rs

1//! Parallelized metric customization phase.
2//!
3//! Customization is the "C" in CCH. Once the structural topology is built, this module
4//! takes a specific set of edge weights (e.g., travel times for a specific vehicle profile)
5//! and propagates them through the shortcut arcs.
6//!
7//! This phase is heavily optimized and parallelized using `rayon`. It utilizes a lock-free
8//! approach with atomic spin-locks (`AtomicBool`) and thread-local workspaces to evaluate
9//! the elimination tree levels concurrently.
10
11use std::cell::RefCell;
12use std::cmp::Reverse;
13use std::collections::BinaryHeap;
14use std::sync::atomic::{AtomicBool, Ordering};
15
16use rayon::prelude::*;
17
18use crate::cch::contract::Topology;
19use crate::cch::{Cch, CchGraph};
20
21thread_local! {
22    static WORKSPACE: RefCell<Vec<WeightPair>> = const { RefCell::new(Vec::new()) };
23}
24
25#[repr(C)]
26#[derive(Clone, Copy)]
27pub struct WeightPair {
28    up: f32,
29    down: f32,
30}
31#[derive(Clone)]
32pub struct Weights {
33    pub up: Vec<f32>,
34    pub down: Vec<f32>,
35}
36#[repr(C)]
37#[derive(Clone, Copy)]
38pub struct WitnessPair {
39    pub up: u32,
40    pub down: u32,
41}
42#[derive(Clone)]
43pub struct Shortcuts {
44    pub middles: Vec<WitnessPair>,
45}
46
47#[derive(Debug, Clone)]
48pub struct WeightMap {
49    pub up_map: Vec<u32>,
50    pub down_map: Vec<u32>,
51}
52impl WeightMap {
53    pub fn build<G: CchGraph>(graph: &G, ranks: &[u32], topology: &Topology) -> Self {
54        let (mut up_map, mut down_map) = (vec![u32::MAX; graph.num_edges()], vec![u32::MAX; graph.num_edges()]);
55        (0..graph.num_nodes() as u32).for_each(|u| {
56            let u_rank = ranks[u as usize];
57            graph.edge_indices(u as usize).for_each(|edge_idx| {
58                let v_rank = ranks[graph.head()[edge_idx] as usize];
59                let (from_r, to_r) = if u_rank < v_rank { (u_rank, v_rank) } else { (v_rank, u_rank) };
60                if let Some(arc_id) = unsafe { topology.find_edge_id(from_r, to_r) } {
61                    let target = if u_rank < v_rank { &mut up_map } else { &mut down_map };
62                    target[edge_idx] = arc_id;
63                }
64            });
65        });
66        Self { up_map, down_map }
67    }
68}
69impl WeightPair {
70    const INFINITY: Self = Self {
71        up: f32::INFINITY,
72        down: f32::INFINITY,
73    };
74    #[inline(always)]
75    fn is_infinite(&self) -> bool { self.up == f32::INFINITY && self.down == f32::INFINITY }
76    #[inline(always)]
77    fn reset(&mut self) {
78        self.up = f32::INFINITY;
79        self.down = f32::INFINITY;
80    }
81}
82
83#[derive(Clone, Copy)]
84struct RawSlice<T>(*mut T);
85unsafe impl<T> Send for RawSlice<T> {}
86unsafe impl<T> Sync for RawSlice<T> {}
87impl<T> RawSlice<T> {
88    #[inline(always)]
89    unsafe fn get(self, i: usize) -> *mut T { unsafe { self.0.add(i) } }
90}
91
92struct MinArcQueue {
93    heap: BinaryHeap<Reverse<u32>>,
94    queued: Vec<bool>,
95}
96impl MinArcQueue {
97    fn new(arc_count: usize) -> Self {
98        Self {
99            heap: BinaryHeap::new(),
100            queued: vec![false; arc_count],
101        }
102    }
103
104    fn push(&mut self, arc: u32) {
105        let slot = &mut self.queued[arc as usize];
106        if !*slot {
107            *slot = true;
108            self.heap.push(Reverse(arc));
109        }
110    }
111
112    fn pop(&mut self) -> Option<u32> {
113        self.heap.pop().map(|Reverse(arc)| {
114            self.queued[arc as usize] = false;
115            arc
116        })
117    }
118}
119
120#[derive(Clone)]
121pub struct PartialUpdateContext {
122    tail: Vec<u32>,
123    first_in: Vec<u32>,
124    in_arc: Vec<u32>,
125    up_first: Vec<u32>,
126    up_src: Vec<u32>,
127    down_first: Vec<u32>,
128    down_src: Vec<u32>,
129}
130impl PartialUpdateContext {
131    pub fn build(topology: &Topology, mapper: &WeightMap) -> Self {
132        let arc_count = topology.head.len();
133        let mut tail = vec![0; arc_count];
134        topology.first_out.windows(2).enumerate().for_each(|(u, w)| tail[w[0] as usize..w[1] as usize].fill(u as u32));
135
136        let mut first_in = vec![0; topology.first_out.len()];
137        topology.head.iter().for_each(|&v| first_in[v as usize + 1] += 1);
138        (1..first_in.len()).for_each(|i| first_in[i] += first_in[i - 1]);
139
140        let mut in_arc = vec![0; arc_count];
141        let mut next_in = first_in[..first_in.len() - 1].to_vec();
142        topology.head.iter().enumerate().for_each(|(arc, &v)| {
143            in_arc[next_in[v as usize] as usize] = arc as u32;
144            next_in[v as usize] += 1;
145        });
146
147        let (up_first, up_src) = Self::build_source_index(&mapper.up_map, arc_count);
148        let (down_first, down_src) = Self::build_source_index(&mapper.down_map, arc_count);
149
150        Self {
151            tail,
152            first_in,
153            in_arc,
154            up_first,
155            up_src,
156            down_first,
157            down_src,
158        }
159    }
160
161    fn build_source_index(map: &[u32], arc_count: usize) -> (Vec<u32>, Vec<u32>) {
162        let mut f = vec![0; arc_count + 1];
163        map.iter().filter(|&&a| a != u32::MAX).for_each(|&a| f[a as usize + 1] += 1);
164        (1..f.len()).for_each(|i| f[i] += f[i - 1]);
165        let (mut s, mut n) = (vec![0; f[arc_count] as usize], f.clone());
166        map.iter().enumerate().filter(|&(_, &a)| a != u32::MAX).for_each(|(i, &a)| {
167            s[n[a as usize] as usize] = i as u32;
168            n[a as usize] += 1;
169        });
170        (f, s)
171    }
172
173    #[inline(always)]
174    fn intersect_by<K: Ord, F: FnMut(usize, usize)>(n1: usize, n2: usize, mut k1: impl FnMut(usize) -> K, mut k2: impl FnMut(usize) -> K, mut f: F) {
175        let (mut i, mut j) = (0, 0);
176        while i < n1 && j < n2 {
177            match k1(i).cmp(&k2(j)) {
178                std::cmp::Ordering::Less => i += 1,
179                std::cmp::Ordering::Greater => j += 1,
180                std::cmp::Ordering::Equal => {
181                    f(i, j);
182                    i += 1;
183                    j += 1;
184                },
185            }
186        }
187    }
188
189    pub fn for_each_upper<F: FnMut(u32, u32, u32)>(&self, topo: &Topology, arc: u32, mut f: F) {
190        let (x, y) = (self.tail[arc as usize] as usize, topo.head[arc as usize] as usize);
191        let (xo, yo) = (arc as usize + 1, topo.first_out[y] as usize);
192        Self::intersect_by(
193            (topo.first_out[x + 1] as usize).saturating_sub(xo),
194            (topo.first_out[y + 1] as usize).saturating_sub(yo),
195            |i| topo.head[xo + i],
196            |j| topo.head[yo + j],
197            |i, j| f(arc, (xo + i) as u32, (yo + j) as u32),
198        );
199    }
200
201    pub fn for_each_intermediate<F: FnMut(u32, u32, u32)>(&self, topo: &Topology, arc: u32, mut f: F) {
202        let (x, y) = (self.tail[arc as usize] as usize, topo.head[arc as usize] as usize);
203        let (xo, yo) = (topo.first_out[x] as usize, self.first_in[y] as usize);
204        Self::intersect_by(
205            (arc as usize).saturating_sub(xo),
206            (self.first_in[y + 1] as usize).saturating_sub(yo),
207            |i| topo.head[xo + i],
208            |j| self.tail[self.in_arc[yo + j] as usize],
209            |i, j| f((xo + i) as u32, arc, self.in_arc[yo + j]),
210        );
211    }
212
213    pub fn for_each_lower<F: FnMut(u32, u32, u32)>(&self, topo: &Topology, arc: u32, mut f: F) {
214        let (x, y) = (self.tail[arc as usize] as usize, topo.head[arc as usize] as usize);
215        let (xo, yo) = (self.first_in[x] as usize, self.first_in[y] as usize);
216        Self::intersect_by(
217            (self.first_in[x + 1] as usize).saturating_sub(xo),
218            (self.first_in[y + 1] as usize).saturating_sub(yo),
219            |i| self.tail[self.in_arc[xo + i] as usize],
220            |j| self.tail[self.in_arc[yo + j] as usize],
221            |i, j| f(self.in_arc[xo + i], self.in_arc[yo + j], arc),
222        );
223    }
224}
225
226impl Cch {
227    pub fn build_partial_update_context(&self) -> PartialUpdateContext { PartialUpdateContext::build(&self.topology, &self.weight_map) }
228
229    #[inline]
230    fn ensure_workspace_size(&self, ws: &mut Vec<WeightPair>, required_size: usize) {
231        if ws.len() < required_size {
232            ws.resize(required_size, WeightPair::INFINITY);
233        }
234    }
235
236    #[inline(always)]
237    fn initial_arc_weight(&self, ctx: &PartialUpdateContext, input_weights: &[f32], arc: u32) -> WeightPair {
238        let mut pair = WeightPair::INFINITY;
239        ctx.up_src[ctx.up_first[arc as usize] as usize..ctx.up_first[arc as usize + 1] as usize]
240            .iter()
241            .for_each(|&e| pair.up = pair.up.min(input_weights[e as usize]));
242        ctx.down_src[ctx.down_first[arc as usize] as usize..ctx.down_first[arc as usize + 1] as usize]
243            .iter()
244            .for_each(|&e| pair.down = pair.down.min(input_weights[e as usize]));
245        pair
246    }
247
248    fn recompute_arc(&self, ctx: &PartialUpdateContext, input_weights: &[f32], weights: &mut Weights, shortcuts: &mut Shortcuts, arc: u32) -> WeightPair {
249        let mut pair = self.initial_arc_weight(ctx, input_weights, arc);
250        let mut middle = WitnessPair { up: u32::MAX, down: u32::MAX };
251
252        ctx.for_each_lower(&self.topology, arc, |bottom_arc, mid_arc, _| {
253            let bottom = bottom_arc as usize;
254            let mid = mid_arc as usize;
255            let witness = ctx.tail[bottom];
256
257            let next_up = weights.down[bottom] + weights.up[mid];
258            if next_up < pair.up {
259                pair.up = next_up;
260                middle.up = witness;
261            }
262
263            let next_down = weights.up[bottom] + weights.down[mid];
264            if next_down < pair.down {
265                pair.down = next_down;
266                middle.down = witness;
267            }
268        });
269
270        weights.up[arc as usize] = pair.up;
271        weights.down[arc as usize] = pair.down;
272        shortcuts.middles[arc as usize] = middle;
273        pair
274    }
275
276    pub fn customize(&self, mapper: &WeightMap, scheduler: &[Vec<u32>], original_weights: &[f32]) -> (Weights, Shortcuts) {
277        let num_arcs = self.topology.head.len();
278        let mut weights = vec![WeightPair::INFINITY; num_arcs];
279        let mut middles = vec![WitnessPair { up: u32::MAX, down: u32::MAX }; num_arcs];
280
281        let weights_ptr = RawSlice(weights.as_mut_ptr());
282        let witness_ptr = RawSlice(middles.as_mut_ptr());
283
284        original_weights.iter().enumerate().for_each(|(i, &weight)| unsafe {
285            let set_weight = |idx: u32, is_up: bool| {
286                if idx != u32::MAX {
287                    let p = &mut *weights_ptr.get(idx as usize);
288                    if is_up {
289                        p.up = p.up.min(weight);
290                    } else {
291                        p.down = p.down.min(weight);
292                    }
293                }
294            };
295            set_weight(mapper.up_map[i], true);
296            set_weight(mapper.down_map[i], false);
297        });
298
299        let num_nodes = self.ranks.len();
300        let locks: Vec<AtomicBool> = (0..num_nodes).map(|_| AtomicBool::new(false)).collect();
301
302        scheduler.iter().for_each(|l| l.par_iter().for_each(|&u| self.relax_node(u, weights_ptr, witness_ptr, &locks)));
303
304        (
305            Weights {
306                up: weights.iter().map(|w| w.up).collect(),
307                down: weights.iter().map(|w| w.down).collect(),
308            },
309            Shortcuts { middles },
310        )
311    }
312
313    pub fn customize_partial(&self, mapper: &WeightMap, ctx: &PartialUpdateContext, input_weights: &mut [f32], weights: &mut Weights, shortcuts: &mut Shortcuts, updates: &[(usize, f32)]) {
314        if updates.is_empty() {
315            return;
316        }
317
318        let mut queue = MinArcQueue::new(self.topology.head.len());
319        updates.iter().for_each(|&(edge_idx, new_weight)| {
320            assert!(edge_idx < input_weights.len(), "Input edge index {edge_idx} out of bounds");
321            input_weights[edge_idx] = new_weight;
322
323            let up_arc = mapper.up_map[edge_idx];
324            if up_arc != u32::MAX {
325                queue.push(up_arc);
326            }
327
328            let down_arc = mapper.down_map[edge_idx];
329            if down_arc != u32::MAX {
330                queue.push(down_arc);
331            }
332        });
333
334        while let Some(arc) = queue.pop() {
335            let old_up = weights.up[arc as usize];
336            let old_down = weights.down[arc as usize];
337            let new_pair = self.recompute_arc(ctx, input_weights, weights, shortcuts, arc);
338
339            if old_up == new_pair.up && old_down == new_pair.down {
340                continue;
341            }
342
343            ctx.for_each_intermediate(&self.topology, arc, |bottom_arc, _, top_arc| {
344                let bottom = bottom_arc as usize;
345                let top = top_arc as usize;
346                if weights.down[bottom] + old_up == weights.up[top]
347                    || weights.up[bottom] + old_down == weights.down[top]
348                    || weights.down[bottom] + new_pair.up < weights.up[top]
349                    || weights.up[bottom] + new_pair.down < weights.down[top]
350                {
351                    queue.push(top_arc);
352                }
353            });
354
355            ctx.for_each_upper(&self.topology, arc, |_, mid_arc, top_arc| {
356                let mid = mid_arc as usize;
357                let top = top_arc as usize;
358                if weights.up[mid] + old_down == weights.up[top]
359                    || weights.down[mid] + old_up == weights.down[top]
360                    || weights.up[mid] + new_pair.down < weights.up[top]
361                    || weights.down[mid] + new_pair.up < weights.down[top]
362                {
363                    queue.push(top_arc);
364                }
365            });
366        }
367
368        #[cfg(debug_assertions)]
369        {
370            (0..self.topology.head.len() as u32).for_each(|arc| {
371                ctx.for_each_lower(&self.topology, arc, |bottom_arc, mid_arc, top_arc| {
372                    let bottom = bottom_arc as usize;
373                    let mid = mid_arc as usize;
374                    let top = top_arc as usize;
375                    assert!(weights.up[top] <= weights.down[bottom] + weights.up[mid]);
376                    assert!(weights.down[top] <= weights.up[bottom] + weights.down[mid]);
377                });
378            });
379        }
380    }
381
382    #[inline(always)]
383    fn relax_node(&self, u: u32, weights_ptr: RawSlice<WeightPair>, witnesses_ptr: RawSlice<WitnessPair>, locks: &[AtomicBool]) {
384        WORKSPACE.with(|cell| {
385            let mut ws = cell.borrow_mut();
386            self.ensure_workspace_size(&mut ws, self.ranks.len());
387
388            unsafe {
389                let ws_ptr = ws.as_mut_ptr();
390                let u_rng = self.topology.get_range(u as usize);
391                u_rng.clone().for_each(|arc| *ws_ptr.add(self.topology.get_head(arc)) = *weights_ptr.get(arc));
392
393                for uv_arc in u_rng.clone() {
394                    let v = self.topology.get_head(uv_arc);
395                    let uv_pair = *ws_ptr.add(v);
396                    if uv_pair.is_infinite() {
397                        continue;
398                    }
399
400                    let mut backoff = 1;
401                    while locks[v].compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed).is_err() {
402                        for _ in 0..backoff {
403                            std::hint::spin_loop();
404                        }
405                        if backoff < 64 {
406                            backoff <<= 1;
407                        }
408                    }
409
410                    for vw_arc in self.topology.get_range(v) {
411                        let w = self.topology.get_head(vw_arc);
412                        let uw_pair = *ws_ptr.add(w);
413
414                        let (target_weight, target_witness) = (weights_ptr.get(vw_arc), witnesses_ptr.get(vw_arc));
415                        let (n_up, n_down) = (uv_pair.down + uw_pair.up, uv_pair.up + uw_pair.down);
416
417                        if n_up < (*target_weight).up {
418                            (*target_weight).up = n_up;
419                            (*target_witness).up = u;
420                        }
421                        if n_down < (*target_weight).down {
422                            (*target_weight).down = n_down;
423                            (*target_witness).down = u;
424                        }
425                    }
426                    locks[v].store(false, Ordering::Release);
427                }
428                u_rng.into_iter().for_each(|arc| (*ws_ptr.add(self.topology.get_head(arc))).reset());
429            }
430        });
431    }
432}