use crate::simplify::simplify_graph;
use crate::utils::{calculate_distance, calculate_travel_time};
use petgraph::graph::{DiGraph, NodeIndex};
use rstar::{PointDistance, RTree, RTreeObject, AABB};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Deserialize)]
pub struct XmlData {
#[serde(rename = "node", default)]
pub nodes: Vec<XmlNode>,
#[serde(rename = "way", default)]
pub ways: Vec<XmlWay>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct XmlNode {
#[serde(rename = "@id")]
pub id: i64,
#[serde(rename = "@lat")]
pub lat: f64,
#[serde(rename = "@lon")]
pub lon: f64,
#[serde(rename = "tag", default)]
pub tags: Vec<XmlTag>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct XmlWay {
#[serde(rename = "@id")]
pub id: i64,
#[serde(rename = "nd", default)]
pub nodes: Vec<XmlNodeRef>,
#[serde(rename = "tag", default)]
pub tags: Vec<XmlTag>,
#[serde(default)]
pub length: f64,
#[serde(default)]
pub speed_kph: f64,
#[serde(default)]
pub walk_travel_time: f64,
#[serde(default)]
pub bike_travel_time: f64,
#[serde(default)]
pub drive_travel_time: f64,
#[serde(default)]
pub geometry: Vec<(f64, f64)>,
}
impl XmlWay {
pub fn filter_useful_tags(mut self) -> Self {
const USEFUL_TAGS: &[&str] = &["highway", "name", "ref", "bridge", "tunnel", "service"];
self.tags
.retain(|tag| USEFUL_TAGS.contains(&tag.key.as_str()));
self
}
#[inline]
pub fn travel_time(&self, network_type: crate::overpass::NetworkType) -> f64 {
match network_type {
crate::overpass::NetworkType::Walk => self.walk_travel_time,
crate::overpass::NetworkType::Bike => self.bike_travel_time,
_ => self.drive_travel_time,
}
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct XmlNodeRef {
#[serde(rename = "@ref")]
pub node_id: i64,
}
#[derive(Debug, Deserialize, Clone)]
pub struct XmlTag {
#[serde(rename = "@k")]
pub key: String,
#[serde(rename = "@v")]
pub value: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Direction {
Bidirectional,
OneWayForward,
OneWayReverse,
}
pub fn parse_xml(xml_data: &str) -> Result<XmlData, quick_xml::DeError> {
let root: XmlData = quick_xml::de::from_str(xml_data)?;
Ok(root)
}
fn find_tag<'a>(tags: &'a [XmlTag], key: &str) -> Option<&'a XmlTag> {
tags.iter().find(|tag| tag.key == key)
}
fn assess_path_directionality(path: &XmlWay) -> Direction {
let oneway_tag = find_tag(&path.tags, "oneway");
let junction_tag = find_tag(&path.tags, "junction");
if oneway_tag.is_some_and(|tag| matches!(tag.value.as_str(), "-1" | "reverse")) {
return Direction::OneWayReverse;
}
if oneway_tag.is_some_and(|tag| matches!(tag.value.as_str(), "yes" | "true" | "1")) {
return Direction::OneWayForward;
}
let is_roundabout = junction_tag.is_some_and(|tag| tag.value == "roundabout");
if is_roundabout {
Direction::OneWayForward
} else {
Direction::Bidirectional
}
}
fn highway_speed_kph(highway: &str) -> Option<f64> {
match highway {
"motorway" => Some(110.0),
"motorway_link" => Some(60.0),
"trunk" => Some(90.0),
"trunk_link" => Some(45.0),
"primary" => Some(65.0),
"primary_link" => Some(45.0),
"secondary" => Some(55.0),
"secondary_link" => Some(40.0),
"tertiary" => Some(45.0),
"tertiary_link" => Some(35.0),
"unclassified" => Some(45.0),
"residential" => Some(30.0),
"living_street" => Some(10.0),
"service" => Some(20.0),
"track" => Some(20.0),
"road" => Some(50.0),
_ => None,
}
}
fn way_speed_kph(way: &XmlWay) -> f64 {
const FALLBACK_SPEED_KPH: f64 = 50.0;
if let Some(maxspeed) =
find_tag(&way.tags, "maxspeed").and_then(|tag| clean_maxspeed(&tag.value))
{
return maxspeed;
}
find_tag(&way.tags, "highway")
.and_then(|tag| highway_speed_kph(&tag.value))
.unwrap_or(FALLBACK_SPEED_KPH)
}
fn edge_way_from_template(
template: &XmlWay,
length: f64,
speed_kph: f64,
geometry: Vec<(f64, f64)>,
) -> XmlWay {
XmlWay {
id: template.id,
nodes: Vec::new(),
tags: template.tags.clone(),
length,
speed_kph,
walk_travel_time: calculate_travel_time(length, 5.0),
bike_travel_time: calculate_travel_time(length, 15.0),
drive_travel_time: calculate_travel_time(length, speed_kph),
geometry,
}
}
pub fn create_graph(
nodes: Vec<XmlNode>,
ways: Vec<XmlWay>,
retain_all: bool,
bidirectional: bool,
) -> DiGraph<XmlNode, XmlWay> {
let mut graph = DiGraph::<XmlNode, XmlWay>::new();
let mut node_index_map = HashMap::new();
for node in nodes {
let id = node.id;
let node_index = graph.add_node(node); node_index_map.insert(id, node_index);
}
for mut way in ways {
let node_refs = std::mem::take(&mut way.nodes);
let path_direction = assess_path_directionality(&way);
let speed_kph = way_speed_kph(&way);
let filtered_way = way.filter_useful_tags();
for window in node_refs.windows(2) {
if let [start_ref, end_ref] = window {
let start_index = node_index_map[&start_ref.node_id];
let end_index = node_index_map[&end_ref.node_id];
let (length, forward_geometry, reverse_geometry) = {
let start_node = &graph[start_index];
let end_node = &graph[end_index];
let length = calculate_distance(
start_node.lat,
start_node.lon,
end_node.lat,
end_node.lon,
);
(
length,
vec![
(start_node.lat, start_node.lon),
(end_node.lat, end_node.lon),
],
vec![
(end_node.lat, end_node.lon),
(start_node.lat, start_node.lon),
],
)
};
match path_direction {
Direction::OneWayForward => {
let edge_way = edge_way_from_template(
&filtered_way,
length,
speed_kph,
forward_geometry.clone(),
);
graph.add_edge(start_index, end_index, edge_way);
}
Direction::OneWayReverse => {
let reverse_way = edge_way_from_template(
&filtered_way,
length,
speed_kph,
reverse_geometry.clone(),
);
graph.add_edge(end_index, start_index, reverse_way);
}
Direction::Bidirectional => {
let edge_way = edge_way_from_template(
&filtered_way,
length,
speed_kph,
forward_geometry.clone(),
);
let reverse_way = edge_way_from_template(
&filtered_way,
length,
speed_kph,
reverse_geometry.clone(),
);
graph.add_edge(start_index, end_index, edge_way);
graph.add_edge(end_index, start_index, reverse_way);
}
}
if bidirectional && path_direction != Direction::Bidirectional {
match path_direction {
Direction::OneWayForward => {
let reverse_way = edge_way_from_template(
&filtered_way,
length,
speed_kph,
reverse_geometry.clone(),
);
graph.add_edge(end_index, start_index, reverse_way);
}
Direction::OneWayReverse => {
let reverse_way = edge_way_from_template(
&filtered_way,
length,
speed_kph,
forward_geometry.clone(),
);
graph.add_edge(start_index, end_index, reverse_way);
}
Direction::Bidirectional => {}
}
}
}
}
}
if !retain_all {
graph = simplify_graph(&graph)
}
graph
}
fn clean_maxspeed(maxspeed: &str) -> Option<f64> {
let mph_to_kph = 1.60934;
let trimmed = maxspeed.trim();
let numeric_prefix: String = trimmed
.chars()
.take_while(|c| c.is_ascii_digit() || *c == '.')
.collect();
let speed = numeric_prefix.parse::<f64>().ok()?;
if speed <= 0.0 {
return None;
}
if trimmed.to_ascii_lowercase().contains("mph") {
Some(speed * mph_to_kph)
} else {
Some(speed)
}
}
pub fn node_to_latlon(graph: &DiGraph<XmlNode, XmlWay>, node_index: NodeIndex) -> (f64, f64) {
let node = &graph[node_index];
(node.lat, node.lon)
}
#[derive(Clone)]
struct NodeEntry {
point: [f64; 2],
index: NodeIndex,
}
impl RTreeObject for NodeEntry {
type Envelope = AABB<[f64; 2]>;
fn envelope(&self) -> Self::Envelope {
AABB::from_point(self.point)
}
}
fn spatial_index_point(lat: f64, lon: f64) -> [f64; 2] {
let meters_per_degree = 111_320.0;
[
lat * meters_per_degree,
lon * meters_per_degree * lat.to_radians().cos(),
]
}
impl PointDistance for NodeEntry {
fn distance_2(&self, point: &[f64; 2]) -> f64 {
let dlat = self.point[0] - point[0];
let dlon = self.point[1] - point[1];
dlat * dlat + dlon * dlon
}
}
#[derive(Debug, Clone, Copy)]
pub struct SnapResult {
pub input_lat: f64,
pub input_lon: f64,
pub node_index: NodeIndex,
pub node_id: i64,
pub node_lat: f64,
pub node_lon: f64,
pub distance_m: f64,
}
#[derive(Debug, Clone, Copy)]
pub struct SnappedPoi {
pub poi_id: i64,
pub snap: SnapResult,
}
#[derive(Clone)]
pub struct SpatialGraph {
pub graph: Arc<DiGraph<XmlNode, XmlWay>>,
tree: Arc<RTree<NodeEntry>>,
pub poi_snaps: Option<Arc<HashMap<i64, SnappedPoi>>>,
}
impl SpatialGraph {
pub fn new(graph: DiGraph<XmlNode, XmlWay>) -> Self {
let entries: Vec<NodeEntry> = graph
.node_indices()
.map(|i| NodeEntry {
point: spatial_index_point(graph[i].lat, graph[i].lon),
index: i,
})
.collect();
let tree = Arc::new(RTree::bulk_load(entries));
let graph = Arc::new(graph);
Self {
graph,
tree,
poi_snaps: None,
}
}
pub(crate) fn from_parsed_osm(
data: XmlData,
network_type: crate::overpass::NetworkType,
retain_all: bool,
) -> Self {
let bidirectional = matches!(network_type, crate::overpass::NetworkType::Walk);
let graph = create_graph(data.nodes, data.ways, retain_all, bidirectional);
Self::new(graph)
}
pub fn from_osm(
xml: &str,
network_type: crate::overpass::NetworkType,
retain_all: Option<bool>,
) -> Result<Self, quick_xml::DeError> {
let data = parse_xml(xml)?;
Ok(Self::from_parsed_osm(
data,
network_type,
retain_all.unwrap_or(false),
))
}
pub fn snap_pois(&mut self, pois: &[crate::poi::Poi]) {
let snaps: HashMap<i64, SnappedPoi> = pois
.iter()
.filter_map(|poi| {
self.snap_point(poi.lat, poi.lon).map(|snap| {
(
poi.id,
SnappedPoi {
poi_id: poi.id,
snap,
},
)
})
})
.collect();
self.poi_snaps = Some(Arc::new(snaps));
}
pub fn nearest_node(&self, lat: f64, lon: f64) -> Option<NodeIndex> {
self.tree
.nearest_neighbor(&spatial_index_point(lat, lon))
.map(|e| e.index)
}
pub fn snap_point(&self, lat: f64, lon: f64) -> Option<SnapResult> {
self.nearest_node(lat, lon).map(|node_index| {
let node = &self.graph[node_index];
SnapResult {
input_lat: lat,
input_lon: lon,
node_index,
node_id: node.id,
node_lat: node.lat,
node_lon: node.lon,
distance_m: calculate_distance(lat, lon, node.lat, node.lon),
}
})
}
pub fn snap_point_within(
&self,
lat: f64,
lon: f64,
max_distance_m: Option<f64>,
) -> Option<SnapResult> {
let snap = self.snap_point(lat, lon)?;
match max_distance_m {
Some(max_distance_m) if snap.distance_m > max_distance_m => None,
_ => Some(snap),
}
}
pub fn nearest_node_within(
&self,
lat: f64,
lon: f64,
max_distance_m: Option<f64>,
) -> Option<NodeIndex> {
self.snap_point_within(lat, lon, max_distance_m)
.map(|snap| snap.node_index)
}
}
#[cfg(test)]
mod tests {
use super::*;
use petgraph::visit::EdgeRef;
fn make_node(id: i64, lat: f64, lon: f64) -> XmlNode {
XmlNode {
id,
lat,
lon,
tags: vec![],
}
}
fn make_way_raw(node_ids: Vec<i64>, tags: Vec<(&str, &str)>) -> XmlWay {
XmlWay {
id: 1,
nodes: node_ids
.into_iter()
.map(|id| XmlNodeRef { node_id: id })
.collect(),
tags: tags
.into_iter()
.map(|(k, v)| XmlTag {
key: k.into(),
value: v.into(),
})
.collect(),
length: 0.0,
speed_kph: 0.0,
walk_travel_time: 0.0,
bike_travel_time: 0.0,
drive_travel_time: 0.0,
geometry: Vec::new(),
}
}
fn edge_id_pairs(graph: &DiGraph<XmlNode, XmlWay>) -> Vec<(i64, i64)> {
let mut pairs: Vec<(i64, i64)> = graph
.edge_references()
.map(|edge| (graph[edge.source()].id, graph[edge.target()].id))
.collect();
pairs.sort_unstable();
pairs
}
#[test]
fn test_graph_respects_maxspeed_tag() {
let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
let way = make_way_raw(
vec![1, 2],
vec![("highway", "residential"), ("maxspeed", "30")],
);
let graph = create_graph(
vec![nodes[0].clone(), nodes[1].clone()],
vec![way],
true,
false,
);
assert_eq!(graph.edge_weights().next().unwrap().speed_kph, 30.0);
}
#[test]
fn test_graph_parses_mph_maxspeed_tag() {
let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
let way = make_way_raw(
vec![1, 2],
vec![("highway", "residential"), ("maxspeed", "30 mph")],
);
let graph = create_graph(
vec![nodes[0].clone(), nodes[1].clone()],
vec![way],
true,
false,
);
let speed = graph.edge_weights().next().unwrap().speed_kph;
assert!((speed - 48.2802).abs() < 1e-4);
}
#[test]
fn test_graph_falls_back_when_maxspeed_is_non_numeric() {
let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
let way = make_way_raw(
vec![1, 2],
vec![("highway", "residential"), ("maxspeed", "signals")],
);
let graph = create_graph(
vec![nodes[0].clone(), nodes[1].clone()],
vec![way],
true,
false,
);
assert_eq!(graph.edge_weights().next().unwrap().speed_kph, 30.0);
}
#[test]
fn test_oneway_produces_single_edge() {
let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
let way = make_way_raw(
vec![1, 2],
vec![("highway", "residential"), ("oneway", "yes")],
);
let graph = create_graph(
vec![nodes[0].clone(), nodes[1].clone()],
vec![way],
true,
false,
);
assert_eq!(graph.edge_count(), 1);
}
#[test]
fn test_oneway_yes_points_in_way_order() {
let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
let way = make_way_raw(
vec![1, 2],
vec![("highway", "residential"), ("oneway", "yes")],
);
let graph = create_graph(nodes, vec![way], true, false);
assert_eq!(edge_id_pairs(&graph), vec![(1, 2)]);
}
#[test]
fn test_oneway_reverse_points_against_way_order() {
for value in ["-1", "reverse"] {
let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
let way = make_way_raw(
vec![1, 2],
vec![("highway", "residential"), ("oneway", value)],
);
let graph = create_graph(nodes, vec![way], true, false);
assert_eq!(edge_id_pairs(&graph), vec![(2, 1)], "oneway={value}");
}
}
#[test]
fn test_oneway_truthy_values_point_in_way_order() {
for value in ["true", "1"] {
let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
let way = make_way_raw(
vec![1, 2],
vec![("highway", "residential"), ("oneway", value)],
);
let graph = create_graph(nodes, vec![way], true, false);
assert_eq!(edge_id_pairs(&graph), vec![(1, 2)], "oneway={value}");
}
}
#[test]
fn test_roundabout_is_oneway_forward() {
let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
let way = make_way_raw(
vec![1, 2],
vec![("highway", "residential"), ("junction", "roundabout")],
);
let graph = create_graph(nodes, vec![way], true, false);
assert_eq!(edge_id_pairs(&graph), vec![(1, 2)]);
}
#[test]
fn test_bidirectional_profile_overrides_oneway() {
let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
let way = make_way_raw(
vec![1, 2],
vec![("highway", "residential"), ("oneway", "yes")],
);
let graph = create_graph(nodes, vec![way], true, true);
assert_eq!(edge_id_pairs(&graph), vec![(1, 2), (2, 1)]);
}
#[test]
fn test_bidirectional_produces_two_edges() {
let nodes = vec![make_node(1, 0.0, 0.0), make_node(2, 0.001, 0.0)];
let way = make_way_raw(vec![1, 2], vec![("highway", "residential")]);
let graph = create_graph(
vec![nodes[0].clone(), nodes[1].clone()],
vec![way],
true,
false,
);
assert_eq!(graph.edge_count(), 2);
}
#[test]
fn test_nearest_node_finds_closest() {
let mut graph = DiGraph::new();
graph.add_node(make_node(1, 48.0, 11.0));
graph.add_node(make_node(2, 52.0, 13.0));
let sg = SpatialGraph::new(graph);
let idx = sg.nearest_node(48.001, 11.001).unwrap();
assert_eq!(sg.graph[idx].id, 1);
}
#[test]
fn test_snap_point_returns_diagnostics() {
let mut graph = DiGraph::new();
graph.add_node(make_node(1, 48.0, 11.0));
let sg = SpatialGraph::new(graph);
let snap = sg.snap_point(48.001, 11.001).unwrap();
assert_eq!(snap.node_id, 1);
assert_eq!(snap.node_lat, 48.0);
assert_eq!(snap.node_lon, 11.0);
assert!(snap.distance_m > 0.0);
}
#[test]
fn test_snap_point_within_rejects_far_snap() {
let mut graph = DiGraph::new();
graph.add_node(make_node(1, 48.0, 11.0));
let sg = SpatialGraph::new(graph);
assert!(sg.snap_point_within(48.001, 11.001, Some(500.0)).is_some());
assert!(sg.snap_point_within(48.001, 11.001, Some(1.0)).is_none());
}
#[test]
fn test_parse_xml_minimal_osm_fixture() {
let xml = r#"
<osm version="0.6">
<node id="1" lat="48.0" lon="11.0" />
<node id="2" lat="48.001" lon="11.0">
<tag k="amenity" v="cafe" />
</node>
<way id="10">
<nd ref="1" />
<nd ref="2" />
<tag k="highway" v="residential" />
<tag k="name" v="Fixture Street" />
</way>
</osm>
"#;
let parsed = parse_xml(xml).unwrap();
assert_eq!(parsed.nodes.len(), 2);
assert_eq!(parsed.ways.len(), 1);
assert_eq!(parsed.ways[0].nodes.len(), 2);
assert!(parsed.ways[0]
.tags
.iter()
.any(|tag| tag.key == "highway" && tag.value == "residential"));
assert_eq!(parsed.nodes[1].tags[0].key, "amenity");
}
#[test]
fn test_parse_xml_malformed_input_errors() {
let err = parse_xml("<osm><node id=\"1\"></osm>").unwrap_err();
assert!(!err.to_string().is_empty());
}
}