Skip to main content

pysochrone/
lib.rs

1#![allow(dead_code)]
2
3// Public modules — available to any Rust crate that depends on this library.
4// None of these import pyo3, so they compile cleanly without the extension-module feature.
5pub mod error;
6pub mod geocoding;
7pub mod graph;
8pub mod isochrone;
9pub mod overpass;
10pub mod pbf;
11pub mod poi;
12pub mod routing;
13pub mod utils;
14
15// Internal implementation details; not part of the public Rust API.
16mod cache;
17mod simplify;
18
19// ---------------------------------------------------------------------------
20// Python extension module
21//
22// Everything below is compiled ONLY when the "extension-module" feature is
23// active (i.e. when maturin is building the .pyd/.so for Python).
24// When another Rust crate depends on this library as an rlib, none of this
25// code is included and there is no pyo3 / Python linkage at all.
26// ---------------------------------------------------------------------------
27
28#[cfg(feature = "extension-module")]
29use pyo3::prelude::*;
30
31#[cfg(feature = "extension-module")]
32use std::sync::Arc;
33
34#[cfg(feature = "extension-module")]
35lazy_static::lazy_static! {
36    static ref TOKIO_RT: tokio::runtime::Runtime =
37        tokio::runtime::Runtime::new().expect("failed to create tokio runtime");
38}
39
40// ---------------------------------------------------------------------------
41// Shared helpers (Python-binding layer only)
42// ---------------------------------------------------------------------------
43
44#[cfg(feature = "extension-module")]
45fn parse_network_type(s: &str) -> PyResult<overpass::NetworkType> {
46    match s {
47        "Drive"        => Ok(overpass::NetworkType::Drive),
48        "DriveService" => Ok(overpass::NetworkType::DriveService),
49        "Walk"         => Ok(overpass::NetworkType::Walk),
50        "Bike"         => Ok(overpass::NetworkType::Bike),
51        "All"          => Ok(overpass::NetworkType::All),
52        "AllPrivate"   => Ok(overpass::NetworkType::AllPrivate),
53        _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
54            "Invalid network type '{}'. Expected: Drive, DriveService, Walk, Bike, All, AllPrivate", s
55        ))),
56    }
57}
58
59#[cfg(feature = "extension-module")]
60fn parse_hull_type(s: &str) -> PyResult<isochrone::HullType> {
61    match s {
62        "Convex"      => Ok(isochrone::HullType::Convex),
63        "FastConcave" => Ok(isochrone::HullType::FastConcave),
64        "Concave"     => Ok(isochrone::HullType::Concave),
65        _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
66            "Invalid hull type '{}'. Expected: Convex, FastConcave, Concave", s
67        ))),
68    }
69}
70
71// ---------------------------------------------------------------------------
72// PyGraph — exposes a loaded SpatialGraph to Python
73// ---------------------------------------------------------------------------
74
75/// A road-network graph loaded from OpenStreetMap.
76///
77/// Obtain one via `build_graph(...)` and reuse it for multiple queries over
78/// the same area — isochrones, routing, and POI lookups all share the same
79/// in-memory graph with no redundant fetches.
80#[cfg(feature = "extension-module")]
81#[pyclass(name = "Graph")]
82struct PyGraph {
83    sg: graph::SpatialGraph,
84    network_type: overpass::NetworkType,
85}
86
87#[cfg(feature = "extension-module")]
88#[pymethods]
89impl PyGraph {
90    fn node_count(&self) -> usize {
91        self.sg.graph.node_count()
92    }
93
94    fn edge_count(&self) -> usize {
95        self.sg.graph.edge_count()
96    }
97
98    fn nearest_node(&self, lat: f64, lon: f64) -> PyResult<Option<(i64, f64, f64)>> {
99        Ok(self.sg.nearest_node(lat, lon).map(|idx| {
100            let n = &self.sg.graph[idx];
101            (n.id, n.lat, n.lon)
102        }))
103    }
104
105    #[pyo3(signature = (lat, lon, time_limits, hull_type = "Concave"))]
106    fn isochrones(
107        &self,
108        lat: f64,
109        lon: f64,
110        time_limits: Vec<f64>,
111        hull_type: &str,
112    ) -> PyResult<Vec<String>> {
113        let hull = parse_hull_type(hull_type)?;
114        let node = self.sg.nearest_node(lat, lon)
115            .ok_or_else(|| pyo3::exceptions::PyValueError::new_err(
116                "No node found near the given coordinates"
117            ))?;
118
119        let shared = Arc::clone(&self.sg.graph);
120        let isos = isochrone::calculate_isochrones_concurrently(
121            shared, node, time_limits, self.network_type, hull,
122        );
123        Ok(isos.iter().map(|p| utils::polygon_to_geojson_string(p)).collect())
124    }
125
126    fn route(
127        &self,
128        origin_lat: f64,
129        origin_lon: f64,
130        dest_lat: f64,
131        dest_lon: f64,
132    ) -> PyResult<String> {
133        let r = routing::route(&self.sg, origin_lat, origin_lon, dest_lat, dest_lon, self.network_type)?;
134
135        let coords: Vec<Vec<f64>> = r.coordinates.iter()
136            .map(|(lat, lon)| vec![*lon, *lat])
137            .collect();
138        let geometry = geojson::Geometry::new(geojson::Value::LineString(coords));
139        let mut props = geojson::JsonObject::new();
140        props.insert("distance_m".into(), r.distance_m.into());
141        props.insert("duration_s".into(), r.duration_s.into());
142        props.insert(
143            "cumulative_times_s".into(),
144            geojson::JsonValue::Array(r.cumulative_times_s.iter().map(|&t| t.into()).collect()),
145        );
146        let feature = geojson::Feature { geometry: Some(geometry), properties: Some(props), ..Default::default() };
147        Ok(geojson::GeoJson::Feature(feature).to_string())
148    }
149
150    fn fetch_pois(&self, isochrone_geojson: String) -> PyResult<String> {
151        let polygon = poi::parse_isochrone(&isochrone_geojson)?;
152        let pois = TOKIO_RT.block_on(poi::fetch_pois_within(&polygon))?;
153        Ok(poi::pois_to_geojson(&pois))
154    }
155
156    fn nodes_geojson(&self) -> String {
157        let features: Vec<geojson::Feature> = self.sg.graph.node_indices().map(|idx| {
158            let n = &self.sg.graph[idx];
159            let geom = geojson::Geometry::new(geojson::Value::Point(vec![n.lon, n.lat]));
160            let mut props = geojson::JsonObject::new();
161            props.insert("id".into(), n.id.into());
162            props.insert("lat".into(), n.lat.into());
163            props.insert("lon".into(), n.lon.into());
164            geojson::Feature { geometry: Some(geom), properties: Some(props), ..Default::default() }
165        }).collect();
166        geojson::GeoJson::FeatureCollection(geojson::FeatureCollection {
167            features, bbox: None, foreign_members: None,
168        }).to_string()
169    }
170
171    fn edges_geojson(&self) -> String {
172        let features: Vec<geojson::Feature> = self.sg.graph.edge_indices().map(|eidx| {
173            let (u, v) = self.sg.graph.edge_endpoints(eidx).unwrap();
174            let from = &self.sg.graph[u];
175            let to   = &self.sg.graph[v];
176            let way  = self.sg.graph.edge_weight(eidx).unwrap();
177            let coords = vec![vec![from.lon, from.lat], vec![to.lon, to.lat]];
178            let geom = geojson::Geometry::new(geojson::Value::LineString(coords));
179            let highway = way.tags.iter()
180                .find(|t| t.key == "highway")
181                .map(|t| t.value.as_str()).unwrap_or("unknown").to_string();
182            let mut props = geojson::JsonObject::new();
183            props.insert("highway".into(),      highway.into());
184            props.insert("length_m".into(),     way.length.into());
185            props.insert("speed_kph".into(),    way.speed_kph.into());
186            props.insert("drive_time_s".into(), way.drive_travel_time.into());
187            props.insert("walk_time_s".into(),  way.walk_travel_time.into());
188            props.insert("bike_time_s".into(),  way.bike_travel_time.into());
189            geojson::Feature { geometry: Some(geom), properties: Some(props), ..Default::default() }
190        }).collect();
191        geojson::GeoJson::FeatureCollection(geojson::FeatureCollection {
192            features, bbox: None, foreign_members: None,
193        }).to_string()
194    }
195
196    fn __repr__(&self) -> String {
197        format!(
198            "Graph(nodes={}, edges={}, network_type={:?})",
199            self.sg.graph.node_count(),
200            self.sg.graph.edge_count(),
201            self.network_type,
202        )
203    }
204}
205
206// ---------------------------------------------------------------------------
207// Module-level Python functions
208// ---------------------------------------------------------------------------
209
210#[cfg(feature = "extension-module")]
211#[pyfunction]
212#[pyo3(signature = (lat, lon, network_type, max_dist = None, retain_all = false))]
213fn build_graph(
214    lat: f64, lon: f64, network_type: String,
215    max_dist: Option<f64>, retain_all: bool,
216) -> PyResult<PyGraph> {
217    let nt = parse_network_type(&network_type)?;
218    let dist = max_dist.unwrap_or(5_000.0);
219    let (_, sg) = TOKIO_RT.block_on(isochrone::calculate_isochrones_from_point(
220        lat, lon, Some(dist), vec![], nt, isochrone::HullType::Convex, retain_all,
221    ))?;
222    Ok(PyGraph { sg, network_type: nt })
223}
224
225#[cfg(feature = "extension-module")]
226#[pyfunction]
227#[pyo3(signature = (lat, lon, time_limits, network_type, hull_type, max_dist=None, retain_all=false))]
228fn calc_isochrones(
229    lat: f64, lon: f64, time_limits: Vec<f64>,
230    network_type: String, hull_type: String,
231    max_dist: Option<f64>, retain_all: bool,
232) -> PyResult<Vec<String>> {
233    let nt = parse_network_type(&network_type)?;
234    let hull = parse_hull_type(&hull_type)?;
235    let (isochrones, _) = TOKIO_RT.block_on(isochrone::calculate_isochrones_from_point(
236        lat, lon, max_dist, time_limits, nt, hull, retain_all,
237    ))?;
238    Ok(isochrones.iter().map(|iso| utils::polygon_to_geojson_string(iso)).collect())
239}
240
241#[cfg(feature = "extension-module")]
242#[pyfunction]
243#[pyo3(signature = (origin_lat, origin_lon, dest_lat, dest_lon, network_type, max_dist=None, retain_all=false))]
244fn calc_route(
245    origin_lat: f64, origin_lon: f64, dest_lat: f64, dest_lon: f64,
246    network_type: String, max_dist: Option<f64>, retain_all: bool,
247) -> PyResult<String> {
248    let nt = parse_network_type(&network_type)?;
249    let mid_lat = (origin_lat + dest_lat) / 2.0;
250    let mid_lon = (origin_lon + dest_lon) / 2.0;
251    let straight_line = utils::calculate_distance(origin_lat, origin_lon, dest_lat, dest_lon);
252    let computed_dist = max_dist.unwrap_or_else(|| (straight_line * 1.5).max(5_000.0));
253    let (_, sg) = TOKIO_RT.block_on(isochrone::calculate_isochrones_from_point(
254        mid_lat, mid_lon, Some(computed_dist), vec![], nt, isochrone::HullType::Convex, retain_all,
255    ))?;
256    let r = routing::route(&sg, origin_lat, origin_lon, dest_lat, dest_lon, nt)?;
257    let coords: Vec<Vec<f64>> = r.coordinates.iter().map(|(lat, lon)| vec![*lon, *lat]).collect();
258    let geometry = geojson::Geometry::new(geojson::Value::LineString(coords));
259    let mut props = geojson::JsonObject::new();
260    props.insert("distance_m".to_string(), r.distance_m.into());
261    props.insert("duration_s".to_string(), r.duration_s.into());
262    props.insert(
263        "cumulative_times_s".to_string(),
264        geojson::JsonValue::Array(r.cumulative_times_s.iter().map(|&t| geojson::JsonValue::from(t)).collect()),
265    );
266    let feature = geojson::Feature { geometry: Some(geometry), properties: Some(props), ..Default::default() };
267    Ok(geojson::GeoJson::Feature(feature).to_string())
268}
269
270#[cfg(feature = "extension-module")]
271#[pyfunction]
272fn geocode(place: String) -> PyResult<(f64, f64)> {
273    Ok(TOKIO_RT.block_on(geocoding::geocode(&place))?)
274}
275
276#[cfg(feature = "extension-module")]
277#[pyfunction]
278fn fetch_pois(isochrone_geojson: String) -> PyResult<String> {
279    let polygon = poi::parse_isochrone(&isochrone_geojson)?;
280    let pois = TOKIO_RT.block_on(poi::fetch_pois_within(&polygon))?;
281    Ok(poi::pois_to_geojson(&pois))
282}
283
284#[cfg(feature = "extension-module")]
285#[pyfunction]
286fn clear_cache() -> PyResult<()> {
287    cache::clear_cache()?;
288    cache::clear_disk_cache()?;
289    Ok(())
290}
291
292#[cfg(feature = "extension-module")]
293#[pyfunction]
294fn cache_dir() -> PyResult<String> {
295    Ok(cache::disk_cache_dir().to_string_lossy().into_owned())
296}
297
298#[cfg(feature = "extension-module")]
299#[pymodule]
300fn pysochrone(_py: Python, m: &PyModule) -> pyo3::PyResult<()> {
301    m.add_class::<PyGraph>()?;
302    m.add_function(wrap_pyfunction!(build_graph, m)?)?;
303    m.add_function(wrap_pyfunction!(calc_isochrones, m)?)?;
304    m.add_function(wrap_pyfunction!(calc_route, m)?)?;
305    m.add_function(wrap_pyfunction!(geocode, m)?)?;
306    m.add_function(wrap_pyfunction!(fetch_pois, m)?)?;
307    m.add_function(wrap_pyfunction!(clear_cache, m)?)?;
308    m.add_function(wrap_pyfunction!(cache_dir, m)?)?;
309    Ok(())
310}