Skip to main content

pysochrone/
pbf.rs

1//! Read OpenStreetMap PBF files into the same intermediate shape produced by
2//! the Overpass XML parser. This lets the rest of the pipeline (graph building,
3//! POI extraction) work unchanged whether the data came from live Overpass or
4//! a local PBF file.
5
6use std::collections::{HashMap, HashSet};
7use std::path::Path;
8
9use osmpbf::{Element, ElementReader};
10
11use crate::error::OsmGraphError;
12use crate::graph::{XmlData, XmlNode, XmlNodeRef, XmlTag, XmlWay};
13use crate::overpass::NetworkType;
14
15/// Read a PBF file once and produce one `XmlData` per requested network type,
16/// plus the set of node IDs that are POIs (POIs are network-type-independent).
17///
18/// This avoids re-reading the PBF for each network type — useful at server
19/// startup when you want walk/bike/drive graphs for the same region.
20pub fn read_pbf_multi(
21    path: impl AsRef<Path>,
22    network_types: &[NetworkType],
23) -> Result<(HashMap<NetworkType, XmlData>, HashSet<i64>), OsmGraphError> {
24    let mut all_nodes: HashMap<i64, RawNode> = HashMap::new();
25    let mut roads_by_type: HashMap<NetworkType, Vec<RawWay>> =
26        network_types.iter().map(|nt| (*nt, Vec::new())).collect();
27    let mut poi_ids: HashSet<i64> = HashSet::new();
28
29    let reader = ElementReader::from_path(path.as_ref())
30        .map_err(|e| OsmGraphError::PbfError(e.to_string()))?;
31
32    reader
33        .for_each(|element| match element {
34            Element::Node(node) => {
35                let tags: Vec<(String, String)> = node
36                    .tags()
37                    .map(|(k, v)| (k.to_string(), v.to_string()))
38                    .collect();
39                let id = node.id();
40                if is_poi_node(&tags) {
41                    poi_ids.insert(id);
42                }
43                all_nodes.insert(id, RawNode { lat: node.lat(), lon: node.lon(), tags });
44            }
45            Element::DenseNode(node) => {
46                let tags: Vec<(String, String)> = node
47                    .tags()
48                    .map(|(k, v)| (k.to_string(), v.to_string()))
49                    .collect();
50                let id = node.id();
51                if is_poi_node(&tags) {
52                    poi_ids.insert(id);
53                }
54                all_nodes.insert(id, RawNode { lat: node.lat(), lon: node.lon(), tags });
55            }
56            Element::Way(way) => {
57                let tags: Vec<(String, String)> = way
58                    .tags()
59                    .map(|(k, v)| (k.to_string(), v.to_string()))
60                    .collect();
61                // Quick reject: ways without a highway tag aren't roads for any mode.
62                if !tags.iter().any(|(k, _)| k == "highway") { return; }
63                let refs: Vec<i64> = way.refs().collect();
64                for &nt in network_types {
65                    if way_passes_road_filter(&tags, nt) {
66                        roads_by_type.get_mut(&nt).unwrap().push(RawWay {
67                            id: way.id(),
68                            refs: refs.clone(),
69                            tags: tags.clone(),
70                        });
71                    }
72                }
73            }
74            Element::Relation(_) => {}
75        })
76        .map_err(|e| OsmGraphError::PbfError(e.to_string()))?;
77
78    // Per-network-type, emit only the nodes referenced by that type's ways
79    // (plus all POI nodes — they're shared across all network types).
80    let mut out: HashMap<NetworkType, XmlData> = HashMap::new();
81    for (nt, roads) in roads_by_type {
82        let mut needed: HashSet<i64> = poi_ids.clone();
83        for w in &roads {
84            for r in &w.refs {
85                needed.insert(*r);
86            }
87        }
88        let nodes: Vec<XmlNode> = all_nodes
89            .iter()
90            .filter(|(id, _)| needed.contains(id))
91            .map(|(id, n)| XmlNode {
92                id: *id,
93                lat: n.lat,
94                lon: n.lon,
95                tags: n.tags.iter().cloned()
96                    .map(|(k, v)| XmlTag { key: k, value: v })
97                    .collect(),
98                geohash: None,
99            })
100            .collect();
101        let ways: Vec<XmlWay> = roads
102            .into_iter()
103            .map(|w| XmlWay {
104                id: w.id,
105                nodes: w.refs.into_iter().map(|node_id| XmlNodeRef { node_id }).collect(),
106                tags: w.tags.into_iter().map(|(k, v)| XmlTag { key: k, value: v }).collect(),
107                length: 0.0, speed_kph: 0.0,
108                walk_travel_time: 0.0, bike_travel_time: 0.0, drive_travel_time: 0.0,
109            })
110            .collect();
111        out.insert(nt, XmlData { nodes, ways});
112    }
113
114    Ok((out, poi_ids))
115}
116
117/// Read a PBF file and produce an `XmlData` (the canonical intermediate shape
118/// our graph builder consumes) plus the set of node IDs that are POIs.
119///
120/// Two-pass logic implemented in a single PBF iteration:
121///   1. Collect every node into a temporary map (id → lat/lon/tags).
122///   2. Collect every way that passes the road-network filter for `network_type`.
123///   3. Mark POI nodes (any node with our standard amenity/tourism/etc. tags).
124///
125/// After iteration, emit only the nodes we actually need: those referenced by
126/// a kept way, or flagged as a POI. Everything else is discarded — for DC this
127/// drops the ~4 million tagless nodes.
128pub fn read_pbf(
129    path: impl AsRef<Path>,
130    network_type: NetworkType,
131) -> Result<(XmlData, HashSet<i64>), OsmGraphError> {
132    let mut all_nodes: HashMap<i64, RawNode> = HashMap::new();
133    let mut roads: Vec<RawWay> = Vec::new();
134    let mut poi_ids: HashSet<i64> = HashSet::new();
135
136    let reader = ElementReader::from_path(path.as_ref())
137        .map_err(|e| OsmGraphError::PbfError(e.to_string()))?;
138
139    reader
140        .for_each(|element| match element {
141            Element::Node(node) => {
142                let tags: Vec<(String, String)> = node
143                    .tags()
144                    .map(|(k, v)| (k.to_string(), v.to_string()))
145                    .collect();
146                let id = node.id();
147                if is_poi_node(&tags) {
148                    poi_ids.insert(id);
149                }
150                all_nodes.insert(
151                    id,
152                    RawNode { lat: node.lat(), lon: node.lon(), tags },
153                );
154            }
155            Element::DenseNode(node) => {
156                let tags: Vec<(String, String)> = node
157                    .tags()
158                    .map(|(k, v)| (k.to_string(), v.to_string()))
159                    .collect();
160                let id = node.id();
161                if is_poi_node(&tags) {
162                    poi_ids.insert(id);
163                }
164                all_nodes.insert(
165                    id,
166                    RawNode { lat: node.lat(), lon: node.lon(), tags },
167                );
168            }
169            Element::Way(way) => {
170                let tags: Vec<(String, String)> = way
171                    .tags()
172                    .map(|(k, v)| (k.to_string(), v.to_string()))
173                    .collect();
174                if !way_passes_road_filter(&tags, network_type) {
175                    return;
176                }
177                let refs: Vec<i64> = way.refs().collect();
178                roads.push(RawWay { id: way.id(), refs, tags });
179            }
180            Element::Relation(_) => {}
181        })
182        .map_err(|e| OsmGraphError::PbfError(e.to_string()))?;
183
184    // Build the set of nodes we actually need to keep.
185    let mut needed: HashSet<i64> = poi_ids.clone();
186    for w in &roads {
187        for r in &w.refs {
188            needed.insert(*r);
189        }
190    }
191
192    let nodes: Vec<XmlNode> = all_nodes
193        .into_iter()
194        .filter(|(id, _)| needed.contains(id))
195        .map(|(id, n)| XmlNode {
196            id,
197            lat: n.lat,
198            lon: n.lon,
199            tags: n
200                .tags
201                .into_iter()
202                .map(|(k, v)| XmlTag { key: k, value: v })
203                .collect(),
204            geohash: None,
205        })
206        .collect();
207
208    let ways: Vec<XmlWay> = roads
209        .into_iter()
210        .map(|w| XmlWay {
211            id: w.id,
212            nodes: w
213                .refs
214                .into_iter()
215                .map(|node_id| XmlNodeRef { node_id })
216                .collect(),
217            tags: w
218                .tags
219                .into_iter()
220                .map(|(k, v)| XmlTag { key: k, value: v })
221                .collect(),
222            length: 0.0,
223            speed_kph: 0.0,
224            walk_travel_time: 0.0,
225            bike_travel_time: 0.0,
226            drive_travel_time: 0.0,
227        })
228        .collect();
229
230    Ok((XmlData { nodes, ways}, poi_ids))
231}
232
233struct RawNode {
234    lat: f64,
235    lon: f64,
236    tags: Vec<(String, String)>,
237}
238
239struct RawWay {
240    id: i64,
241    refs: Vec<i64>,
242    tags: Vec<(String, String)>,
243}
244
245/// Mirror of `overpass::get_osm_filter`. If Overpass filter rules ever change,
246/// these need to change in lockstep.
247fn way_passes_road_filter(tags: &[(String, String)], network_type: NetworkType) -> bool {
248    let get = |k: &str| {
249        tags.iter()
250            .find(|(tk, _)| tk == k)
251            .map(|(_, v)| v.as_str())
252    };
253
254    let highway = match get("highway") {
255        Some(v) => v,
256        None => return false,
257    };
258    if get("area") == Some("yes") {
259        return false;
260    }
261
262    match network_type {
263        NetworkType::Drive => {
264            const EXCLUDE_HIGHWAY: &[&str] = &[
265                "abandoned", "bridleway", "bus_guideway", "construction", "corridor",
266                "cycleway", "elevator", "escalator", "footway", "no", "path", "pedestrian",
267                "planned", "platform", "proposed", "raceway", "razed", "service", "steps", "track",
268            ];
269            if EXCLUDE_HIGHWAY.contains(&highway) { return false; }
270            if get("motor_vehicle") == Some("no") { return false; }
271            if get("motorcar") == Some("no") { return false; }
272            const EXCLUDE_SERVICE: &[&str] = &[
273                "alley", "driveway", "emergency_access", "parking", "parking_aisle", "private",
274            ];
275            if let Some(s) = get("service") {
276                if EXCLUDE_SERVICE.contains(&s) { return false; }
277            }
278        }
279        NetworkType::DriveService => {
280            const EXCLUDE_HIGHWAY: &[&str] = &[
281                "abandoned", "bridleway", "bus_guideway", "construction", "corridor",
282                "cycleway", "elevator", "escalator", "footway", "no", "path", "pedestrian",
283                "planned", "platform", "proposed", "raceway", "razed", "steps", "track",
284            ];
285            if EXCLUDE_HIGHWAY.contains(&highway) { return false; }
286            if get("motor_vehicle") == Some("no") { return false; }
287            if get("motorcar") == Some("no") { return false; }
288            const EXCLUDE_SERVICE: &[&str] = &[
289                "emergency_access", "parking", "parking_aisle", "private",
290            ];
291            if let Some(s) = get("service") {
292                if EXCLUDE_SERVICE.contains(&s) { return false; }
293            }
294        }
295        NetworkType::Walk => {
296            // "motor" is a substring pattern in Overpass — matches motor, motorway, motorroad.
297            const EXCLUDE_HIGHWAY: &[&str] = &[
298                "abandoned", "bus_guideway", "construction", "corridor", "elevator", "escalator",
299                "no", "planned", "platform", "proposed", "raceway", "razed",
300            ];
301            if EXCLUDE_HIGHWAY.contains(&highway) || highway.starts_with("motor") { return false; }
302            if get("foot") == Some("no") { return false; }
303            if get("service") == Some("private") { return false; }
304        }
305        NetworkType::Bike => {
306            const EXCLUDE_HIGHWAY: &[&str] = &[
307                "abandoned", "bus_guideway", "construction", "corridor", "elevator", "escalator",
308                "footway", "no", "planned", "platform", "proposed", "raceway", "razed", "steps",
309            ];
310            if EXCLUDE_HIGHWAY.contains(&highway) || highway.starts_with("motor") { return false; }
311            if get("bicycle") == Some("no") { return false; }
312            if get("service") == Some("private") { return false; }
313        }
314        NetworkType::All => {
315            const EXCLUDE_HIGHWAY: &[&str] = &[
316                "abandoned", "construction", "no", "planned", "platform", "proposed", "raceway", "razed",
317            ];
318            if EXCLUDE_HIGHWAY.contains(&highway) { return false; }
319            if get("service") == Some("private") { return false; }
320        }
321        NetworkType::AllPrivate => {
322            const EXCLUDE_HIGHWAY: &[&str] = &[
323                "abandoned", "construction", "no", "planned", "platform", "proposed", "raceway", "razed",
324            ];
325            if EXCLUDE_HIGHWAY.contains(&highway) { return false; }
326        }
327    }
328    true
329}
330
331/// Mirror of the node selectors in `poi::create_poi_query`.
332fn is_poi_node(tags: &[(String, String)]) -> bool {
333    let get = |k: &str| {
334        tags.iter()
335            .find(|(tk, _)| tk == k)
336            .map(|(_, v)| v.as_str())
337    };
338
339    if get("tourism").is_some() { return true; }
340    if get("historic").is_some() { return true; }
341
342    if let Some(v) = get("natural") {
343        if matches!(v, "peak" | "waterfall" | "cave_entrance" | "beach" | "hot_spring") {
344            return true;
345        }
346    }
347
348    if let Some(v) = get("amenity") {
349        if matches!(
350            v,
351            "restaurant" | "fast_food" | "cafe" | "bar" | "pub" | "biergarten" | "ice_cream"
352                | "food_court" | "museum" | "theatre" | "cinema" | "arts_centre" | "library"
353                | "place_of_worship" | "spa" | "swimming_pool"
354        ) {
355            return true;
356        }
357    }
358
359    if let Some(v) = get("leisure") {
360        if matches!(v, "park" | "nature_reserve" | "garden" | "sports_centre" | "fitness_centre") {
361            return true;
362        }
363    }
364
365    if let Some(v) = get("shop") {
366        if matches!(v, "bakery" | "deli" | "chocolate" | "wine" | "cheese" | "mall" | "department_store") {
367            return true;
368        }
369    }
370
371    false
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn poi_detection_amenity() {
380        let tags = vec![
381            ("amenity".to_string(), "restaurant".to_string()),
382            ("name".to_string(), "Joe's".to_string()),
383        ];
384        assert!(is_poi_node(&tags));
385    }
386
387    #[test]
388    fn poi_detection_rejects_unrelated_amenity() {
389        let tags = vec![("amenity".to_string(), "atm".to_string())];
390        assert!(!is_poi_node(&tags));
391    }
392
393    #[test]
394    fn poi_detection_tourism() {
395        // Any tourism tag counts.
396        let tags = vec![("tourism".to_string(), "hotel".to_string())];
397        assert!(is_poi_node(&tags));
398    }
399
400    #[test]
401    fn road_filter_walk_keeps_residential() {
402        let tags = vec![("highway".to_string(), "residential".to_string())];
403        assert!(way_passes_road_filter(&tags, NetworkType::Walk));
404    }
405
406    #[test]
407    fn road_filter_walk_rejects_motor() {
408        let tags = vec![("highway".to_string(), "motorway".to_string())];
409        assert!(!way_passes_road_filter(&tags, NetworkType::Walk));
410    }
411
412    #[test]
413    fn road_filter_drive_rejects_footway() {
414        let tags = vec![("highway".to_string(), "footway".to_string())];
415        assert!(!way_passes_road_filter(&tags, NetworkType::Drive));
416    }
417
418    #[test]
419    fn road_filter_rejects_area_yes() {
420        let tags = vec![
421            ("highway".to_string(), "residential".to_string()),
422            ("area".to_string(), "yes".to_string()),
423        ];
424        assert!(!way_passes_road_filter(&tags, NetworkType::Walk));
425    }
426
427    #[test]
428    fn road_filter_walk_rejects_foot_no() {
429        let tags = vec![
430            ("highway".to_string(), "residential".to_string()),
431            ("foot".to_string(), "no".to_string()),
432        ];
433        assert!(!way_passes_road_filter(&tags, NetworkType::Walk));
434    }
435}