ch_router/
ch.rs

1use super::{Edge as InitEdge, Node, NodeIdx};
2
3use rustc_hash::{FxHashMap, FxHashSet};
4use std::collections::BinaryHeap;
5use unfolding::{AllEdges, Predecessor, Predecessors, ShortcutVia};
6
7mod storage;
8
9#[cfg(not(feature = "path-unfolding"))]
10mod no_unfolding;
11#[cfg(not(feature = "path-unfolding"))]
12use no_unfolding as unfolding;
13
14#[cfg(feature = "path-unfolding")]
15mod unfolding;
16
17type HashMap<K, V> = FxHashMap<K, V>;
18type HashSet<K> = FxHashSet<K>;
19
20#[derive(Copy, Clone)]
21struct Edge {
22    to: NodeIdx,
23    weight: f32,
24    shortcut_via: ShortcutVia,
25}
26
27pub struct ContractionHierarchy {
28    forward_edges: Vec<Vec<Edge>>,
29    backward_edges: Vec<Vec<Edge>>,
30    all_edges: AllEdges,
31}
32
33struct NodeImportance {
34    node_idx: NodeIdx,
35    importance: u32,
36}
37
38impl Eq for NodeImportance {}
39impl PartialEq for NodeImportance {
40    fn eq(&self, other: &Self) -> bool {
41        self.importance == other.importance
42    }
43}
44
45impl PartialOrd for NodeImportance {
46    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
47        Some(self.cmp(&other))
48    }
49}
50
51impl Ord for NodeImportance {
52    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
53        // Inverted for min-heap
54        other.importance.cmp(&self.importance)
55    }
56}
57
58impl ContractionHierarchy {
59    pub fn new(nodes: &[Node], edges: &[InitEdge], max_speed: f32) -> Self {
60        let mut forward_edges = vec![vec![]; nodes.len()];
61        let mut backward_edges = vec![vec![]; nodes.len()];
62
63        for edge in edges {
64            forward_edges[edge.from as usize].push(Edge {
65                to: edge.to,
66                weight: edge.weight,
67                shortcut_via: ShortcutVia::none(),
68            });
69            backward_edges[edge.to as usize].push(Edge {
70                to: edge.from,
71                weight: edge.weight,
72                shortcut_via: ShortcutVia::none(),
73            });
74        }
75
76        let mut this = Self {
77            all_edges: AllEdges::new(&forward_edges, &backward_edges),
78            forward_edges,
79            backward_edges,
80        };
81
82        let mut contracted = vec![false; nodes.len()];
83        let mut node_importance = vec![0; nodes.len()];
84
85        let mut remaining_nodes: BinaryHeap<_> = (0..(contracted.len() as NodeIdx))
86            .map(|node_idx| NodeImportance {
87                importance: this
88                    .get_required_shortcuts(&nodes, node_idx, &contracted, max_speed)
89                    .len() as u32,
90                node_idx,
91            })
92            .collect();
93
94        let mut next_importance = 0;
95
96        while let Some(node_imp) = remaining_nodes.pop() {
97            let node_idx = node_imp.node_idx;
98            let required_shortcuts =
99                this.get_required_shortcuts(&nodes, node_idx, &contracted, max_speed);
100            let importance = required_shortcuts.len() as u32;
101
102            if importance > node_imp.importance {
103                remaining_nodes.push(NodeImportance {
104                    node_idx,
105                    importance,
106                });
107                continue;
108            }
109
110            node_importance[node_idx as usize] = next_importance;
111            next_importance += 1;
112
113            for (from, to, weight) in required_shortcuts {
114                this.add_shortcut(from, to, node_idx, weight);
115            }
116
117            contracted[node_idx as usize] = true;
118        }
119
120        this.prune_edges(&node_importance);
121
122        this
123    }
124
125    pub fn distance(&self, start: NodeIdx, target: NodeIdx) -> Option<f32> {
126        Router::new(self).distance(start, target)
127    }
128
129    #[cfg(feature = "path-unfolding")]
130    pub fn route(&self, start: NodeIdx, target: NodeIdx) -> Option<Route> {
131        Router::new(self).route(start, target)
132    }
133
134    pub fn node_count(&self) -> u32 {
135        self.forward_edges.len() as u32
136    }
137
138    pub fn create_hot_group(&self, nodes: &[NodeIdx]) -> HotGroup {
139        Router::new(self).create_hot_group(nodes)
140    }
141
142    fn prune_edges(&mut self, node_importance: &[u32]) {
143        for (i, edges) in self.forward_edges.iter_mut().enumerate() {
144            edges.retain(|edge| node_importance[i] < node_importance[edge.to as usize]);
145        }
146
147        for (i, edges) in self.backward_edges.iter_mut().enumerate() {
148            edges.retain(|edge| node_importance[edge.to as usize] > node_importance[i]);
149        }
150    }
151
152    fn get_required_shortcuts(
153        &self,
154        nodes: &[Node],
155        node_idx: NodeIdx,
156        contracted: &[bool],
157        max_speed: f32,
158    ) -> Vec<(NodeIdx, NodeIdx, f32)> {
159        let incoming = &self.backward_edges[node_idx as usize];
160        let outgoing = &self.forward_edges[node_idx as usize];
161
162        let mut required_shortcuts = vec![];
163
164        let mut pairs = HashMap::default();
165
166        for in_edge in incoming {
167            for out_edge in outgoing {
168                if contracted[in_edge.to as usize] || contracted[out_edge.to as usize] {
169                    continue;
170                }
171
172                let path_length = in_edge.weight + out_edge.weight;
173                let current = pairs
174                    .entry((in_edge.to, out_edge.to))
175                    .or_insert(path_length);
176                *current = current.min(path_length);
177            }
178        }
179
180        for ((from, to), path_length) in pairs {
181            if !self.witness_path_exists(
182                &nodes,
183                from,
184                to,
185                node_idx,
186                path_length,
187                max_speed,
188                contracted,
189            ) {
190                required_shortcuts.push((from, to, path_length));
191            }
192        }
193
194        required_shortcuts
195    }
196
197    fn witness_path_exists(
198        &self,
199        nodes: &[Node],
200        from: NodeIdx,
201        to: NodeIdx,
202        via: NodeIdx,
203        max_length: f32,
204        max_speed: f32,
205        contracted: &[bool],
206    ) -> bool {
207        struct State {
208            cost: f32,
209            h_cost: f32,
210            idx: NodeIdx,
211        }
212
213        impl Eq for State {}
214        impl PartialEq for State {
215            fn eq(&self, other: &Self) -> bool {
216                self.h_cost == other.h_cost
217            }
218        }
219
220        impl PartialOrd for State {
221            fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
222                Some(self.cmp(other))
223            }
224        }
225
226        impl Ord for State {
227            fn cmp(&self, other: &Self) -> std::cmp::Ordering {
228                other.h_cost.partial_cmp(&self.h_cost).unwrap()
229            }
230        }
231
232        let target = nodes[to as usize];
233        let straight_line_distance = |from: NodeIdx| {
234            let start = nodes[from as usize];
235            ((target.x - start.x).powi(2) + (target.y - start.y).powi(2)).sqrt() / max_speed
236        };
237
238        let mut distances = HashMap::default();
239        let mut heap = BinaryHeap::new();
240
241        distances.insert(from, 0.0);
242        heap.push(State {
243            cost: 0.0,
244            h_cost: straight_line_distance(from),
245            idx: from,
246        });
247
248        while let Some(State { cost, idx, h_cost }) = heap.pop() {
249            if cost > distances[&idx] {
250                continue;
251            }
252
253            if h_cost > max_length {
254                return false;
255            }
256
257            if idx == to {
258                return true;
259            }
260
261            for edge in &self.forward_edges[idx as usize] {
262                if edge.to == via || contracted[edge.to as usize] {
263                    continue;
264                }
265
266                let cost = cost + edge.weight;
267                let next = State {
268                    h_cost: cost + straight_line_distance(edge.to),
269                    idx: edge.to,
270                    cost,
271                };
272
273                let dist = distances.entry(next.idx).or_insert(f32::MAX);
274
275                if cost < *dist {
276                    *dist = cost;
277                    heap.push(next);
278                }
279            }
280        }
281
282        false
283    }
284
285    fn add_shortcut(&mut self, from: NodeIdx, to: NodeIdx, via: NodeIdx, weight: f32) {
286        let forward = &mut self.forward_edges[from as usize];
287        let backward = &mut self.backward_edges[to as usize];
288
289        forward.retain(|edge| edge.to != to);
290        backward.retain(|edge| edge.to != from);
291
292        let edge = Edge {
293            to,
294            weight,
295            shortcut_via: ShortcutVia::node(via),
296        };
297        forward.push(edge);
298        self.all_edges.add_forward(from, edge);
299
300        let edge = Edge {
301            to: from,
302            weight,
303            shortcut_via: ShortcutVia::node(via),
304        };
305        backward.push(edge);
306        self.all_edges.add_backward(to, edge);
307    }
308}
309
310struct SearchState {
311    node: NodeIdx,
312    distance: f32,
313    pred: Predecessor,
314}
315
316impl Eq for SearchState {}
317impl PartialEq for SearchState {
318    fn eq(&self, other: &Self) -> bool {
319        self.distance == other.distance
320    }
321}
322
323impl Ord for SearchState {
324    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
325        other.distance.partial_cmp(&self.distance).unwrap()
326    }
327}
328
329impl PartialOrd for SearchState {
330    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
331        Some(self.cmp(other))
332    }
333}
334
335#[cfg(feature = "path-unfolding")]
336#[derive(Debug)]
337pub struct Route {
338    pub path: Vec<NodeIdx>,
339    pub distance: f32,
340}
341
342pub struct Router<'ch> {
343    ch: &'ch ContractionHierarchy,
344    queue: BinaryHeap<SearchState>,
345    forward_distances: HashMap<NodeIdx, f32>,
346    backward_distances: HashMap<NodeIdx, f32>,
347    forward_settled: HashSet<NodeIdx>,
348    predecessors: Predecessors,
349}
350
351impl<'ch> Router<'ch> {
352    pub fn new(ch: &'ch ContractionHierarchy) -> Self {
353        Self {
354            ch,
355            queue: BinaryHeap::new(),
356            forward_distances: HashMap::default(),
357            backward_distances: HashMap::default(),
358            forward_settled: HashSet::default(),
359            predecessors: Predecessors::new(),
360        }
361    }
362
363    pub fn distance(&mut self, start: NodeIdx, target: NodeIdx) -> Option<f32> {
364        self.bidirectional_dijkstra(start, target)
365            .map(|(_, distance)| distance)
366    }
367
368    #[cfg(feature = "path-unfolding")]
369    pub fn route(&mut self, start: NodeIdx, target: NodeIdx) -> Option<Route> {
370        let (node, distance) = self.bidirectional_dijkstra(start, target)?;
371
372        Some(Route {
373            path: self
374                .predecessors
375                .unfold(&self.ch.all_edges, start, target, node),
376            distance,
377        })
378    }
379
380    pub fn create_hot_group(&mut self, nodes: &[NodeIdx]) -> HotGroup {
381        let forward_meeting_nodes: Vec<_> = nodes
382            .iter()
383            .map(|&node| {
384                self.queue.push(SearchState {
385                    distance: 0.0,
386                    node,
387                    pred: Predecessor::new(node),
388                });
389
390                let mut forward_settled = vec![];
391
392                self.forward_distances.insert(node, 0.0);
393
394                while let Some(SearchState { node, distance, .. }) = self.queue.pop() {
395                    if let Some(&current) = self.forward_distances.get(&node) {
396                        if distance > current {
397                            continue;
398                        }
399                    }
400
401                    forward_settled.push(MeetingNode { node, distance });
402
403                    for edge in &self.ch.forward_edges[node as usize] {
404                        let next = SearchState {
405                            distance: distance + edge.weight,
406                            node: edge.to,
407                            pred: Predecessor::new(node),
408                        };
409
410                        let forward = self.forward_distances.entry(next.node).or_insert(f32::MAX);
411                        if next.distance < *forward {
412                            *forward = next.distance;
413                            self.queue.push(next);
414                        }
415                    }
416                }
417
418                self.forward_distances.clear();
419                forward_settled.sort_by_key(|mn| mn.node);
420                forward_settled
421            })
422            .collect();
423
424        let backward_meeting_nodes: Vec<_> = nodes
425            .iter()
426            .map(|&node| {
427                self.queue.push(SearchState {
428                    distance: 0.0,
429                    node,
430                    pred: Predecessor::new(node),
431                });
432
433                let mut backward_settled = vec![];
434
435                self.backward_distances.insert(node, 0.0);
436
437                while let Some(SearchState { node, distance, .. }) = self.queue.pop() {
438                    if let Some(&current) = self.backward_distances.get(&node) {
439                        if distance > current {
440                            continue;
441                        }
442                    }
443
444                    backward_settled.push(MeetingNode { node, distance });
445
446                    for edge in &self.ch.backward_edges[node as usize] {
447                        let next = SearchState {
448                            distance: distance + edge.weight,
449                            node: edge.to,
450                            pred: Predecessor::new(node),
451                        };
452
453                        let backward = self.backward_distances.entry(next.node).or_insert(f32::MAX);
454
455                        if next.distance < *backward {
456                            *backward = next.distance;
457                            self.queue.push(next);
458                        }
459                    }
460                }
461
462                self.backward_distances.clear();
463                backward_settled.sort_by_key(|mn| mn.node);
464                backward_settled
465            })
466            .collect();
467
468        self.forward_distances.clear();
469        self.backward_distances.clear();
470
471        HotGroup {
472            forward_meeting_nodes,
473            backward_meeting_nodes,
474        }
475    }
476
477    fn bidirectional_dijkstra(
478        &mut self,
479        start: NodeIdx,
480        target: NodeIdx,
481    ) -> Option<(NodeIdx, f32)> {
482        let mut best_distance = f32::MAX;
483        let mut best_meeting_node = None;
484
485        self.queue.push(SearchState {
486            distance: 0.0,
487            node: start,
488            pred: Predecessor::new(start),
489        });
490
491        self.forward_distances.insert(start, 0.0);
492        self.backward_distances.insert(target, 0.0);
493
494        while let Some(SearchState {
495            node,
496            distance,
497            pred,
498        }) = self.queue.pop()
499        {
500            if let Some(&current) = self.forward_distances.get(&node) {
501                if distance > current {
502                    continue;
503                }
504            }
505
506            self.forward_settled.insert(node);
507            self.predecessors.insert_forward(node, pred);
508
509            for edge in &self.ch.forward_edges[node as usize] {
510                let next = SearchState {
511                    distance: distance + edge.weight,
512                    node: edge.to,
513                    pred: Predecessor::new(node),
514                };
515
516                let forward = self.forward_distances.entry(next.node).or_insert(f32::MAX);
517                if next.distance < *forward {
518                    *forward = next.distance;
519                    self.queue.push(next);
520                }
521            }
522        }
523
524        self.queue.push(SearchState {
525            distance: 0.0,
526            node: target,
527            pred: Predecessor::new(target),
528        });
529
530        while let Some(SearchState {
531            node,
532            distance,
533            pred,
534        }) = self.queue.pop()
535        {
536            if let Some(&current) = self.backward_distances.get(&node) {
537                if distance > current {
538                    continue;
539                }
540            }
541
542            self.predecessors.insert_backward(node, pred);
543
544            if self.forward_settled.contains(&node) {
545                let forward_distance = self.forward_distances[&node];
546                let total_distance = distance + forward_distance;
547                if total_distance < best_distance {
548                    best_distance = total_distance;
549                    best_meeting_node = Some(node);
550                }
551            }
552
553            for edge in &self.ch.backward_edges[node as usize] {
554                let next = SearchState {
555                    distance: distance + edge.weight,
556                    node: edge.to,
557                    pred: Predecessor::new(node),
558                };
559
560                let backward = self.backward_distances.entry(next.node).or_insert(f32::MAX);
561
562                if next.distance < *backward {
563                    *backward = next.distance;
564                    self.queue.push(next);
565                }
566            }
567        }
568
569        self.forward_distances.clear();
570        self.backward_distances.clear();
571
572        self.forward_settled.clear();
573
574        best_meeting_node.map(|node| (node, best_distance))
575    }
576}
577
578pub struct HotGroup {
579    forward_meeting_nodes: Vec<Vec<MeetingNode>>,
580    backward_meeting_nodes: Vec<Vec<MeetingNode>>,
581}
582
583struct MeetingNode {
584    node: NodeIdx,
585    distance: f32,
586}
587
588impl HotGroup {
589    pub fn distance(&self, start: NodeIdx, end: NodeIdx) -> Option<f32> {
590        let forward_nodes = &self.forward_meeting_nodes[start as usize];
591        let backward_nodes = &self.backward_meeting_nodes[end as usize];
592
593        let mut i = 0;
594        let mut j = 0;
595
596        let mut min_distance = f32::MAX;
597        let mut found = false;
598
599        while i < forward_nodes.len() && j < backward_nodes.len() {
600            let forward = &forward_nodes[i];
601            let backward = &backward_nodes[j];
602
603            if forward.node == backward.node {
604                i += 1;
605                j += 1;
606
607                min_distance = min_distance.min(forward.distance + backward.distance);
608                found = true;
609            } else if forward.node < backward.node {
610                i += 1;
611            } else {
612                j += 1;
613            }
614        }
615
616        if found {
617            Some(min_distance)
618        } else {
619            None
620        }
621    }
622
623    pub fn node_count(&self) -> u32 {
624        self.forward_meeting_nodes.len() as u32
625    }
626}
627
628#[cfg(test)]
629mod tests {
630    use crate::*;
631
632    fn create_ch() -> ContractionHierarchy {
633        let nodes = [
634            Node { x: 0.0, y: 0.0 },
635            Node { x: 8.0, y: 0.0 },
636            Node { x: 4.0, y: 1.0 },
637            Node { x: 2.0, y: 4.0 },
638            Node { x: 6.0, y: 4.0 },
639            Node { x: 4.0, y: 7.0 },
640            Node { x: 0.0, y: 8.0 },
641            Node { x: 8.0, y: 8.0 },
642        ];
643
644        let edges = [
645            Edge {
646                from: 0,
647                to: 3,
648                weight: 2.0,
649            },
650            Edge {
651                from: 1,
652                to: 2,
653                weight: 1.6,
654            },
655            Edge {
656                from: 1,
657                to: 4,
658                weight: 2.5,
659            },
660            Edge {
661                from: 2,
662                to: 0,
663                weight: 1.3,
664            },
665            Edge {
666                from: 2,
667                to: 3,
668                weight: 1.7,
669            },
670            Edge {
671                from: 2,
672                to: 4,
673                weight: 1.5,
674            },
675            Edge {
676                from: 3,
677                to: 0,
678                weight: 2.0,
679            },
680            Edge {
681                from: 3,
682                to: 5,
683                weight: 1.3,
684            },
685            Edge {
686                from: 3,
687                to: 2,
688                weight: 1.7,
689            },
690            Edge {
691                from: 4,
692                to: 5,
693                weight: 1.2,
694            },
695            Edge {
696                from: 4,
697                to: 1,
698                weight: 2.5,
699            },
700            Edge {
701                from: 5,
702                to: 6,
703                weight: 1.9,
704            },
705            Edge {
706                from: 5,
707                to: 7,
708                weight: 1.2,
709            },
710            Edge {
711                from: 5,
712                to: 3,
713                weight: 1.3,
714            },
715            Edge {
716                from: 5,
717                to: 4,
718                weight: 1.2,
719            },
720            Edge {
721                from: 6,
722                to: 3,
723                weight: 3.2,
724            },
725            Edge {
726                from: 6,
727                to: 5,
728                weight: 1.9,
729            },
730            Edge {
731                from: 7,
732                to: 4,
733                weight: 3.0,
734            },
735        ];
736
737        let max_speed = edges
738            .iter()
739            .map(|e| {
740                let from = nodes[e.from as usize];
741                let to = nodes[e.to as usize];
742                let straight_line_distance =
743                    ((from.x - to.x).powi(2) + (from.y - to.y).powi(2)).sqrt();
744                straight_line_distance / e.weight
745            })
746            .fold(f32::MIN, f32::max);
747
748        ContractionHierarchy::new(&nodes, &edges, max_speed)
749    }
750
751    fn check_distances(ch: &ContractionHierarchy) {
752        let mut router = Router::new(ch);
753
754        assert_eq!(router.distance(0, 1), Some(7.0));
755        assert_eq!(router.distance(1, 6), Some(5.6));
756        assert_eq!(router.distance(7, 2), Some(7.1));
757    }
758
759    #[cfg(feature = "path-unfolding")]
760    fn check_routing(ch: &ContractionHierarchy) {
761        let mut router = Router::new(&ch);
762
763        let route = router.route(0, 1).unwrap();
764        assert_eq!(route.path, vec![0, 3, 5, 4, 1]);
765        assert_eq!(route.distance, 7.0);
766    }
767
768    #[test]
769    fn test_distance() {
770        check_distances(&create_ch());
771    }
772
773    #[cfg(feature = "path-unfolding")]
774    #[test]
775    fn test_route() {
776        check_routing(&create_ch());
777    }
778
779    #[test]
780    fn test_load_distance() {
781        let mut ch_data = vec![];
782        let ch = create_ch();
783        ch.write(&mut ch_data).unwrap();
784        let ch = ContractionHierarchy::read(&mut &*ch_data).unwrap();
785        check_distances(&ch);
786    }
787
788    #[cfg(feature = "path-unfolding")]
789    #[test]
790    fn test_load_route() {
791        let mut ch_data = vec![];
792        let ch = create_ch();
793        ch.write(&mut ch_data).unwrap();
794        let ch = ContractionHierarchy::read(&mut &*ch_data).unwrap();
795        check_routing(&ch);
796    }
797}