Skip to main content

pysochrone/
graph.rs

1use petgraph::graph::{DiGraph, NodeIndex};
2use serde::Deserialize;
3use std::collections::HashMap;
4use std::sync::Arc;
5use crate::utils::{calculate_distance, calculate_travel_time};
6use crate::simplify::simplify_graph;
7use rstar::{RTree, RTreeObject, AABB, PointDistance};
8
9#[derive(Debug, Deserialize)]
10pub struct XmlData {
11    #[serde(rename = "node", default)]
12    pub nodes: Vec<XmlNode>,
13    #[serde(rename = "way", default)]
14    pub ways: Vec<XmlWay>,
15}
16
17#[derive(Debug, Deserialize, Clone)]
18pub struct XmlNode {
19    #[serde(rename = "@id")]
20    pub id: i64,
21    #[serde(rename = "@lat")]
22    pub lat: f64,
23    #[serde(rename = "@lon")]
24    pub lon: f64,
25    #[serde(rename = "tag", default)]
26    pub tags: Vec<XmlTag>,
27    pub geohash: Option<String>,
28}
29
30#[derive(Debug, Deserialize, Clone)]
31pub struct XmlWay {
32    #[serde(rename = "@id")]
33    pub id: i64,
34    #[serde(rename = "nd", default)]
35    pub nodes: Vec<XmlNodeRef>,
36    #[serde(rename = "tag", default)]
37    pub tags: Vec<XmlTag>,
38    #[serde(default)]
39    pub length: f64,
40    #[serde(default)]
41    pub speed_kph: f64,
42    #[serde(default)]
43    pub walk_travel_time: f64,
44    #[serde(default)]
45    pub bike_travel_time: f64,
46    #[serde(default)]
47    pub drive_travel_time: f64,
48}
49
50impl XmlWay {
51    pub fn filter_useful_tags(self) -> Self {
52        const USEFUL_TAGS: &[&str] = &[
53            "bridge", "tunnel", "oneway", "lanes", "ref", "name",
54            "highway", "maxspeed", "service", "access", "area",
55            "landuse", "width", "est_width", "junction",
56        ];
57        // Linear search on 15-element static slice — no HashSet allocation needed.
58        let tags = self.tags.into_iter()
59            .filter(|tag| USEFUL_TAGS.iter().any(|&k| k == tag.key.as_str()))
60            .collect();
61        XmlWay { tags, ..self }
62    }
63}
64
65#[derive(Debug, Deserialize, Clone)]
66pub struct XmlNodeRef {
67    #[serde(rename = "@ref")]
68    pub node_id: i64,
69}
70
71#[derive(Debug, Deserialize, Clone)]
72pub struct XmlTag {
73    #[serde(rename = "@k")]
74    pub key: String,
75    #[serde(rename = "@v")]
76    pub value: String,
77}
78
79struct PathDirectionality {
80    is_one_way: bool,
81    is_reversed: bool,
82}
83
84// Function to parse the XML response
85pub fn parse_xml(xml_data: &str) -> Result<XmlData, quick_xml::DeError> {
86    let root: XmlData = quick_xml::de::from_str(xml_data)?;
87    Ok(root)
88}
89
90fn find_tag<'a>(tags: &'a [XmlTag], key: &str) -> Option<&'a XmlTag> {
91    tags.iter().find(|tag| tag.key == key)
92}
93
94fn assess_path_directionality(path: &XmlWay) -> PathDirectionality {
95    let oneway_tag = find_tag(&path.tags, "oneway");
96    let junction_tag = find_tag(&path.tags, "junction");
97
98    let is_one_way = match oneway_tag {
99        Some(tag) => {
100            // The oneway tag can have several values indicating true: "yes", "true", "1"
101            // or indicating reversed: "-1", "reverse"
102            // Any other value (including absence of the tag) defaults to false
103            matches!(tag.value.as_str(), "yes" | "true" | "1" | "-1" | "reverse")
104        },
105        None => false,
106    };
107
108    let is_reversed = oneway_tag.map_or(false, |tag| {
109        matches!(tag.value.as_str(), "-1" | "reverse")
110    });
111
112    // Roundabouts are considered one-way implicitly
113    let is_roundabout = junction_tag.map_or(false, |tag| tag.value == "roundabout");
114
115    PathDirectionality {
116        is_one_way: is_one_way || is_roundabout,
117        is_reversed,
118    }
119}
120
121// Function to create the network graph
122pub fn create_graph(
123    nodes: Vec<XmlNode>,
124    ways: Vec<XmlWay>,
125    retain_all: bool,
126    _bidirectional: bool,
127) -> DiGraph<XmlNode, XmlWay> {
128    let mut graph = DiGraph::<XmlNode, XmlWay>::new();
129    let mut node_index_map = HashMap::new();
130
131    // Add nodes to the graph and keep track of their indices
132    for node in nodes {
133        let id = node.id;
134        let node_index = graph.add_node(node); // move — no clone needed, nodes is already owned
135        node_index_map.insert(id, node_index);
136    }
137
138    // Add edges to the graph
139    for mut way in ways {
140        // Extract node refs before consuming `way` so that edge weights are stored
141        // without the construction-only node list (saves memory for every edge in the graph).
142        let node_refs = std::mem::take(&mut way.nodes);
143        let filtered_way = way.filter_useful_tags();
144        let path_direction = assess_path_directionality(&filtered_way);
145
146        for window in node_refs.windows(2) {
147            if let [start_ref, end_ref] = window {
148                let (start_index, end_index) = if path_direction.is_reversed {
149                    (
150                        node_index_map[&end_ref.node_id],
151                        node_index_map[&start_ref.node_id],
152                    )
153                } else {
154                    (
155                        node_index_map[&start_ref.node_id],
156                        node_index_map[&end_ref.node_id],
157                    )
158                };
159
160                graph.add_edge(start_index, end_index, filtered_way.clone());
161                if !path_direction.is_one_way {
162                    graph.add_edge(end_index, start_index, filtered_way.clone());
163                }
164            }
165        }
166    }
167
168    // Add distance as edge weights
169    add_edge_lengths(&mut graph);
170
171    // Standard OSM highway type speeds (kph), based on typical urban defaults.
172    // These apply when no maxspeed tag is present.
173    let hwy_speeds = HashMap::from([
174        ("motorway".to_string(),       110.0),
175        ("motorway_link".to_string(),   60.0),
176        ("trunk".to_string(),           90.0),
177        ("trunk_link".to_string(),      45.0),
178        ("primary".to_string(),         65.0),
179        ("primary_link".to_string(),    45.0),
180        ("secondary".to_string(),       55.0),
181        ("secondary_link".to_string(),  40.0),
182        ("tertiary".to_string(),        45.0),
183        ("tertiary_link".to_string(),   35.0),
184        ("unclassified".to_string(),    45.0),
185        ("residential".to_string(),     30.0),
186        ("living_street".to_string(),   10.0),
187        ("service".to_string(),         20.0),
188        ("track".to_string(),           20.0),
189        ("road".to_string(),            50.0),
190    ]);
191    let fallback_speed = 50.0;
192
193    add_edge_speeds(&mut graph, &hwy_speeds, fallback_speed);
194    add_edge_travel_times(&mut graph);
195
196    // Simplify graph topology for faster downstream calculations
197    // Consolidates distance and speed from
198    if !retain_all {
199        graph = simplify_graph(&graph)
200    }
201    // ... other future logic
202
203    graph
204}
205
206fn add_edge_lengths(graph: &mut DiGraph<XmlNode, XmlWay>) {
207    for edge in graph.edge_indices() {
208        let (start_index, end_index) = graph.edge_endpoints(edge).unwrap();
209        let start_node = &graph[start_index];
210        let end_node = &graph[end_index];
211
212        let distance =
213            calculate_distance(start_node.lat, start_node.lon, end_node.lat, end_node.lon);
214
215        let way = graph.edge_weight_mut(edge).unwrap();
216        way.length = distance;
217    }
218}
219
220fn add_edge_speeds(
221    graph: &mut DiGraph<XmlNode, XmlWay>,
222    hwy_speeds: &HashMap<String, f64>,
223    fallback: f64,
224) {
225    for edge in graph.edge_indices() {
226        let way = graph.edge_weight_mut(edge).unwrap();
227        let speed = way
228            .tags
229            .iter()
230            .find(|tag| tag.key == "maxspeed")
231            .map_or_else(
232                || {
233                    way.tags
234                        .iter()
235                        .find(|tag| tag.key == "highway")
236                        .and_then(|tag| hwy_speeds.get(&tag.value).copied())
237                        .unwrap_or(fallback)
238                },
239                |tag| clean_maxspeed(&tag.value),
240            );
241        way.speed_kph = speed;
242    }
243}
244
245fn clean_maxspeed(maxspeed: &str) -> f64 {
246    let mph_to_kph = 1.60934;
247    let speed = maxspeed.parse::<f64>().unwrap_or(0.0);
248    if maxspeed.to_lowercase().contains("mph") {
249        speed * mph_to_kph
250    } else {
251        speed
252    }
253}
254
255// Function to add travel times as an edge weight
256fn add_edge_travel_times(graph: &mut DiGraph<XmlNode, XmlWay>) {
257    for edge in graph.edge_indices() {
258        let way = graph.edge_weight_mut(edge).unwrap();
259        let walk_travel_time = calculate_travel_time(way.length, 5.0);
260        let bike_travel_time = calculate_travel_time(way.length, 15.0);
261        let drive_travel_time = calculate_travel_time(way.length, way.speed_kph);
262
263        way.walk_travel_time = walk_travel_time;
264        way.bike_travel_time = bike_travel_time;
265        way.drive_travel_time = drive_travel_time;
266    }
267}
268
269pub fn node_to_latlon(graph: &DiGraph<XmlNode, XmlWay>, node_index: NodeIndex) -> (f64, f64) {
270    let node = &graph[node_index];
271    (node.lat, node.lon)
272}
273
274/// R-tree entry pairing a node's coordinates with its NodeIndex.
275#[derive(Clone)]
276struct NodeEntry {
277    point: [f64; 2],
278    index: NodeIndex,
279}
280
281impl RTreeObject for NodeEntry {
282    type Envelope = AABB<[f64; 2]>;
283    fn envelope(&self) -> Self::Envelope {
284        AABB::from_point(self.point)
285    }
286}
287
288impl PointDistance for NodeEntry {
289    fn distance_2(&self, point: &[f64; 2]) -> f64 {
290        let dlat = self.point[0] - point[0];
291        let dlon = self.point[1] - point[1];
292        dlat * dlat + dlon * dlon
293    }
294}
295
296/// A graph bundled with a spatial index for O(log n) nearest-node queries.
297/// Build once via `SpatialGraph::new`, reuse for all lookups and routing.
298/// Both inner fields are reference-counted, so cloning a `SpatialGraph` is O(1).
299#[derive(Clone)]
300pub struct SpatialGraph {
301    pub graph: Arc<DiGraph<XmlNode, XmlWay>>,
302    tree: Arc<RTree<NodeEntry>>,
303}
304
305impl SpatialGraph {
306    pub fn new(graph: DiGraph<XmlNode, XmlWay>) -> Self {
307        let entries = graph.node_indices()
308            .map(|i| NodeEntry { point: [graph[i].lat, graph[i].lon], index: i })
309            .collect();
310        let tree = Arc::new(RTree::bulk_load(entries));
311        let graph = Arc::new(graph);
312        Self { graph, tree }
313    }
314
315    pub fn nearest_node(&self, lat: f64, lon: f64) -> Option<NodeIndex> {
316        self.tree.nearest_neighbor(&[lat, lon]).map(|e| e.index)
317    }
318}
319
320// Keep the free function for backwards compatibility but delegate to SpatialGraph
321pub fn latlon_to_node(graph: &DiGraph<XmlNode, XmlWay>, lat: f64, lon: f64) -> Option<NodeIndex> {
322    SpatialGraph::new(graph.clone()).nearest_node(lat, lon)
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    fn make_node(id: i64, lat: f64, lon: f64) -> XmlNode {
330        XmlNode { id, lat, lon, tags: vec![], geohash: None }
331    }
332
333    fn make_way_raw(node_ids: Vec<i64>, tags: Vec<(&str, &str)>) -> XmlWay {
334        XmlWay {
335            id: 1,
336            nodes: node_ids.into_iter().map(|id| XmlNodeRef { node_id: id }).collect(),
337            tags: tags.into_iter().map(|(k, v)| XmlTag { key: k.into(), value: v.into() }).collect(),
338            length: 0.0, speed_kph: 0.0,
339            walk_travel_time: 0.0, bike_travel_time: 0.0, drive_travel_time: 0.0,
340        }
341    }
342
343    #[test]
344    fn test_graph_respects_maxspeed_tag() {
345        let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
346        let way = make_way_raw(vec![1, 2], vec![("highway", "residential"), ("maxspeed", "30")]);
347        let graph = create_graph(vec![nodes[0].clone(), nodes[1].clone()], vec![way], true, false);
348        assert_eq!(graph.edge_weights().next().unwrap().speed_kph, 30.0);
349    }
350
351    #[test]
352    fn test_oneway_produces_single_edge() {
353        let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
354        let way = make_way_raw(vec![1, 2], vec![("highway", "residential"), ("oneway", "yes")]);
355        let graph = create_graph(vec![nodes[0].clone(), nodes[1].clone()], vec![way], true, false);
356        assert_eq!(graph.edge_count(), 1);
357    }
358
359    #[test]
360    fn test_bidirectional_produces_two_edges() {
361        let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
362        let way = make_way_raw(vec![1, 2], vec![("highway", "residential")]);
363        let graph = create_graph(vec![nodes[0].clone(), nodes[1].clone()], vec![way], true, false);
364        assert_eq!(graph.edge_count(), 2);
365    }
366
367    #[test]
368    fn test_nearest_node_finds_closest() {
369        let mut graph = DiGraph::new();
370        graph.add_node(make_node(1, 48.0, 11.0));
371        graph.add_node(make_node(2, 52.0, 13.0));
372        let sg = SpatialGraph::new(graph);
373        let idx = sg.nearest_node(48.001, 11.001).unwrap();
374        assert_eq!(sg.graph[idx].id, 1);
375    }
376}