use crate::{
CollapseError, CollapsedPath, Costing, MatchError, PredicateCache, Reachable, Solver,
TransitionContext, Trip,
candidate::{CandidateEdge, CandidateId},
costing::{EmissionStrategy, TransitionStrategy},
entity::Transition,
primitives::RoutingContext,
};
use routers_network::{Entry, Metadata, Network};
use log::{debug, info};
use core::cell::RefCell;
use rustc_hash::FxHashMap;
use std::{marker::PhantomData, sync::Arc};
use geo::{Distance, Haversine};
use itertools::Itertools;
use measure_time::debug_time;
use pathfinding::{num_traits::Zero, prelude::*};
pub struct SelectiveForwardSolver<E, M, N>
where
E: Entry,
M: Metadata,
N: Network<E, M>,
{
predicate: Arc<PredicateCache<E, M, N>>,
reachable_hash: RefCell<FxHashMap<(usize, usize), Reachable<E>>>,
_phantom: PhantomData<N>,
}
impl<E, M, N> Default for SelectiveForwardSolver<E, M, N>
where
E: Entry,
M: Metadata,
N: Network<E, M>,
{
fn default() -> Self {
Self {
predicate: Arc::new(PredicateCache::default()),
reachable_hash: RefCell::new(FxHashMap::default()),
_phantom: PhantomData,
}
}
}
impl<E, M, N> SelectiveForwardSolver<E, M, N>
where
E: Entry,
M: Metadata,
N: Network<E, M>,
{
pub fn use_cache(self, cache: Arc<PredicateCache<E, M, N>>) -> Self {
Self {
predicate: cache,
..self
}
}
fn reach<'a, 'b, Emmis, Trans>(
&'b self,
transition: &'b Transition<'b, Emmis, Trans, E, M, N>,
context: &'b RoutingContext<'b, E, M, N>,
(start, end): (CandidateId, CandidateId),
source: &CandidateId,
) -> Vec<(CandidateId, CandidateEdge)>
where
Emmis: EmissionStrategy + Send + Sync,
Trans: TransitionStrategy<E, M, N> + Send + Sync,
'b: 'a,
{
let successors = transition.candidates.next_layer(source);
if *source == start {
return successors
.into_iter()
.map(|candidate| (candidate, CandidateEdge::zero()))
.collect::<Vec<_>>();
}
if successors.contains(&end) {
debug!("End-Successors: {successors:?}");
return vec![(end, CandidateEdge::zero())];
}
let reachable = self
.reachable(context, source, successors.as_slice())
.unwrap_or_default();
{
debug_time!("Format Reachable Elements");
let mut hash = self.reachable_hash.borrow_mut();
reachable
.into_iter()
.filter_map(move |reachable| {
let path_vec = reachable.path_nodes().collect_vec();
let optimal_path = Trip::new_with_map(context.map, &path_vec);
let source = context.candidate(&reachable.source)?;
let target = context.candidate(&reachable.target)?;
let sl = transition.layers.layers.get(source.location.layer_id)?;
let tl = transition.layers.layers.get(target.location.layer_id)?;
let layer_width = Haversine.distance(sl.origin, tl.origin);
let transition_cost = transition.heuristics.transition(TransitionContext {
map_path: &path_vec,
requested_resolution_method: reachable.resolution_method,
source_candidate: &reachable.source,
target_candidate: &reachable.target,
routing_context: context,
layer_width,
optimal_path,
});
let transition = (transition_cost as f64 * 0.6) as u32;
let emission = (target.emission as f64 * 0.4) as u32;
let cost = emission.saturating_add(transition);
hash.insert(reachable.hash(), reachable.clone());
Some((reachable.target, CandidateEdge::new(cost)))
})
.collect::<Vec<_>>()
}
}
fn reachable<'a>(
&self,
ctx: &'a RoutingContext<'a, E, M, N>,
source: &CandidateId,
targets: &'a [CandidateId],
) -> Option<Vec<Reachable<E>>> {
let source_candidate = ctx.candidate(source)?;
let predicate_map = {
debug_time!("query predicate for {source:?}");
self.predicate.query(ctx, source_candidate.edge.target)
};
let reachable = {
debug_time!("predicates {source:?} -> reachable");
targets
.iter()
.filter_map(|target| {
let candidate = ctx.candidate(target)?;
'stmt: {
if candidate.edge.id.index() == source_candidate.edge.id.index() {
let common_source =
candidate.edge.source == source_candidate.edge.source;
let common_target =
candidate.edge.target == source_candidate.edge.target;
let tracking_forward = common_source && common_target;
let source_percentage = source_candidate.percentage(ctx.map)?;
let target_percentage = candidate.percentage(ctx.map)?;
return if tracking_forward && source_percentage <= target_percentage {
Some(Reachable::new(*source, *target, vec![]).distance_only())
} else {
break 'stmt;
};
}
}
let path_to_target = Self::path_builder(
&candidate.edge.source,
&source_candidate.edge.target,
&predicate_map,
)?;
let path = path_to_target
.windows(2)
.filter_map(|pair| {
if let [a, b] = pair {
return ctx.edge(a, b);
}
None
})
.collect::<Vec<_>>();
Some(Reachable::new(*source, *target, path))
})
.collect::<Vec<_>>()
};
Some(reachable)
}
}
impl<N, E, M> Solver<E, M, N> for SelectiveForwardSolver<E, M, N>
where
E: Entry,
M: Metadata,
N: Network<E, M>,
{
fn solve<Emmis, Trans>(
&self,
mut transition: Transition<Emmis, Trans, E, M, N>,
runtime: &M::Runtime,
) -> Result<CollapsedPath<E>, MatchError>
where
Emmis: EmissionStrategy + Send + Sync,
Trans: TransitionStrategy<E, M, N> + Send + Sync,
{
let (start, end) = {
transition
.candidates
.attach_ends(&transition.layers)
.map_err(MatchError::EndAttachFailure)?
};
debug!("Attached Ends");
transition.candidates.weave(&transition.layers);
debug!("Weaved all candidate layers.");
info!("Solving: Start={start:?}. End={end:?}. ");
let context = transition.context(runtime);
let Some((path, cost)) = ({
debug_time!("Solved transition graph");
astar(
&start,
|source| self.reach(&transition, &context, (start, end), source),
|_| CandidateEdge::zero(),
|node| *node == end,
)
}) else {
return Err(MatchError::CollapseFailure(CollapseError::NoPathFound));
};
info!("Total cost of solve: {}", cost.weight);
let reached = path
.windows(2)
.filter_map(|nodes| {
if let [a, b] = nodes {
self.reachable_hash
.borrow()
.get(&(a.index(), b.index()))
.cloned()
} else {
None
}
})
.collect::<Vec<_>>();
Ok(CollapsedPath::new(
cost.weight,
reached,
path,
transition.candidates,
))
}
}