Skip to main content

pysochrone/
routing.rs

1use petgraph::algo::astar;
2use petgraph::graph::NodeIndex;
3
4use crate::error::OsmGraphError;
5use crate::graph::{SpatialGraph, XmlWay};
6use crate::overpass::NetworkType;
7use crate::utils::calculate_distance;
8
9#[derive(Debug, Clone)]
10pub struct Route {
11    /// Ordered list of (lat, lon) coordinates along the route
12    pub coordinates: Vec<(f64, f64)>,
13    /// Cumulative travel time in seconds at each coordinate (parallel to `coordinates`)
14    pub cumulative_times_s: Vec<f64>,
15    /// Total route distance in meters
16    pub distance_m: f64,
17    /// Total travel time in seconds for the given network type
18    pub duration_s: f64,
19}
20
21pub fn route(
22    sg: &SpatialGraph,
23    origin_lat: f64,
24    origin_lon: f64,
25    dest_lat: f64,
26    dest_lon: f64,
27    network_type: NetworkType,
28) -> Result<Route, OsmGraphError> {
29    let origin = sg.nearest_node(origin_lat, origin_lon).ok_or(OsmGraphError::NodeNotFound)?;
30    let dest = sg.nearest_node(dest_lat, dest_lon).ok_or(OsmGraphError::NodeNotFound)?;
31
32    let edge_cost = |e: petgraph::graph::EdgeReference<XmlWay>| -> f64 {
33        let way = e.weight();
34        match network_type {
35            NetworkType::Walk => way.walk_travel_time,
36            NetworkType::Bike => way.bike_travel_time,
37            _ => way.drive_travel_time,
38        }
39    };
40
41    // Heuristic: straight-line travel time from node to destination
42    let heuristic = |node: NodeIndex| -> f64 {
43        let n = &sg.graph[node];
44        let d = &sg.graph[dest];
45        let dist = calculate_distance(n.lat, n.lon, d.lat, d.lon);
46        // Use a generous speed so the heuristic is admissible (never overestimates)
47        let max_speed_m_per_s = 200.0 / 3.6; // 200 kph
48        dist / max_speed_m_per_s
49    };
50
51    let result = astar(&*sg.graph, origin, |n| n == dest, edge_cost, heuristic)
52        .ok_or(OsmGraphError::NodeNotFound)?; // no path found
53
54    let (_, path) = result;
55
56    let coordinates: Vec<(f64, f64)> = path
57        .iter()
58        .map(|&idx| {
59            let n = &sg.graph[idx];
60            (n.lat, n.lon)
61        })
62        .collect();
63
64    // Aggregate distance, duration, and cumulative times along the path
65    let mut distance_m = 0.0;
66    let mut duration_s = 0.0;
67    let mut cumulative_times_s = vec![0.0_f64]; // origin starts at t=0
68    for window in path.windows(2) {
69        if let [u, v] = window {
70            if let Some(edge) = sg.graph.find_edge(*u, *v) {
71                let way = sg.graph.edge_weight(edge).unwrap();
72                distance_m += way.length;
73                duration_s += match network_type {
74                    NetworkType::Walk => way.walk_travel_time,
75                    NetworkType::Bike => way.bike_travel_time,
76                    _ => way.drive_travel_time,
77                };
78                cumulative_times_s.push(duration_s);
79            }
80        }
81    }
82
83    Ok(Route { coordinates, cumulative_times_s, distance_m, duration_s })
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::graph::{SpatialGraph, XmlNode, XmlTag, XmlWay};
90    use crate::overpass::NetworkType;
91    use petgraph::graph::DiGraph;
92
93    fn make_node(id: i64, lat: f64, lon: f64) -> XmlNode {
94        XmlNode { id, lat, lon, tags: vec![], geohash: None }
95    }
96
97    fn make_way(drive_travel_time: f64, length: f64) -> XmlWay {
98        XmlWay {
99            id: 1,
100            nodes: vec![],
101            tags: vec![XmlTag { key: "highway".into(), value: "residential".into() }],
102            length,
103            speed_kph: 50.0,
104            walk_travel_time: length / (5.0 / 3.6),
105            bike_travel_time: length / (15.0 / 3.6),
106            drive_travel_time,
107        }
108    }
109
110    fn linear_graph() -> SpatialGraph {
111        // A → B → C along a straight line
112        let mut g = DiGraph::new();
113        let a = g.add_node(make_node(1, 0.0,   0.0));
114        let b = g.add_node(make_node(2, 0.001, 0.0));
115        let c = g.add_node(make_node(3, 0.002, 0.0));
116        g.add_edge(a, b, make_way(10.0, 111.0));
117        g.add_edge(b, c, make_way(10.0, 111.0));
118        SpatialGraph::new(g)
119    }
120
121    #[test]
122    fn test_cumulative_times_starts_at_zero() {
123        let sg = linear_graph();
124        let r = route(&sg, 0.0, 0.0, 0.002, 0.0, NetworkType::Drive).unwrap();
125        assert_eq!(r.cumulative_times_s[0], 0.0);
126    }
127
128    #[test]
129    fn test_cumulative_times_parallel_to_coordinates() {
130        let sg = linear_graph();
131        let r = route(&sg, 0.0, 0.0, 0.002, 0.0, NetworkType::Drive).unwrap();
132        assert_eq!(r.cumulative_times_s.len(), r.coordinates.len());
133    }
134
135    #[test]
136    fn test_cumulative_times_monotonic() {
137        let sg = linear_graph();
138        let r = route(&sg, 0.0, 0.0, 0.002, 0.0, NetworkType::Drive).unwrap();
139        for w in r.cumulative_times_s.windows(2) {
140            assert!(w[1] >= w[0], "times decreased: {:?}", r.cumulative_times_s);
141        }
142    }
143
144    #[test]
145    fn test_cumulative_times_last_equals_duration() {
146        let sg = linear_graph();
147        let r = route(&sg, 0.0, 0.0, 0.002, 0.0, NetworkType::Drive).unwrap();
148        let last = *r.cumulative_times_s.last().unwrap();
149        assert!(
150            (last - r.duration_s).abs() < 1e-6,
151            "last cumulative time {last:.6} != duration {:.6}", r.duration_s
152        );
153    }
154}