use crate::adjacencies::{Adjacencies, InEdges, Neighbors, OutEdges};
use crate::collections::{BinHeap, ItemMap, ItemPriQueue};
pub use crate::search::astar::AStarHeuristic as Heuristic;
use crate::search::path_from_incomings;
use crate::traits::{Digraph, Graph, GraphType};
use either::Either::{self, Left, Right};
use num_traits::Zero;
use std::cmp::Ordering;
use std::collections::HashMap;
use std::hash::Hash;
use std::ops::{Add, Neg, Sub};
pub use super::astar::default_data;
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
pub enum Direction<E> {
Forward(E),
Backward(E),
}
#[derive(Clone, Copy, Debug)]
pub struct BiData<E, D, H> {
edge: E,
distance: D,
lower: H,
}
impl<E, D, H> PartialEq for BiData<E, D, H>
where
D: PartialEq,
{
fn eq(&self, data: &Self) -> bool {
self.distance.eq(&data.distance)
}
}
impl<E, D, H> PartialOrd for BiData<E, D, H>
where
D: PartialOrd + Clone,
H: Add<D, Output = D> + Clone,
{
fn partial_cmp(&self, data: &Self) -> Option<Ordering> {
(self.lower.clone() + self.distance.clone()).partial_cmp(&(data.lower.clone() + data.distance.clone()))
}
}
struct Meet<N, E, D> {
node: N,
edge: E,
fwd_distance: D,
total_distance: D,
}
pub struct BiAStar<'a, Aout, Ain, D, W, M, P, H>
where
Aout: Adjacencies<'a>,
Ain: Adjacencies<'a, Node = Aout::Node, Edge = Aout::Edge>,
M: ItemMap<Direction<Aout::Node>, Either<P::Item, D>>,
P: ItemPriQueue<Direction<Aout::Node>, BiData<Aout::Edge, D, H::Result>>,
D: Copy,
W: Fn(Aout::Edge) -> D,
H: Heuristic<Aout::Node>,
H::Result: Copy,
{
adjout: Aout,
adjin: Ain,
nodes: M,
pqueue: P,
weights: W,
heur: H,
meet: Option<Meet<Aout::Node, Aout::Edge, D>>,
top_fwd: D,
top_bwd: D,
}
pub type DefaultMap<'a, A, D, H> = HashMap<
Direction<<A as GraphType<'a>>::Node>,
Either<
<BinHeap<Direction<<A as GraphType<'a>>::Node>, BiData<<A as GraphType<'a>>::Edge, D, H>> as ItemPriQueue<
Direction<<A as GraphType<'a>>::Node>,
BiData<<A as GraphType<'a>>::Edge, D, H>,
>>::Item,
D,
>,
>;
pub type DefaultPriQueue<'a, A, D, H, ID = u32> =
BinHeap<Direction<<A as GraphType<'a>>::Node>, BiData<<A as GraphType<'a>>::Edge, D, H>, ID>;
pub type BiAStarDefault<'a, Aout, Ain, D, W, H> = BiAStar<
'a,
Aout,
Ain,
D,
W,
DefaultMap<'a, Aout, D, <H as Heuristic<<Aout as GraphType<'a>>::Node>>::Result>,
DefaultPriQueue<'a, Aout, D, <H as Heuristic<<Aout as GraphType<'a>>::Node>>::Result>,
H,
>;
pub fn start<'a, Aout, Ain, D, W, H>(
adjout: Aout,
adjin: Ain,
src: Aout::Node,
snk: Aout::Node,
weights: W,
heur: H,
) -> BiAStarDefault<'a, Aout, Ain, D, W, H>
where
Aout: Adjacencies<'a>,
Aout::Node: Hash + Eq,
Ain: Adjacencies<'a, Node = Aout::Node, Edge = Aout::Edge>,
D: Copy + Zero + PartialOrd,
W: Fn(Aout::Edge) -> D,
H: Heuristic<Aout::Node>,
H::Result: Add<D, Output = D> + Add<H::Result, Output = H::Result> + Neg<Output = H::Result>,
{
start_with_data(adjout, adjin, src, snk, weights, heur, default_data())
}
pub fn start_with_data<'a, Aout, Ain, D, W, H, M, P>(
adjout: Aout,
adjin: Ain,
src: Aout::Node,
snk: Aout::Node,
weights: W,
heur: H,
data: (M, P),
) -> BiAStar<'a, Aout, Ain, D, W, M, P, H>
where
Aout: Adjacencies<'a>,
Ain: Adjacencies<'a, Node = Aout::Node, Edge = Aout::Edge>,
D: Copy + PartialOrd + Zero,
W: Fn(Aout::Edge) -> D,
M: ItemMap<Direction<Aout::Node>, Either<P::Item, D>>,
P: ItemPriQueue<Direction<Aout::Node>, BiData<Aout::Edge, D, H::Result>>,
H: Heuristic<Aout::Node>,
H::Result: Add<D, Output = D> + Add<H::Result, Output = H::Result> + Neg<Output = H::Result>,
{
let (mut nodes, mut pqueue) = data;
pqueue.clear();
nodes.clear();
if src == snk {
return BiAStar {
adjout,
adjin,
nodes,
pqueue,
weights,
heur,
meet: None,
top_fwd: D::zero(),
top_bwd: D::zero(),
};
}
nodes.insert(Direction::Forward(src), Right(D::zero()));
nodes.insert(Direction::Backward(snk), Right(D::zero()));
for (e, v) in adjout.neighs(src) {
let dir_v = Direction::Forward(v);
let d = (weights)(e);
match nodes.get_mut(dir_v) {
Some(Left(item_v)) => {
let (distance, lower) = {
let data = pqueue.value(item_v);
(data.distance, data.lower)
};
if d < distance {
pqueue.decrease_key(
item_v,
BiData {
edge: e,
distance: d,
lower,
},
);
}
}
None => {
let item_v = pqueue.push(
dir_v,
BiData {
edge: e,
distance: d,
lower: heur.call(v),
},
);
nodes.insert(dir_v, Left(item_v));
}
_ => (),
}
}
let mut meet: Option<Meet<_, _, _>> = None;
for (e, v) in adjin.neighs(snk) {
let dir_v = Direction::Backward(v);
let d = (weights)(e);
if v == src {
if meet.as_ref().map(|m| d < m.total_distance).unwrap_or(true) {
meet = Some(Meet {
node: snk,
edge: e,
fwd_distance: d,
total_distance: d,
});
}
}
match nodes.get_mut(dir_v) {
Some(Left(item_v)) => {
let (distance, lower) = {
let data = pqueue.value(item_v);
(data.distance, data.lower)
};
if d < distance {
pqueue.decrease_key(
item_v,
BiData {
edge: e,
distance: d,
lower,
},
);
}
}
None => {
let item_v = pqueue.push(
dir_v,
BiData {
edge: e,
distance: d,
lower: -heur.call(v),
},
);
nodes.insert(dir_v, Left(item_v));
}
_ => (),
}
}
BiAStar {
adjout,
adjin,
nodes,
pqueue,
weights,
heur,
meet,
top_fwd: D::zero(),
top_bwd: D::zero(),
}
}
impl<'a, Aout, Ain, D, W, M, P, H> Iterator for BiAStar<'a, Aout, Ain, D, W, M, P, H>
where
Aout: Adjacencies<'a>,
Ain: Adjacencies<'a, Node = Aout::Node, Edge = Aout::Edge>,
D: Copy + PartialOrd + Add<D, Output = D> + Sub<D, Output = D> + Zero,
W: Fn(Aout::Edge) -> D,
M: ItemMap<Direction<Aout::Node>, Either<P::Item, D>>,
P: ItemPriQueue<Direction<Aout::Node>, BiData<Aout::Edge, D, H::Result>>,
H: Heuristic<Aout::Node>,
H::Result: Add<D, Output = D> + Add<H::Result, Output = H::Result> + Neg<Output = H::Result>,
{
type Item = (Aout::Node, Direction<Aout::Edge>, D);
fn next(&mut self) -> Option<Self::Item> {
if let Some((dir_u, data)) = self.pqueue.pop_min() {
self.nodes.insert_or_replace(dir_u, Right(data.distance));
let (distance, edge) = (data.distance, data.edge);
if let Direction::Forward(_) = dir_u {
self.top_fwd = data.lower + data.distance;
} else {
self.top_bwd = data.lower + data.distance;
};
if self
.meet
.as_ref()
.map(|m| m.total_distance <= self.top_fwd + self.top_bwd)
.unwrap_or(false)
{
self.pqueue.clear();
let meet = self.meet.as_ref().unwrap();
return Some((meet.node, Direction::Forward(meet.edge), meet.fwd_distance));
}
match dir_u {
Direction::Forward(u) => {
for (e, v) in self.adjout.neighs(u) {
let dir_v = Direction::Forward(v);
let edge_weight = (self.weights)(e);
let d = distance + edge_weight;
if let Some(Right(rdistance)) = self.nodes.get(Direction::Backward(v)) {
let new_dist = *rdistance + distance + edge_weight;
if self.meet.as_ref().map(|m| new_dist < m.total_distance).unwrap_or(true) {
self.meet = Some(Meet {
node: v,
edge: e,
fwd_distance: d,
total_distance: new_dist,
});
}
}
match self.nodes.get_mut(dir_v) {
Some(Left(item_v)) => {
let (distance, lower) = {
let data = self.pqueue.value(item_v);
(data.distance, data.lower)
};
if d < distance {
self.pqueue.decrease_key(
item_v,
BiData {
edge: e,
distance: d,
lower,
},
);
}
}
None => {
let item_v = self.pqueue.push(
dir_v,
BiData {
edge: e,
distance: d,
lower: self.heur.call(v),
},
);
self.nodes.insert(dir_v, Left(item_v));
}
_ => (),
}
}
}
Direction::Backward(u) => {
for (e, v) in self.adjin.neighs(u) {
assert!((-self.heur.call(v) + self.heur.call(u)) + (self.weights)(e) >= D::zero());
let dir_v = Direction::Backward(v);
let edge_weight = (self.weights)(e);
let d = distance + edge_weight;
if let Some(Right(rdistance)) = self.nodes.get(Direction::Forward(v)) {
let new_dist = *rdistance + distance + edge_weight;
if self.meet.as_ref().map(|m| new_dist < m.total_distance).unwrap_or(true) {
self.meet = Some(Meet {
node: u,
edge: e,
fwd_distance: *rdistance + edge_weight,
total_distance: new_dist,
});
}
}
match self.nodes.get_mut(dir_v) {
Some(Left(item_v)) => {
let (distance, lower) = {
let data = self.pqueue.value(item_v);
(data.distance, data.lower)
};
if d < distance {
self.pqueue.decrease_key(
item_v,
BiData {
edge: e,
distance: d,
lower,
},
);
}
}
None => {
let item_v = self.pqueue.push(
dir_v,
BiData {
edge: e,
distance: d,
lower: -self.heur.call(v),
},
);
self.nodes.insert(dir_v, Left(item_v));
}
_ => (),
}
}
}
}
match dir_u {
Direction::Forward(u) => Some((u, Direction::Forward(edge), distance)),
Direction::Backward(u) => Some((u, Direction::Backward(edge), distance)),
}
} else {
None
}
}
}
impl<'a, Aout, Ain, D, W, M, P, H> BiAStar<'a, Aout, Ain, D, W, M, P, H>
where
Aout: Adjacencies<'a>,
Ain: Adjacencies<'a, Node = Aout::Node, Edge = Aout::Edge>,
D: Copy + PartialOrd + Add<D, Output = D> + Sub<D, Output = D>,
W: Fn(Aout::Edge) -> D,
M: ItemMap<Direction<Aout::Node>, Either<P::Item, D>>,
P: ItemPriQueue<Direction<Aout::Node>, BiData<Aout::Edge, D, H::Result>>,
H: Heuristic<Aout::Node>,
H::Result: Add<D, Output = D> + Neg<Output = H::Result>,
{
fn meet(&self) -> Option<Aout::Node> {
self.meet.as_ref().map(|m| m.node)
}
fn value(&self) -> Option<D> {
self.meet.as_ref().map(|m| m.total_distance)
}
}
pub fn start_undirected<'a, G, D, W, H>(
g: &'a G,
src: G::Node,
snk: G::Node,
weights: W,
heur: H,
) -> BiAStarDefault<'a, Neighbors<'a, G>, Neighbors<'a, G>, D, W, H>
where
G: Graph<'a>,
G::Node: Hash,
D: Copy + PartialOrd + Zero,
W: Fn(G::Edge) -> D,
H: Heuristic<G::Node>,
H::Result: Add<D, Output = D> + Add<H::Result, Output = H::Result> + Neg<Output = H::Result>,
{
start(Neighbors(g), Neighbors(g), src, snk, weights, heur)
}
pub fn find_undirected_path<'a, G, D, W, H>(
g: &'a G,
src: G::Node,
snk: G::Node,
weights: W,
heur: H,
) -> Option<(Vec<G::Edge>, D)>
where
G: Graph<'a>,
G::Node: Hash,
D: Copy + PartialOrd + Zero + Add<D, Output = D> + Sub<D, Output = D>,
W: Fn(G::Edge) -> D,
H: Heuristic<G::Node>,
H::Result: Add<D, Output = D> + Add<H::Result, Output = H::Result> + Neg<Output = H::Result>,
{
if src == snk {
return Some((vec![], D::zero()));
}
let mut incoming_edges = HashMap::new();
let mut it = start_undirected(g, src, snk, weights, heur);
while let Some((u, dir_e, _)) = it.next() {
match dir_e {
Direction::Forward(e) => incoming_edges.insert(Direction::Forward(u), e),
Direction::Backward(e) => incoming_edges.insert(Direction::Backward(u), e),
};
}
it.meet().map(|meet| {
let mut path = path_from_incomings(meet, |u| {
incoming_edges
.get(&Direction::Forward(u))
.map(|&e| (e, g.enodes(e)))
.map(|(e, (v, w))| (e, if v == u { w } else { v }))
})
.collect::<Vec<_>>();
path.reverse();
path.extend(path_from_incomings(meet, |u| {
incoming_edges
.get(&Direction::Backward(u))
.map(|&e| (e, g.enodes(e)))
.map(|(e, (v, w))| (e, if v == u { w } else { v }))
}));
(path, it.value().unwrap())
})
}
pub fn start_directed<'a, G, D, W, H>(
g: &'a G,
src: G::Node,
snk: G::Node,
weights: W,
heur: H,
) -> BiAStarDefault<'a, OutEdges<'a, G>, InEdges<'a, G>, D, W, H>
where
G: Digraph<'a>,
G::Node: Hash,
D: Copy + PartialOrd + Zero,
W: Fn(G::Edge) -> D,
H: Heuristic<G::Node>,
H::Result: Add<D, Output = D> + Add<H::Result, Output = H::Result> + Neg<Output = H::Result>,
{
start(OutEdges(g), InEdges(g), src, snk, weights, heur)
}
pub fn find_directed_path<'a, G, D, W, H>(
g: &'a G,
src: G::Node,
snk: G::Node,
weights: W,
heur: H,
) -> Option<(Vec<G::Edge>, D)>
where
G: Digraph<'a>,
G::Node: Hash,
D: Copy + PartialOrd + Zero + Add<D, Output = D> + Sub<D, Output = D>,
W: Fn(G::Edge) -> D,
H: Heuristic<G::Node>,
H::Result: Add<D, Output = D> + Add<H::Result, Output = H::Result> + Neg<Output = H::Result>,
{
if src == snk {
return Some((vec![], D::zero()));
}
let mut incoming_edges = HashMap::new();
let mut it = start_directed(g, src, snk, weights, heur);
while let Some((u, dir_e, _)) = it.next() {
match dir_e {
Direction::Forward(e) => incoming_edges.insert(Direction::Forward(u), e),
Direction::Backward(e) => incoming_edges.insert(Direction::Backward(u), e),
};
}
it.meet().map(|meet| {
let mut path = path_from_incomings(meet, |u| {
incoming_edges.get(&Direction::Forward(u)).map(|&e| (e, g.src(e)))
})
.collect::<Vec<_>>();
path.reverse();
path.extend(path_from_incomings(meet, |u| {
incoming_edges.get(&Direction::Backward(u)).map(|&e| (e, g.snk(e)))
}));
(path, it.value().unwrap())
})
}
#[test]
fn test_biastar() {
use crate::search::biastar;
use crate::string::{from_ascii, Data};
use crate::traits::*;
use crate::LinkedListGraph;
let Data {
graph: g,
weights,
nodes,
} = from_ascii::<LinkedListGraph>(
r"
*--2--*--2--*--2--*--2--*--2--*--2--*--2--*--2--*--2--*
| | | | | | | | | |
2 2 2 2 2 2 2 2 2 2
| | | | | | | | | |
*--2--*--2--*--2--*--2--*--3--e--2--f--2--t--2--*--2--*
| | | | | | | | | |
2 2 2 2 2 2 3 2 2 2
| | | | | | | | | |
*--2--*--2--*--3--*--3--c--2--d--2--*--3--*--2--*--2--*
| | | | | | | | | |
2 2 2 2 2 3 2 2 2 2
| | | | | | | | | |
*--2--*--2--s--2--a--2--b--2--*--2--*--3--*--2--*--2--*
| | | | | | | | | |
2 2 2 2 2 2 2 2 2 2
| | | | | | | | | |
*--2--*--2--*--2--*--2--*--2--*--2--*--2--*--2--*--2--*
",
)
.unwrap();
let s = nodes[&'s'];
let t = nodes[&'t'];
let coords = |u| ((g.node_id(u) % 10) as isize, (g.node_id(u) / 10) as isize);
let (xs, ys) = coords(s);
let (xt, yt) = coords(t);
let manh_heur = |u| {
let (x, y) = coords(u);
0.5 * (((x - xt).abs() + (y - yt).abs()) as f64 - ((x - xs).abs() + (y - ys).abs()) as f64)
};
let (path, dist) = biastar::find_undirected_path(&g, s, t, |e| weights[e.index()] as f64, manh_heur).unwrap();
assert!((dist - 14.0).abs() < 1e-6);
let mut pathnodes = vec![s];
for e in path {
let uv = g.enodes(e);
if uv.0 == *pathnodes.last().unwrap() {
pathnodes.push(uv.1);
} else {
pathnodes.push(uv.0);
}
}
assert_eq!(pathnodes, "sabcdeft".chars().map(|c| nodes[&c]).collect::<Vec<_>>());
for (v, _, _) in biastar::start_undirected(&g, s, t, |e| weights[e.index()] as f64, manh_heur) {
let (x, y) = coords(v);
assert!(x + 1 >= xs && x <= xt + 1 && y + 1 >= yt && y <= ys + 1);
}
}
#[test]
fn test_biastar_correct() {
use crate::search::biastar;
use crate::string::{from_ascii, Data};
use crate::traits::*;
use crate::LinkedListGraph;
let Data {
graph: g,
weights,
nodes,
} = from_ascii::<LinkedListGraph>(
r"
b--11--c---1--t
| |
1 8
| |
s--1--a--10--*
",
)
.unwrap();
let s = nodes[&'s'];
let t = nodes[&'t'];
let a = nodes[&'a'];
let b = nodes[&'b'];
let c = nodes[&'c'];
let (path, dist) = biastar::find_undirected_path(&g, s, t, |e| weights[e.index()] as isize, |_| 0).unwrap();
let length: usize = path.iter().map(|e| weights[e.index()]).sum();
assert_eq!(dist, 14);
assert_eq!(length, 14);
let path = path
.into_iter()
.map(|e| g.enodes(e))
.map(|(u, v)| if g.node_id(u) < g.node_id(v) { (u, v) } else { (v, u) })
.collect::<Vec<_>>();
assert_eq!(path, vec![(s, a), (b, a), (b, c), (c, t)]);
}