use thiserror::Error;
pub trait CostMetric: std::fmt::Debug {
fn delta_cost(
&self,
from: (usize, usize),
to: (usize, usize),
periodic: bool,
width: usize,
height: usize,
diagonal: bool,
) -> f64;
}
#[derive(Debug, Clone)]
pub struct DirectDistance {
pub cost_cardinal: f64,
pub cost_diagonal: f64,
}
impl DirectDistance {
pub fn new() -> Self {
Self {
cost_cardinal: 1.0,
cost_diagonal: std::f64::consts::SQRT_2,
}
}
pub fn with_costs(cost_cardinal: f64, cost_diagonal: f64) -> Self {
Self {
cost_cardinal,
cost_diagonal,
}
}
}
impl Default for DirectDistance {
fn default() -> Self {
Self::new()
}
}
impl CostMetric for DirectDistance {
fn delta_cost(
&self,
from: (usize, usize),
to: (usize, usize),
periodic: bool,
width: usize,
height: usize,
diagonal: bool,
) -> f64 {
let (dx, dy) = delta_with_periodic(from, to, periodic, width, height);
if diagonal {
let min_d = dx.min(dy) as f64;
let max_d = dx.max(dy) as f64;
self.cost_diagonal * min_d + self.cost_cardinal * (max_d - min_d)
} else {
self.cost_cardinal * (dx + dy) as f64
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MaxDistance;
impl CostMetric for MaxDistance {
fn delta_cost(
&self,
from: (usize, usize),
to: (usize, usize),
periodic: bool,
width: usize,
height: usize,
_diagonal: bool,
) -> f64 {
let (dx, dy) = delta_with_periodic(from, to, periodic, width, height);
dx.max(dy) as f64
}
}
#[derive(Debug, Clone, Copy)]
pub struct Manhattan;
impl CostMetric for Manhattan {
fn delta_cost(
&self,
from: (usize, usize),
to: (usize, usize),
periodic: bool,
width: usize,
height: usize,
_diagonal: bool,
) -> f64 {
let (dx, dy) = delta_with_periodic(from, to, periodic, width, height);
(dx + dy) as f64
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
pub enum PenaltyMapError {
#[error("penalty-map dimensions must be positive")]
InvalidDimensions,
#[error("penalty map size mismatch: expected {expected}, got {actual}")]
SizeMismatch { expected: usize, actual: usize },
}
#[derive(Debug)]
pub struct PenaltyMap {
penalties: Vec<i32>,
map_width: usize,
#[allow(dead_code)]
map_height: usize,
base_metric: Box<dyn CostMetric>,
}
impl PenaltyMap {
pub fn new(
penalties: Vec<i32>,
width: usize,
height: usize,
base: impl CostMetric + 'static,
) -> Result<Self, PenaltyMapError> {
if width == 0 || height == 0 {
return Err(PenaltyMapError::InvalidDimensions);
}
let expected = width * height;
if penalties.len() != expected {
return Err(PenaltyMapError::SizeMismatch {
expected,
actual: penalties.len(),
});
}
Ok(Self {
penalties,
map_width: width,
map_height: height,
base_metric: Box::new(base),
})
}
pub fn penalties(&self) -> &[i32] {
&self.penalties
}
pub fn penalties_mut(&mut self) -> &mut [i32] {
&mut self.penalties
}
pub fn penalty_at(&self, x: usize, y: usize) -> i32 {
self.penalties[y * self.map_width + x]
}
}
impl CostMetric for PenaltyMap {
fn delta_cost(
&self,
from: (usize, usize),
to: (usize, usize),
periodic: bool,
width: usize,
height: usize,
diagonal: bool,
) -> f64 {
let base = self
.base_metric
.delta_cost(from, to, periodic, width, height, diagonal);
let pen_from = self.penalties[from.1 * self.map_width + from.0];
let pen_to = self.penalties[to.1 * self.map_width + to.0];
base + (pen_to - pen_from).unsigned_abs() as f64
}
}
fn delta_with_periodic(
from: (usize, usize),
to: (usize, usize),
periodic: bool,
width: usize,
height: usize,
) -> (usize, usize) {
if periodic {
let dx_raw = (from.0 as isize - to.0 as isize).unsigned_abs();
let dy_raw = (from.1 as isize - to.1 as isize).unsigned_abs();
let dx = dx_raw.min(width - dx_raw);
let dy = dy_raw.min(height - dy_raw);
(dx, dy)
} else {
let dx = (from.0 as isize - to.0 as isize).unsigned_abs();
let dy = (from.1 as isize - to.1 as isize).unsigned_abs();
(dx, dy)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn direct_distance_cardinal() {
let m = DirectDistance::new();
let cost = m.delta_cost((0, 0), (3, 0), false, 10, 10, false);
assert!((cost - 3.0).abs() < 1e-10);
}
#[test]
fn direct_distance_diagonal() {
let m = DirectDistance::new();
let cost = m.delta_cost((0, 0), (2, 2), false, 10, 10, true);
assert!((cost - 2.0 * std::f64::consts::SQRT_2).abs() < 1e-10);
}
#[test]
fn direct_distance_mixed() {
let m = DirectDistance::new();
let cost = m.delta_cost((0, 0), (3, 1), false, 10, 10, true);
assert!((cost - (std::f64::consts::SQRT_2 + 2.0)).abs() < 1e-10);
}
#[test]
fn max_distance_basic() {
let cost = MaxDistance.delta_cost((0, 0), (3, 5), false, 10, 10, true);
assert!((cost - 5.0).abs() < 1e-10);
}
#[test]
fn manhattan_basic() {
let cost = Manhattan.delta_cost((0, 0), (3, 5), false, 10, 10, false);
assert!((cost - 8.0).abs() < 1e-10);
}
#[test]
fn penalty_map_basic() {
let mut pens = vec![0i32; 25];
pens[2 * 5 + 2] = 100; let m = PenaltyMap::new(pens, 5, 5, DirectDistance::new()).unwrap();
let cost = m.delta_cost((0, 0), (2, 2), false, 5, 5, true);
let base = DirectDistance::new().delta_cost((0, 0), (2, 2), false, 5, 5, true);
assert!((cost - (base + 100.0)).abs() < 1e-10);
}
#[test]
fn periodic_distance() {
let m = DirectDistance::new();
let cost = m.delta_cost((0, 0), (9, 0), true, 10, 10, false);
assert!((cost - 1.0).abs() < 1e-10);
}
#[test]
fn penalty_map_invalid_dimensions() {
let pens = vec![0i32; 25];
let result = PenaltyMap::new(pens, 0, 5, DirectDistance::new());
assert!(matches!(result, Err(PenaltyMapError::InvalidDimensions)));
}
#[test]
fn penalty_map_size_mismatch() {
let pens = vec![0i32; 24]; let result = PenaltyMap::new(pens, 5, 5, DirectDistance::new());
assert!(matches!(
result,
Err(PenaltyMapError::SizeMismatch {
expected: 25,
actual: 24
})
));
}
}