1use petgraph::algo::astar;
2use petgraph::graph::NodeIndex;
3
4use crate::error::OsmGraphError;
5use crate::graph::{SpatialGraph, XmlWay};
6use crate::overpass::NetworkType;
7use crate::utils::calculate_distance;
8
9#[derive(Debug, Clone)]
10pub struct Route {
11 pub coordinates: Vec<(f64, f64)>,
13 pub cumulative_times_s: Vec<f64>,
15 pub distance_m: f64,
17 pub duration_s: f64,
19}
20
21pub fn route(
22 sg: &SpatialGraph,
23 origin_lat: f64,
24 origin_lon: f64,
25 dest_lat: f64,
26 dest_lon: f64,
27 network_type: NetworkType,
28) -> Result<Route, OsmGraphError> {
29 let origin = sg.nearest_node(origin_lat, origin_lon).ok_or(OsmGraphError::NodeNotFound)?;
30 let dest = sg.nearest_node(dest_lat, dest_lon).ok_or(OsmGraphError::NodeNotFound)?;
31
32 let edge_cost = |e: petgraph::graph::EdgeReference<XmlWay>| -> f64 {
33 let way = e.weight();
34 match network_type {
35 NetworkType::Walk => way.walk_travel_time,
36 NetworkType::Bike => way.bike_travel_time,
37 _ => way.drive_travel_time,
38 }
39 };
40
41 let heuristic = |node: NodeIndex| -> f64 {
43 let n = &sg.graph[node];
44 let d = &sg.graph[dest];
45 let dist = calculate_distance(n.lat, n.lon, d.lat, d.lon);
46 let max_speed_m_per_s = 200.0 / 3.6; dist / max_speed_m_per_s
49 };
50
51 let result = astar(&*sg.graph, origin, |n| n == dest, edge_cost, heuristic)
52 .ok_or(OsmGraphError::NodeNotFound)?; let (_, path) = result;
55
56 let coordinates: Vec<(f64, f64)> = path
57 .iter()
58 .map(|&idx| {
59 let n = &sg.graph[idx];
60 (n.lat, n.lon)
61 })
62 .collect();
63
64 let mut distance_m = 0.0;
66 let mut duration_s = 0.0;
67 let mut cumulative_times_s = vec![0.0_f64]; for window in path.windows(2) {
69 if let [u, v] = window {
70 if let Some(edge) = sg.graph.find_edge(*u, *v) {
71 let way = sg.graph.edge_weight(edge).unwrap();
72 distance_m += way.length;
73 duration_s += match network_type {
74 NetworkType::Walk => way.walk_travel_time,
75 NetworkType::Bike => way.bike_travel_time,
76 _ => way.drive_travel_time,
77 };
78 cumulative_times_s.push(duration_s);
79 }
80 }
81 }
82
83 Ok(Route { coordinates, cumulative_times_s, distance_m, duration_s })
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::graph::{SpatialGraph, XmlNode, XmlTag, XmlWay};
90 use crate::overpass::NetworkType;
91 use petgraph::graph::DiGraph;
92
93 fn make_node(id: i64, lat: f64, lon: f64) -> XmlNode {
94 XmlNode { id, lat, lon, tags: vec![], geohash: None }
95 }
96
97 fn make_way(drive_travel_time: f64, length: f64) -> XmlWay {
98 XmlWay {
99 id: 1,
100 nodes: vec![],
101 tags: vec![XmlTag { key: "highway".into(), value: "residential".into() }],
102 length,
103 speed_kph: 50.0,
104 walk_travel_time: length / (5.0 / 3.6),
105 bike_travel_time: length / (15.0 / 3.6),
106 drive_travel_time,
107 }
108 }
109
110 fn linear_graph() -> SpatialGraph {
111 let mut g = DiGraph::new();
113 let a = g.add_node(make_node(1, 0.0, 0.0));
114 let b = g.add_node(make_node(2, 0.001, 0.0));
115 let c = g.add_node(make_node(3, 0.002, 0.0));
116 g.add_edge(a, b, make_way(10.0, 111.0));
117 g.add_edge(b, c, make_way(10.0, 111.0));
118 SpatialGraph::new(g)
119 }
120
121 #[test]
122 fn test_cumulative_times_starts_at_zero() {
123 let sg = linear_graph();
124 let r = route(&sg, 0.0, 0.0, 0.002, 0.0, NetworkType::Drive).unwrap();
125 assert_eq!(r.cumulative_times_s[0], 0.0);
126 }
127
128 #[test]
129 fn test_cumulative_times_parallel_to_coordinates() {
130 let sg = linear_graph();
131 let r = route(&sg, 0.0, 0.0, 0.002, 0.0, NetworkType::Drive).unwrap();
132 assert_eq!(r.cumulative_times_s.len(), r.coordinates.len());
133 }
134
135 #[test]
136 fn test_cumulative_times_monotonic() {
137 let sg = linear_graph();
138 let r = route(&sg, 0.0, 0.0, 0.002, 0.0, NetworkType::Drive).unwrap();
139 for w in r.cumulative_times_s.windows(2) {
140 assert!(w[1] >= w[0], "times decreased: {:?}", r.cumulative_times_s);
141 }
142 }
143
144 #[test]
145 fn test_cumulative_times_last_equals_duration() {
146 let sg = linear_graph();
147 let r = route(&sg, 0.0, 0.0, 0.002, 0.0, NetworkType::Drive).unwrap();
148 let last = *r.cumulative_times_s.last().unwrap();
149 assert!(
150 (last - r.duration_s).abs() < 1e-6,
151 "last cumulative time {last:.6} != duration {:.6}", r.duration_s
152 );
153 }
154}