1use petgraph::graph::{DiGraph, NodeIndex};
2use serde::Deserialize;
3use std::collections::HashMap;
4use std::sync::Arc;
5use crate::utils::{calculate_distance, calculate_travel_time};
6use crate::simplify::simplify_graph;
7use rstar::{RTree, RTreeObject, AABB, PointDistance};
8
9#[derive(Debug, Deserialize)]
10pub struct XmlData {
11 #[serde(rename = "node", default)]
12 pub nodes: Vec<XmlNode>,
13 #[serde(rename = "way", default)]
14 pub ways: Vec<XmlWay>,
15}
16
17#[derive(Debug, Deserialize, Clone)]
18pub struct XmlNode {
19 #[serde(rename = "@id")]
20 pub id: i64,
21 #[serde(rename = "@lat")]
22 pub lat: f64,
23 #[serde(rename = "@lon")]
24 pub lon: f64,
25 #[serde(rename = "tag", default)]
26 pub tags: Vec<XmlTag>,
27 pub geohash: Option<String>,
28}
29
30#[derive(Debug, Deserialize, Clone)]
31pub struct XmlWay {
32 #[serde(rename = "@id")]
33 pub id: i64,
34 #[serde(rename = "nd", default)]
35 pub nodes: Vec<XmlNodeRef>,
36 #[serde(rename = "tag", default)]
37 pub tags: Vec<XmlTag>,
38 #[serde(default)]
39 pub length: f64,
40 #[serde(default)]
41 pub speed_kph: f64,
42 #[serde(default)]
43 pub walk_travel_time: f64,
44 #[serde(default)]
45 pub bike_travel_time: f64,
46 #[serde(default)]
47 pub drive_travel_time: f64,
48}
49
50impl XmlWay {
51 pub fn filter_useful_tags(self) -> Self {
52 const USEFUL_TAGS: &[&str] = &[
53 "bridge", "tunnel", "oneway", "lanes", "ref", "name",
54 "highway", "maxspeed", "service", "access", "area",
55 "landuse", "width", "est_width", "junction",
56 ];
57 let tags = self.tags.into_iter()
59 .filter(|tag| USEFUL_TAGS.iter().any(|&k| k == tag.key.as_str()))
60 .collect();
61 XmlWay { tags, ..self }
62 }
63}
64
65#[derive(Debug, Deserialize, Clone)]
66pub struct XmlNodeRef {
67 #[serde(rename = "@ref")]
68 pub node_id: i64,
69}
70
71#[derive(Debug, Deserialize, Clone)]
72pub struct XmlTag {
73 #[serde(rename = "@k")]
74 pub key: String,
75 #[serde(rename = "@v")]
76 pub value: String,
77}
78
79struct PathDirectionality {
80 is_one_way: bool,
81 is_reversed: bool,
82}
83
84pub fn parse_xml(xml_data: &str) -> Result<XmlData, quick_xml::DeError> {
86 let root: XmlData = quick_xml::de::from_str(xml_data)?;
87 Ok(root)
88}
89
90fn find_tag<'a>(tags: &'a [XmlTag], key: &str) -> Option<&'a XmlTag> {
91 tags.iter().find(|tag| tag.key == key)
92}
93
94fn assess_path_directionality(path: &XmlWay) -> PathDirectionality {
95 let oneway_tag = find_tag(&path.tags, "oneway");
96 let junction_tag = find_tag(&path.tags, "junction");
97
98 let is_one_way = match oneway_tag {
99 Some(tag) => {
100 matches!(tag.value.as_str(), "yes" | "true" | "1" | "-1" | "reverse")
104 },
105 None => false,
106 };
107
108 let is_reversed = oneway_tag.map_or(false, |tag| {
109 matches!(tag.value.as_str(), "-1" | "reverse")
110 });
111
112 let is_roundabout = junction_tag.map_or(false, |tag| tag.value == "roundabout");
114
115 PathDirectionality {
116 is_one_way: is_one_way || is_roundabout,
117 is_reversed,
118 }
119}
120
121pub fn create_graph(
123 nodes: Vec<XmlNode>,
124 ways: Vec<XmlWay>,
125 retain_all: bool,
126 _bidirectional: bool,
127) -> DiGraph<XmlNode, XmlWay> {
128 let mut graph = DiGraph::<XmlNode, XmlWay>::new();
129 let mut node_index_map = HashMap::new();
130
131 for node in nodes {
133 let id = node.id;
134 let node_index = graph.add_node(node); node_index_map.insert(id, node_index);
136 }
137
138 for mut way in ways {
140 let node_refs = std::mem::take(&mut way.nodes);
143 let filtered_way = way.filter_useful_tags();
144 let path_direction = assess_path_directionality(&filtered_way);
145
146 for window in node_refs.windows(2) {
147 if let [start_ref, end_ref] = window {
148 let (start_index, end_index) = if path_direction.is_reversed {
149 (
150 node_index_map[&end_ref.node_id],
151 node_index_map[&start_ref.node_id],
152 )
153 } else {
154 (
155 node_index_map[&start_ref.node_id],
156 node_index_map[&end_ref.node_id],
157 )
158 };
159
160 graph.add_edge(start_index, end_index, filtered_way.clone());
161 if !path_direction.is_one_way {
162 graph.add_edge(end_index, start_index, filtered_way.clone());
163 }
164 }
165 }
166 }
167
168 add_edge_lengths(&mut graph);
170
171 let hwy_speeds = HashMap::from([
174 ("motorway".to_string(), 110.0),
175 ("motorway_link".to_string(), 60.0),
176 ("trunk".to_string(), 90.0),
177 ("trunk_link".to_string(), 45.0),
178 ("primary".to_string(), 65.0),
179 ("primary_link".to_string(), 45.0),
180 ("secondary".to_string(), 55.0),
181 ("secondary_link".to_string(), 40.0),
182 ("tertiary".to_string(), 45.0),
183 ("tertiary_link".to_string(), 35.0),
184 ("unclassified".to_string(), 45.0),
185 ("residential".to_string(), 30.0),
186 ("living_street".to_string(), 10.0),
187 ("service".to_string(), 20.0),
188 ("track".to_string(), 20.0),
189 ("road".to_string(), 50.0),
190 ]);
191 let fallback_speed = 50.0;
192
193 add_edge_speeds(&mut graph, &hwy_speeds, fallback_speed);
194 add_edge_travel_times(&mut graph);
195
196 if !retain_all {
199 graph = simplify_graph(&graph)
200 }
201 graph
204}
205
206fn add_edge_lengths(graph: &mut DiGraph<XmlNode, XmlWay>) {
207 for edge in graph.edge_indices() {
208 let (start_index, end_index) = graph.edge_endpoints(edge).unwrap();
209 let start_node = &graph[start_index];
210 let end_node = &graph[end_index];
211
212 let distance =
213 calculate_distance(start_node.lat, start_node.lon, end_node.lat, end_node.lon);
214
215 let way = graph.edge_weight_mut(edge).unwrap();
216 way.length = distance;
217 }
218}
219
220fn add_edge_speeds(
221 graph: &mut DiGraph<XmlNode, XmlWay>,
222 hwy_speeds: &HashMap<String, f64>,
223 fallback: f64,
224) {
225 for edge in graph.edge_indices() {
226 let way = graph.edge_weight_mut(edge).unwrap();
227 let speed = way
228 .tags
229 .iter()
230 .find(|tag| tag.key == "maxspeed")
231 .map_or_else(
232 || {
233 way.tags
234 .iter()
235 .find(|tag| tag.key == "highway")
236 .and_then(|tag| hwy_speeds.get(&tag.value).copied())
237 .unwrap_or(fallback)
238 },
239 |tag| clean_maxspeed(&tag.value),
240 );
241 way.speed_kph = speed;
242 }
243}
244
245fn clean_maxspeed(maxspeed: &str) -> f64 {
246 let mph_to_kph = 1.60934;
247 let speed = maxspeed.parse::<f64>().unwrap_or(0.0);
248 if maxspeed.to_lowercase().contains("mph") {
249 speed * mph_to_kph
250 } else {
251 speed
252 }
253}
254
255fn add_edge_travel_times(graph: &mut DiGraph<XmlNode, XmlWay>) {
257 for edge in graph.edge_indices() {
258 let way = graph.edge_weight_mut(edge).unwrap();
259 let walk_travel_time = calculate_travel_time(way.length, 5.0);
260 let bike_travel_time = calculate_travel_time(way.length, 15.0);
261 let drive_travel_time = calculate_travel_time(way.length, way.speed_kph);
262
263 way.walk_travel_time = walk_travel_time;
264 way.bike_travel_time = bike_travel_time;
265 way.drive_travel_time = drive_travel_time;
266 }
267}
268
269pub fn node_to_latlon(graph: &DiGraph<XmlNode, XmlWay>, node_index: NodeIndex) -> (f64, f64) {
270 let node = &graph[node_index];
271 (node.lat, node.lon)
272}
273
274#[derive(Clone)]
276struct NodeEntry {
277 point: [f64; 2],
278 index: NodeIndex,
279}
280
281impl RTreeObject for NodeEntry {
282 type Envelope = AABB<[f64; 2]>;
283 fn envelope(&self) -> Self::Envelope {
284 AABB::from_point(self.point)
285 }
286}
287
288impl PointDistance for NodeEntry {
289 fn distance_2(&self, point: &[f64; 2]) -> f64 {
290 let dlat = self.point[0] - point[0];
291 let dlon = self.point[1] - point[1];
292 dlat * dlat + dlon * dlon
293 }
294}
295
296#[derive(Clone)]
300pub struct SpatialGraph {
301 pub graph: Arc<DiGraph<XmlNode, XmlWay>>,
302 tree: Arc<RTree<NodeEntry>>,
303}
304
305impl SpatialGraph {
306 pub fn new(graph: DiGraph<XmlNode, XmlWay>) -> Self {
307 let entries = graph.node_indices()
308 .map(|i| NodeEntry { point: [graph[i].lat, graph[i].lon], index: i })
309 .collect();
310 let tree = Arc::new(RTree::bulk_load(entries));
311 let graph = Arc::new(graph);
312 Self { graph, tree }
313 }
314
315 pub fn nearest_node(&self, lat: f64, lon: f64) -> Option<NodeIndex> {
316 self.tree.nearest_neighbor(&[lat, lon]).map(|e| e.index)
317 }
318}
319
320pub fn latlon_to_node(graph: &DiGraph<XmlNode, XmlWay>, lat: f64, lon: f64) -> Option<NodeIndex> {
322 SpatialGraph::new(graph.clone()).nearest_node(lat, lon)
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 fn make_node(id: i64, lat: f64, lon: f64) -> XmlNode {
330 XmlNode { id, lat, lon, tags: vec![], geohash: None }
331 }
332
333 fn make_way_raw(node_ids: Vec<i64>, tags: Vec<(&str, &str)>) -> XmlWay {
334 XmlWay {
335 id: 1,
336 nodes: node_ids.into_iter().map(|id| XmlNodeRef { node_id: id }).collect(),
337 tags: tags.into_iter().map(|(k, v)| XmlTag { key: k.into(), value: v.into() }).collect(),
338 length: 0.0, speed_kph: 0.0,
339 walk_travel_time: 0.0, bike_travel_time: 0.0, drive_travel_time: 0.0,
340 }
341 }
342
343 #[test]
344 fn test_graph_respects_maxspeed_tag() {
345 let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
346 let way = make_way_raw(vec![1, 2], vec![("highway", "residential"), ("maxspeed", "30")]);
347 let graph = create_graph(vec![nodes[0].clone(), nodes[1].clone()], vec![way], true, false);
348 assert_eq!(graph.edge_weights().next().unwrap().speed_kph, 30.0);
349 }
350
351 #[test]
352 fn test_oneway_produces_single_edge() {
353 let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
354 let way = make_way_raw(vec![1, 2], vec![("highway", "residential"), ("oneway", "yes")]);
355 let graph = create_graph(vec![nodes[0].clone(), nodes[1].clone()], vec![way], true, false);
356 assert_eq!(graph.edge_count(), 1);
357 }
358
359 #[test]
360 fn test_bidirectional_produces_two_edges() {
361 let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
362 let way = make_way_raw(vec![1, 2], vec![("highway", "residential")]);
363 let graph = create_graph(vec![nodes[0].clone(), nodes[1].clone()], vec![way], true, false);
364 assert_eq!(graph.edge_count(), 2);
365 }
366
367 #[test]
368 fn test_nearest_node_finds_closest() {
369 let mut graph = DiGraph::new();
370 graph.add_node(make_node(1, 48.0, 11.0));
371 graph.add_node(make_node(2, 52.0, 13.0));
372 let sg = SpatialGraph::new(graph);
373 let idx = sg.nearest_node(48.001, 11.001).unwrap();
374 assert_eq!(sg.graph[idx].id, 1);
375 }
376}