use std::collections::HashMap;
use petgraph::algo::dijkstra;
use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
use petgraph::visit::EdgeRef;
use crate::graph::{SpatialGraph, XmlNode, XmlWay};
use crate::overpass::NetworkType;
#[derive(Debug, Clone)]
pub struct ReachabilityResult {
pub start: NodeIndex,
pub max_cost: f64,
pub distances: HashMap<NodeIndex, f64>,
}
#[derive(Debug, Clone, Copy)]
pub struct EdgeInfo<'a> {
pub id: EdgeIndex,
pub source: NodeIndex,
pub target: NodeIndex,
pub weight: &'a XmlWay,
}
pub fn compute_reachability_with<F>(
graph: &DiGraph<XmlNode, XmlWay>,
start: NodeIndex,
max_cost: f64,
mut cost: F,
) -> ReachabilityResult
where
F: FnMut(EdgeInfo<'_>) -> f64,
{
let raw = dijkstra(graph, start, None, |e| {
cost(EdgeInfo {
id: e.id(),
source: e.source(),
target: e.target(),
weight: e.weight(),
})
});
let distances: HashMap<NodeIndex, f64> =
raw.into_iter().filter(|&(_, t)| t <= max_cost).collect();
ReachabilityResult {
start,
max_cost,
distances,
}
}
pub fn compute_reachability(
graph: &DiGraph<XmlNode, XmlWay>,
start: NodeIndex,
max_cost: f64,
network_type: NetworkType,
) -> ReachabilityResult {
compute_reachability_with(graph, start, max_cost, |e| {
e.weight.travel_time(network_type)
})
}
impl SpatialGraph {
pub fn reachable_from(
&self,
lat: f64,
lon: f64,
max_time: f64,
network_type: NetworkType,
) -> Option<ReachabilityResult> {
self.reachability(lat, lon, max_time, network_type)
}
pub fn reachability(
&self,
lat: f64,
lon: f64,
max_time: f64,
network_type: NetworkType,
) -> Option<ReachabilityResult> {
let start = self.nearest_node(lat, lon)?;
Some(compute_reachability(
&self.graph,
start,
max_time,
network_type,
))
}
pub async fn reachable_pois(
&self,
lat: f64,
lon: f64,
max_time: f64,
network_type: NetworkType,
) -> Option<Result<Vec<crate::poi::ReachablePoi>, crate::error::OsmGraphError>> {
let result = self.reachability(lat, lon, max_time, network_type)?;
Some(crate::poi::fetch_pois_within_reachability(self, &result).await)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::create_graph;
use crate::graph::{XmlNode, XmlNodeRef, XmlTag, XmlWay};
fn node(id: i64, lat: f64, lon: f64) -> XmlNode {
XmlNode {
id,
lat,
lon,
tags: vec![],
}
}
fn way(node_ids: Vec<i64>, tags: Vec<(&str, &str)>) -> XmlWay {
XmlWay {
id: 1,
nodes: node_ids
.into_iter()
.map(|id| XmlNodeRef { node_id: id })
.collect(),
tags: tags
.into_iter()
.map(|(k, v)| XmlTag {
key: k.into(),
value: v.into(),
})
.collect(),
length: 0.0,
speed_kph: 0.0,
walk_travel_time: 0.0,
bike_travel_time: 0.0,
drive_travel_time: 0.0,
}
}
#[test]
fn budget_excludes_distant_nodes() {
let nodes = vec![node(1, 0.0, 0.0), node(2, 0.0, 0.001), node(3, 0.0, 0.002)];
let w = way(vec![1, 2, 3], vec![("highway", "residential")]);
let g = create_graph(nodes, vec![w], true, false);
let start = g.node_indices().find(|&i| g[i].id == 1).unwrap();
let full = compute_reachability(&g, start, f64::INFINITY, NetworkType::Drive);
assert_eq!(
full.distances.len(),
3,
"all 3 nodes should reach with infinite budget"
);
let tight = compute_reachability(&g, start, 5.0, NetworkType::Drive);
assert!(
tight.distances.len() < 3,
"tight budget should exclude at least one node"
);
assert!(tight.distances.values().all(|&t| t <= 5.0));
}
#[test]
fn custom_cost_closure_controls_distances() {
let nodes = vec![
node(1, 0.0, 0.0),
node(2, 0.0, 0.001),
node(3, 0.0, 0.002),
node(4, 0.0, 0.003),
];
let w = way(vec![1, 2, 3, 4], vec![("highway", "residential")]);
let g = create_graph(nodes, vec![w], true, false);
let start = g.node_indices().find(|&i| g[i].id == 1).unwrap();
let result = compute_reachability_with(&g, start, 100.0, |_| 10.0);
let mut times: Vec<f64> = result.distances.values().copied().collect();
times.sort_by(f64::total_cmp);
assert_eq!(times, vec![0.0, 10.0, 20.0, 30.0]);
}
#[test]
fn closure_can_double_baseline_cost() {
let nodes = vec![node(1, 0.0, 0.0), node(2, 0.0, 0.001), node(3, 0.0, 0.002)];
let w = way(vec![1, 2, 3], vec![("highway", "residential")]);
let g = create_graph(nodes, vec![w], true, false);
let start = g.node_indices().find(|&i| g[i].id == 1).unwrap();
let baseline = compute_reachability(&g, start, f64::INFINITY, NetworkType::Drive);
let doubled = compute_reachability_with(&g, start, f64::INFINITY, |e| {
e.weight.travel_time(NetworkType::Drive) * 2.0
});
for (node, &b) in &baseline.distances {
let d = doubled.distances[node];
assert!(
(d - 2.0 * b).abs() < 1e-9,
"node {:?}: expected 2x baseline",
node
);
}
}
}