use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
use petgraph::visit::EdgeRef;
use crate::graph::{SpatialGraph, XmlNode, XmlWay};
use crate::overpass::NetworkType;
#[derive(Debug, Clone)]
pub struct ReachabilityResult {
pub start: NodeIndex,
pub max_cost: f64,
pub distances: HashMap<NodeIndex, f64>,
}
#[derive(Clone)]
pub struct ReachableGraph {
pub graph: SpatialGraph,
pub result: ReachabilityResult,
pub network_type: NetworkType,
}
#[derive(Debug, Clone, Copy)]
pub struct EdgeInfo<'a> {
pub id: EdgeIndex,
pub source: NodeIndex,
pub target: NodeIndex,
pub weight: &'a XmlWay,
}
#[derive(Clone, Copy, Debug)]
struct SearchState {
cost: f64,
node: NodeIndex,
}
impl PartialEq for SearchState {
fn eq(&self, other: &Self) -> bool {
self.cost == other.cost && self.node == other.node
}
}
impl Eq for SearchState {}
impl PartialOrd for SearchState {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SearchState {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
pub fn compute_reachability_with<F>(
graph: &DiGraph<XmlNode, XmlWay>,
start: NodeIndex,
max_cost: f64,
mut cost: F,
) -> ReachabilityResult
where
F: FnMut(EdgeInfo<'_>) -> f64,
{
if max_cost.is_nan() || max_cost < 0.0 {
return ReachabilityResult {
start,
max_cost,
distances: HashMap::new(),
};
}
let mut distances = HashMap::new();
let mut heap = BinaryHeap::new();
distances.insert(start, 0.0);
heap.push(SearchState {
cost: 0.0,
node: start,
});
while let Some(SearchState {
cost: node_cost,
node,
}) = heap.pop()
{
if node_cost > max_cost {
break;
}
if node_cost > *distances.get(&node).unwrap_or(&f64::INFINITY) {
continue;
}
for edge in graph.edges(node) {
let edge_cost = cost(EdgeInfo {
id: edge.id(),
source: edge.source(),
target: edge.target(),
weight: edge.weight(),
});
if !edge_cost.is_finite() || edge_cost < 0.0 {
continue;
}
let next = edge.target();
let next_cost = node_cost + edge_cost;
if next_cost > max_cost {
continue;
}
if next_cost < *distances.get(&next).unwrap_or(&f64::INFINITY) {
distances.insert(next, next_cost);
heap.push(SearchState {
cost: next_cost,
node: next,
});
}
}
}
ReachabilityResult {
start,
max_cost,
distances,
}
}
pub fn compute_reachability(
graph: &DiGraph<XmlNode, XmlWay>,
start: NodeIndex,
max_cost: f64,
network_type: NetworkType,
) -> ReachabilityResult {
compute_reachability_with(graph, start, max_cost, |e| {
e.weight.travel_time(network_type)
})
}
fn reachable_subgraph(sg: &SpatialGraph, result: &ReachabilityResult) -> SpatialGraph {
let mut subgraph = DiGraph::new();
let mut old_to_new = HashMap::new();
for &old_idx in result.distances.keys() {
let new_idx = subgraph.add_node(sg.graph[old_idx].clone());
old_to_new.insert(old_idx, new_idx);
}
for edge in sg.graph.edge_references() {
let (Some(&source), Some(&target)) = (
old_to_new.get(&edge.source()),
old_to_new.get(&edge.target()),
) else {
continue;
};
subgraph.add_edge(source, target, edge.weight().clone());
}
SpatialGraph::new(subgraph)
}
impl ReachableGraph {
pub fn node_count(&self) -> usize {
self.result.distances.len()
}
pub fn edge_count(&self) -> usize {
self.graph
.graph
.edge_references()
.filter(|edge| {
self.result.distances.contains_key(&edge.source())
&& self.result.distances.contains_key(&edge.target())
})
.count()
}
pub fn contains_node_id(&self, node_id: i64) -> bool {
self.result
.distances
.keys()
.any(|&idx| self.graph.graph[idx].id == node_id)
}
pub fn travel_time_to_node_id(&self, node_id: i64) -> Option<f64> {
self.result
.distances
.iter()
.find_map(|(&idx, &time)| (self.graph.graph[idx].id == node_id).then_some(time))
}
pub fn materialize(&self) -> SpatialGraph {
reachable_subgraph(&self.graph, &self.result)
}
pub fn route(
&self,
origin_lat: f64,
origin_lon: f64,
dest_lat: f64,
dest_lon: f64,
max_snap_m: Option<f64>,
) -> Result<crate::routing::Route, crate::error::OsmGraphError> {
self.materialize().route(
origin_lat,
origin_lon,
dest_lat,
dest_lon,
self.network_type,
max_snap_m,
)
}
pub fn isochrones(
&self,
lat: f64,
lon: f64,
time_limits: Vec<f64>,
max_snap_m: Option<f64>,
) -> Option<Vec<geo::Polygon>> {
self.materialize()
.isochrones(lat, lon, time_limits, self.network_type, max_snap_m)
}
}
impl SpatialGraph {
pub fn reachable_graph(
&self,
lat: f64,
lon: f64,
max_time: f64,
network_type: NetworkType,
max_snap_m: Option<f64>,
) -> Option<ReachableGraph> {
let result = self.reachability(lat, lon, max_time, network_type, max_snap_m)?;
Some(ReachableGraph {
graph: self.clone(),
result,
network_type,
})
}
pub fn reachability(
&self,
lat: f64,
lon: f64,
max_time: f64,
network_type: NetworkType,
max_snap_m: Option<f64>,
) -> Option<ReachabilityResult> {
let start = self.nearest_node_within(lat, lon, max_snap_m)?;
Some(compute_reachability(
&self.graph,
start,
max_time,
network_type,
))
}
pub async fn reachable_pois(
&self,
lat: f64,
lon: f64,
max_time: f64,
network_type: NetworkType,
) -> Option<Result<Vec<crate::poi::ReachablePoi>, crate::error::OsmGraphError>> {
let result = self.reachability(lat, lon, max_time, network_type, None)?;
Some(crate::poi::fetch_pois_within_reachability(self, &result).await)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::create_graph;
use crate::graph::{XmlNode, XmlNodeRef, XmlTag, XmlWay};
fn node(id: i64, lat: f64, lon: f64) -> XmlNode {
XmlNode {
id,
lat,
lon,
tags: vec![],
}
}
fn way(node_ids: Vec<i64>, tags: Vec<(&str, &str)>) -> XmlWay {
XmlWay {
id: 1,
nodes: node_ids
.into_iter()
.map(|id| XmlNodeRef { node_id: id })
.collect(),
tags: tags
.into_iter()
.map(|(k, v)| XmlTag {
key: k.into(),
value: v.into(),
})
.collect(),
length: 0.0,
speed_kph: 0.0,
walk_travel_time: 0.0,
bike_travel_time: 0.0,
drive_travel_time: 0.0,
geometry: Vec::new(),
}
}
#[test]
fn budget_excludes_distant_nodes() {
let nodes = vec![node(1, 0.0, 0.0), node(2, 0.0, 0.001), node(3, 0.0, 0.002)];
let w = way(vec![1, 2, 3], vec![("highway", "residential")]);
let g = create_graph(nodes, vec![w], true, false);
let start = g.node_indices().find(|&i| g[i].id == 1).unwrap();
let full = compute_reachability(&g, start, f64::INFINITY, NetworkType::Drive);
assert_eq!(
full.distances.len(),
3,
"all 3 nodes should reach with infinite budget"
);
let tight = compute_reachability(&g, start, 5.0, NetworkType::Drive);
assert!(
tight.distances.len() < 3,
"tight budget should exclude at least one node"
);
assert!(tight.distances.values().all(|&t| t <= 5.0));
}
#[test]
fn custom_cost_closure_controls_distances() {
let nodes = vec![
node(1, 0.0, 0.0),
node(2, 0.0, 0.001),
node(3, 0.0, 0.002),
node(4, 0.0, 0.003),
];
let w = way(vec![1, 2, 3, 4], vec![("highway", "residential")]);
let g = create_graph(nodes, vec![w], true, false);
let start = g.node_indices().find(|&i| g[i].id == 1).unwrap();
let result = compute_reachability_with(&g, start, 100.0, |_| 10.0);
let mut times: Vec<f64> = result.distances.values().copied().collect();
times.sort_by(f64::total_cmp);
assert_eq!(times, vec![0.0, 10.0, 20.0, 30.0]);
}
#[test]
fn bounded_search_does_not_insert_nodes_beyond_budget() {
let nodes = vec![
node(1, 0.0, 0.0),
node(2, 0.0, 0.001),
node(3, 0.0, 0.002),
node(4, 0.0, 0.003),
];
let w = way(vec![1, 2, 3, 4], vec![("highway", "residential")]);
let g = create_graph(nodes, vec![w], true, false);
let start = g.node_indices().find(|&i| g[i].id == 1).unwrap();
let result = compute_reachability_with(&g, start, 15.0, |_| 10.0);
let mut node_ids: Vec<i64> = result.distances.keys().map(|&idx| g[idx].id).collect();
node_ids.sort_unstable();
assert_eq!(node_ids, vec![1, 2]);
assert!(result.distances.values().all(|&t| t <= 15.0));
}
#[test]
fn invalid_budget_returns_empty_reachability() {
let nodes = vec![node(1, 0.0, 0.0), node(2, 0.0, 0.001)];
let w = way(vec![1, 2], vec![("highway", "residential")]);
let g = create_graph(nodes, vec![w], true, false);
let start = g.node_indices().find(|&i| g[i].id == 1).unwrap();
let result = compute_reachability(&g, start, f64::NAN, NetworkType::Drive);
assert!(result.distances.is_empty());
}
#[test]
fn closure_can_double_baseline_cost() {
let nodes = vec![node(1, 0.0, 0.0), node(2, 0.0, 0.001), node(3, 0.0, 0.002)];
let w = way(vec![1, 2, 3], vec![("highway", "residential")]);
let g = create_graph(nodes, vec![w], true, false);
let start = g.node_indices().find(|&i| g[i].id == 1).unwrap();
let baseline = compute_reachability(&g, start, f64::INFINITY, NetworkType::Drive);
let doubled = compute_reachability_with(&g, start, f64::INFINITY, |e| {
e.weight.travel_time(NetworkType::Drive) * 2.0
});
for (node, &b) in &baseline.distances {
let d = doubled.distances[node];
assert!(
(d - 2.0 * b).abs() < 1e-9,
"node {:?}: expected 2x baseline",
node
);
}
}
#[test]
fn reachable_graph_view_exposes_induced_counts_and_travel_times() {
let nodes = vec![node(1, 0.0, 0.0), node(2, 0.0, 0.001), node(3, 0.0, 0.002)];
let w = way(vec![1, 2, 3], vec![("highway", "residential")]);
let graph = SpatialGraph::new(create_graph(nodes, vec![w], true, false));
let reachable = graph
.reachable_graph(0.0, 0.0, 20.0, NetworkType::Drive, None)
.unwrap();
assert_eq!(reachable.node_count(), 2);
assert_eq!(reachable.edge_count(), 2);
assert!(reachable.contains_node_id(1));
assert!(reachable.contains_node_id(2));
assert!(!reachable.contains_node_id(3));
assert_eq!(reachable.travel_time_to_node_id(1), Some(0.0));
assert!(reachable.travel_time_to_node_id(2).unwrap() > 0.0);
assert_eq!(reachable.graph.graph.node_count(), 3);
assert_eq!(reachable.materialize().graph.node_count(), 2);
}
}