use std::{
hash::Hash,
marker::PhantomData,
ops::{Add, Sub},
sync::Arc,
};
use crate::{State, Task, TransitionSystem};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Interval<C, DC>
where
C: PartialEq + Eq + PartialOrd + Ord + LimitValues + Sub<C, Output = DC> + Copy,
{
pub start: C,
pub end: C,
}
impl<C, DC> Default for Interval<C, DC>
where
C: PartialEq + Eq + PartialOrd + Ord + LimitValues + Sub<C, Output = DC> + Copy,
{
fn default() -> Self {
Self::new(C::min_value(), C::max_value())
}
}
impl<C, DC> Interval<C, DC>
where
C: PartialEq + Eq + PartialOrd + Ord + LimitValues + Sub<C, Output = DC> + Copy,
{
pub fn new(start: C, end: C) -> Self {
Self { start, end }
}
pub fn overlaps(&self, other: &Self) -> bool {
self.start <= other.end && other.start <= self.end
}
pub fn contains(&self, other: &Self) -> bool {
self.start <= other.start && other.end <= self.end
}
pub fn length(&self) -> DC {
self.end - self.start
}
}
#[derive(Debug, Clone, Copy)]
pub struct Action<A, DC> {
pub action: Option<A>,
pub cost: DC,
}
impl<A, DC> Action<A, DC> {
pub fn new(action: A, cost: DC) -> Self {
Self {
action: Some(action),
cost,
}
}
pub fn wait(cost: DC) -> Self {
Self { action: None, cost }
}
}
#[derive(Debug, Clone)]
pub struct Solution<S, A, C, DC>
where
C: Default,
{
pub cost: C,
pub steps: Vec<(S, C)>,
pub actions: Vec<Action<A, DC>>,
}
impl<S, A, C, DC> Default for Solution<S, A, C, DC>
where
C: Default,
{
fn default() -> Self {
Self {
cost: C::default(),
steps: Default::default(),
actions: Default::default(),
}
}
}
pub trait Heuristic<TS, S, A, C, DC>
where
TS: TransitionSystem<S, A, C, DC>,
S: Hash + Eq + Clone,
C: Eq
+ PartialOrd
+ Ord
+ Add<DC, Output = C>
+ Sub<C, Output = DC>
+ Copy
+ Default
+ LimitValues,
{
fn get_heuristic(&self, state: &S) -> Option<DC>;
}
pub trait HeuristicBuilder<TS, S, A, C, DC>
where
TS: TransitionSystem<S, A, C, DC>,
S: State + Hash + Eq + Clone,
C: Eq
+ PartialOrd
+ Ord
+ Add<DC, Output = C>
+ Sub<C, Output = DC>
+ Copy
+ Default
+ LimitValues,
{
fn build(transition_system: Arc<TS>, task: Arc<Task<S, C>>) -> Self;
}
pub struct DifferentialHeuristic<TS, S, A, C, DC, H>
where
TS: TransitionSystem<S, A, C, DC>,
S: State + Hash + Eq + Clone,
C: Ord + Add<DC, Output = C> + Sub<C, Output = DC> + Copy + Default + LimitValues,
H: Heuristic<TS, S, A, C, DC>,
{
task: Arc<Task<S, C>>,
heuristic_to_pivots: Arc<Vec<Arc<H>>>,
task_heuristic: Option<usize>,
_phantom: PhantomData<(TS, S, A, DC)>,
}
impl<TS, S, A, C, DC, H> DifferentialHeuristic<TS, S, A, C, DC, H>
where
TS: TransitionSystem<S, A, C, DC>,
S: State + Hash + Eq + Clone,
C: Ord + Add<DC, Output = C> + Sub<C, Output = DC> + Copy + Default + LimitValues,
DC: Ord + Sub<DC, Output = DC> + Copy,
H: Heuristic<TS, S, A, C, DC>,
{
pub fn new(
task: Arc<Task<S, C>>,
pivots: Arc<Vec<S>>,
heuristic_to_pivots: Arc<Vec<Arc<H>>>,
) -> Self {
DifferentialHeuristic {
task_heuristic: pivots
.iter()
.position(|pivot| pivot.is_equivalent(&task.goal_state)),
task,
heuristic_to_pivots,
_phantom: PhantomData,
}
}
}
impl<TS, S, A, C, DC, H> Heuristic<TS, S, A, C, DC> for DifferentialHeuristic<TS, S, A, C, DC, H>
where
TS: TransitionSystem<S, A, C, DC>,
S: State + Hash + Eq + Clone,
C: Ord + Add<DC, Output = C> + Sub<C, Output = DC> + Copy + Default + LimitValues,
DC: Ord + Sub<DC, Output = DC> + Copy,
H: Heuristic<TS, S, A, C, DC>,
{
fn get_heuristic(&self, state: &S) -> Option<DC> {
if let Some(task_heuristic) = self.task_heuristic {
self.heuristic_to_pivots[task_heuristic].get_heuristic(state)
} else {
let mut heuristic = C::default() - C::default();
for heuristic_to_pivot in self.heuristic_to_pivots.iter() {
if let (Some(h1), Some(h2)) = (
heuristic_to_pivot.get_heuristic(state),
heuristic_to_pivot.get_heuristic(&self.task.goal_state),
) {
heuristic = heuristic.max(h2 - h1).max(h1 - h2)
}
}
Some(heuristic)
}
}
}
pub trait LimitValues {
fn min_value() -> Self;
fn max_value() -> Self;
}