use std::{
cmp::Reverse,
collections::{
hash_map::Entry::{Occupied, Vacant},
BinaryHeap,
},
fmt::Debug,
hash::Hash,
iter::Sum,
marker::PhantomData,
ops::Add,
ops::Sub,
sync::Arc,
};
use fxhash::{FxHashMap, FxHashSet};
use parking_lot::Mutex;
use crate::{Heuristic, LimitValues, State, Task, TransitionSystem};
use super::SearchNode;
pub struct ReverseResumableAStar<TS, S, A, C, DC, H>
where
TS: TransitionSystem<S, A, C, DC>,
S: Debug + State + Hash + Eq + Clone,
C: Eq
+ PartialOrd
+ Ord
+ Add<DC, Output = C>
+ Sub<C, Output = DC>
+ Copy
+ Default
+ LimitValues,
DC: Copy,
H: Heuristic<TS, S, A, C, DC>,
{
transition_system: Arc<TS>,
task: Arc<Task<S, C>>,
heuristic: H,
data: Mutex<RraData<S, C, DC>>,
_phantom: PhantomData<A>,
}
impl<TS, S, A, C, DC, H> Heuristic<TS, S, A, C, DC> for ReverseResumableAStar<TS, S, A, C, DC, H>
where
TS: TransitionSystem<S, A, C, DC>,
S: Debug + State + Hash + Eq + Clone,
C: Eq
+ PartialOrd
+ Ord
+ Add<DC, Output = C>
+ Sub<C, Output = DC>
+ Copy
+ Default
+ LimitValues,
DC: Copy,
H: Heuristic<TS, S, A, C, DC>,
{
fn get_heuristic(&self, state: &S) -> Option<DC> {
self.find_path(state)
}
}
impl<TS, S, A, C, DC, H> ReverseResumableAStar<TS, S, A, C, DC, H>
where
TS: TransitionSystem<S, A, C, DC>,
S: Debug + State + Hash + Eq + Clone,
C: Eq
+ PartialOrd
+ Ord
+ Add<DC, Output = C>
+ Sub<C, Output = DC>
+ Copy
+ Default
+ LimitValues,
DC: Copy,
H: Heuristic<TS, S, A, C, DC>,
{
pub fn new(transition_system: Arc<TS>, task: Arc<Task<S, C>>, heuristic: H) -> Self
where
Self: Sized,
{
let mut rra = ReverseResumableAStar {
transition_system: transition_system.clone(),
task: task.clone(),
heuristic,
data: Mutex::new(RraData::default()),
_phantom: PhantomData,
};
rra.init();
rra
}
fn init(&mut self) {
let goal_node = SearchNode {
state: Arc::new(self.task.goal_state.clone()),
cost: self.task.initial_cost,
heuristic: C::default() - C::default(),
};
let mut data = self.data.lock();
data.distance
.insert(goal_node.state.clone(), goal_node.cost);
data.queue.push(Reverse(goal_node));
}
fn find_path(&self, state: &S) -> Option<DC> {
let mut data = self.data.lock();
if data.closed.contains(state) {
data.stats.cached_query += 1;
return Some(data.distance[state] - self.task.initial_cost);
}
data.stats.new_query += 1;
while let Some(Reverse(current)) = data.queue.pop() {
if current.cost > data.distance[¤t.state] {
continue;
}
data.closed.insert(current.state.clone());
if *current.state == *state {
let cost = current.cost - self.task.initial_cost;
data.queue.push(Reverse(current));
return Some(cost);
}
for action in self.transition_system.reverse_actions_from(¤t.state) {
let successor_state = Arc::new(
self.transition_system
.reverse_transition(¤t.state, action),
);
let successor_cost = current.cost
+ self
.transition_system
.reverse_transition_cost(¤t.state, action);
let improved = match data.distance.entry(successor_state.clone()) {
Occupied(mut e) => {
if successor_cost < *e.get() {
*e.get_mut() = successor_cost;
true
} else {
false
}
}
Vacant(e) => {
e.insert(successor_cost);
true
}
};
if improved {
if let Some(heuristic) = self.heuristic.get_heuristic(&successor_state) {
data.queue.push(Reverse(SearchNode {
state: successor_state,
cost: successor_cost,
heuristic,
}));
}
}
}
data.stats.expanded += 1;
}
None
}
pub fn get_stats(&self) -> RraStats {
self.data.lock().stats
}
}
struct RraData<S, C, DC>
where
C: Copy + Ord + Add<DC, Output = C>,
DC: Copy,
{
queue: BinaryHeap<Reverse<SearchNode<S, C, DC>>>,
distance: FxHashMap<Arc<S>, C>,
closed: FxHashSet<Arc<S>>,
stats: RraStats,
}
impl<S, C, DC> Default for RraData<S, C, DC>
where
C: Copy + Ord + Add<DC, Output = C>,
DC: Copy,
{
fn default() -> Self {
Self {
queue: BinaryHeap::new(),
distance: FxHashMap::default(),
closed: FxHashSet::default(),
stats: RraStats::default(),
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct RraStats {
pub new_query: usize,
pub cached_query: usize,
pub expanded: usize,
}
impl Sum for RraStats {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::default(), |a, b| Self {
new_query: a.new_query + b.new_query,
cached_query: a.cached_query + b.cached_query,
expanded: a.expanded + b.expanded,
})
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use ordered_float::OrderedFloat;
use crate::{
simple_graph, GraphNodeId, Heuristic, ReverseResumableAStar, RraStats, SimpleHeuristic,
SimpleState, SimpleWorld, Task,
};
#[test]
fn test_simple() {
let size = 10;
let graph = simple_graph(size);
let transition_system = Arc::new(SimpleWorld::new(graph, 0.4));
let task = Arc::new(Task::new(
SimpleState(GraphNodeId(0)),
SimpleState(GraphNodeId(size * size - 1)),
OrderedFloat(0.0),
));
let heuristic = ReverseResumableAStar::new(
transition_system.clone(),
task.clone(),
SimpleHeuristic::new(transition_system, Arc::new(task.reverse())),
);
for x in 0..size {
for y in 0..size {
assert_eq!(
heuristic
.get_heuristic(&SimpleState(GraphNodeId(x + y * size)))
.unwrap(),
OrderedFloat(((size - x - 1) + (size - y - 1)) as f64)
);
}
}
}
#[test]
fn test_caching() {
let size = 10;
let graph = simple_graph(size);
let transition_system = Arc::new(SimpleWorld::new(graph, 0.4));
let task = Arc::new(Task::new(
SimpleState(GraphNodeId(0)),
SimpleState(GraphNodeId(size * size - 1)),
OrderedFloat(0.0),
));
let heuristic = ReverseResumableAStar::new(
transition_system.clone(),
task.clone(),
SimpleHeuristic::new(transition_system, Arc::new(task.reverse())),
);
let initial = heuristic.get_stats();
heuristic.get_heuristic(&SimpleState(GraphNodeId(0)));
let after_one_query = heuristic.get_stats();
heuristic.get_heuristic(&SimpleState(GraphNodeId(0)));
let after_same_query = heuristic.get_stats();
assert_eq!(
initial,
RraStats {
new_query: 0,
cached_query: 0,
expanded: 0
}
);
assert_eq!(after_one_query.new_query, 1);
assert_eq!(after_one_query.cached_query, 0);
assert_eq!(after_same_query.new_query, 1);
assert_eq!(after_same_query.cached_query, 1);
assert_eq!(after_same_query.expanded, after_one_query.expanded);
}
}