Skip to main content

rustsim_pathfinding/
metrics.rs

1//! Cost metrics for A* pathfinding.
2//!
3//! Mirrors Julia Agents.jl `Pathfinding.CostMetric` system. Cost metrics
4//! estimate the travel cost between two grid cells for use as the A* heuristic
5//! and/or edge cost.
6//!
7//! # Built-in metrics
8//!
9//! | Metric | Description |
10//! |--------|-------------|
11//! | [`DirectDistance`] | Shortest diagonal/orthogonal path cost (default) |
12//! | [`MaxDistance`] | Maximum absolute coordinate difference (Chebyshev) |
13//! | [`Manhattan`] | Sum of absolute coordinate differences (L1) |
14//! | [`PenaltyMap`] | Base metric + per-cell penalty difference (height maps) |
15
16use thiserror::Error;
17
18/// Trait for cost metrics used in A* pathfinding on 2D grids.
19///
20/// Implementing this trait allows custom cost metrics to be plugged into
21/// the grid A* algorithm.
22pub trait CostMetric: std::fmt::Debug {
23    /// Estimate the cost of traveling from `from` to `to`.
24    ///
25    /// For heuristic use, this should be admissible (never overestimate).
26    fn delta_cost(
27        &self,
28        from: (usize, usize),
29        to: (usize, usize),
30        periodic: bool,
31        width: usize,
32        height: usize,
33        diagonal: bool,
34    ) -> f64;
35}
36
37/// Direct-distance metric -- the default.
38///
39/// Estimates cost as the shortest diagonal+orthogonal path between two cells.
40/// When diagonal movement is enabled, uses `cost_diagonal * min_delta + cost_cardinal * (max_delta - min_delta)`.
41/// Otherwise, uses `cost_cardinal * manhattan_distance`.
42///
43/// `cost_cardinal` defaults to `1.0`, `cost_diagonal` defaults to `?2`.
44///
45/// Mirrors Agents.jl `DirectDistance`.
46#[derive(Debug, Clone)]
47pub struct DirectDistance {
48    /// Cost of an orthogonal step (horizontal/vertical).
49    pub cost_cardinal: f64,
50    /// Cost of a diagonal step.
51    pub cost_diagonal: f64,
52}
53
54impl DirectDistance {
55    /// Create with default costs (cardinal = 1.0, diagonal = ?2).
56    pub fn new() -> Self {
57        Self {
58            cost_cardinal: 1.0,
59            cost_diagonal: std::f64::consts::SQRT_2,
60        }
61    }
62
63    /// Create with custom step costs.
64    pub fn with_costs(cost_cardinal: f64, cost_diagonal: f64) -> Self {
65        Self {
66            cost_cardinal,
67            cost_diagonal,
68        }
69    }
70}
71
72impl Default for DirectDistance {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl CostMetric for DirectDistance {
79    fn delta_cost(
80        &self,
81        from: (usize, usize),
82        to: (usize, usize),
83        periodic: bool,
84        width: usize,
85        height: usize,
86        diagonal: bool,
87    ) -> f64 {
88        let (dx, dy) = delta_with_periodic(from, to, periodic, width, height);
89        if diagonal {
90            let min_d = dx.min(dy) as f64;
91            let max_d = dx.max(dy) as f64;
92            self.cost_diagonal * min_d + self.cost_cardinal * (max_d - min_d)
93        } else {
94            self.cost_cardinal * (dx + dy) as f64
95        }
96    }
97}
98
99/// Maximum-coordinate-difference metric (Chebyshev distance).
100///
101/// Cost = `max(|dx|, |dy|)`.
102///
103/// Mirrors Agents.jl `MaxDistance`.
104#[derive(Debug, Clone, Copy)]
105pub struct MaxDistance;
106
107impl CostMetric for MaxDistance {
108    fn delta_cost(
109        &self,
110        from: (usize, usize),
111        to: (usize, usize),
112        periodic: bool,
113        width: usize,
114        height: usize,
115        _diagonal: bool,
116    ) -> f64 {
117        let (dx, dy) = delta_with_periodic(from, to, periodic, width, height);
118        dx.max(dy) as f64
119    }
120}
121
122/// Manhattan distance metric (L1 / taxicab distance).
123///
124/// Cost = `|dx| + |dy|`.
125#[derive(Debug, Clone, Copy)]
126pub struct Manhattan;
127
128impl CostMetric for Manhattan {
129    fn delta_cost(
130        &self,
131        from: (usize, usize),
132        to: (usize, usize),
133        periodic: bool,
134        width: usize,
135        height: usize,
136        _diagonal: bool,
137    ) -> f64 {
138        let (dx, dy) = delta_with_periodic(from, to, periodic, width, height);
139        (dx + dy) as f64
140    }
141}
142
143/// Errors returned by penalty-map validation.
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
145pub enum PenaltyMapError {
146    #[error("penalty-map dimensions must be positive")]
147    InvalidDimensions,
148    #[error("penalty map size mismatch: expected {expected}, got {actual}")]
149    SizeMismatch { expected: usize, actual: usize },
150}
151
152/// Penalty-map metric -- base distance plus per-cell penalty difference.
153///
154/// Cost = `base_metric.delta_cost(from, to) + |penalty[from] - penalty[to]|`.
155///
156/// The penalty map is a 2D array (row-major, `height x width`) of integer
157/// penalties. This is useful for modeling terrain height, road quality, etc.
158///
159/// Mirrors Agents.jl `PenaltyMap`.
160///
161/// # Example
162///
163/// ```
164/// use rustsim_pathfinding::metrics::{PenaltyMap, DirectDistance};
165///
166/// // 5x5 grid with a hill in the center
167/// let mut penalties = vec![0i32; 25];
168/// penalties[2 * 5 + 2] = 100; // cell (2,2) is expensive
169///
170/// let metric = PenaltyMap::new(penalties, 5, 5, DirectDistance::new()).unwrap();
171/// ```
172#[derive(Debug)]
173pub struct PenaltyMap {
174    /// Flat penalty array in row-major order (`penalties[y * width + x]`).
175    penalties: Vec<i32>,
176    map_width: usize,
177    #[allow(dead_code)]
178    map_height: usize,
179    /// Underlying distance metric.
180    base_metric: Box<dyn CostMetric>,
181}
182
183// Manual Debug for the boxed trait object is already handled by #[derive(Debug)] on the struct,
184// since Box<dyn CostMetric> requires CostMetric: Debug.
185
186impl PenaltyMap {
187    /// Create a penalty map metric.
188    pub fn new(
189        penalties: Vec<i32>,
190        width: usize,
191        height: usize,
192        base: impl CostMetric + 'static,
193    ) -> Result<Self, PenaltyMapError> {
194        if width == 0 || height == 0 {
195            return Err(PenaltyMapError::InvalidDimensions);
196        }
197        let expected = width * height;
198        if penalties.len() != expected {
199            return Err(PenaltyMapError::SizeMismatch {
200                expected,
201                actual: penalties.len(),
202            });
203        }
204        Ok(Self {
205            penalties,
206            map_width: width,
207            map_height: height,
208            base_metric: Box::new(base),
209        })
210    }
211
212    /// Read-only access to the penalty values.
213    pub fn penalties(&self) -> &[i32] {
214        &self.penalties
215    }
216
217    /// Mutable access to the penalty values (e.g. for dynamic terrain).
218    pub fn penalties_mut(&mut self) -> &mut [i32] {
219        &mut self.penalties
220    }
221
222    /// Get the penalty at a grid cell.
223    pub fn penalty_at(&self, x: usize, y: usize) -> i32 {
224        self.penalties[y * self.map_width + x]
225    }
226}
227
228impl CostMetric for PenaltyMap {
229    fn delta_cost(
230        &self,
231        from: (usize, usize),
232        to: (usize, usize),
233        periodic: bool,
234        width: usize,
235        height: usize,
236        diagonal: bool,
237    ) -> f64 {
238        let base = self
239            .base_metric
240            .delta_cost(from, to, periodic, width, height, diagonal);
241        let pen_from = self.penalties[from.1 * self.map_width + from.0];
242        let pen_to = self.penalties[to.1 * self.map_width + to.0];
243        base + (pen_to - pen_from).unsigned_abs() as f64
244    }
245}
246
247/// Compute the minimum absolute delta between two grid positions,
248/// accounting for periodic wrapping.
249fn delta_with_periodic(
250    from: (usize, usize),
251    to: (usize, usize),
252    periodic: bool,
253    width: usize,
254    height: usize,
255) -> (usize, usize) {
256    if periodic {
257        let dx_raw = (from.0 as isize - to.0 as isize).unsigned_abs();
258        let dy_raw = (from.1 as isize - to.1 as isize).unsigned_abs();
259        let dx = dx_raw.min(width - dx_raw);
260        let dy = dy_raw.min(height - dy_raw);
261        (dx, dy)
262    } else {
263        let dx = (from.0 as isize - to.0 as isize).unsigned_abs();
264        let dy = (from.1 as isize - to.1 as isize).unsigned_abs();
265        (dx, dy)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn direct_distance_cardinal() {
275        let m = DirectDistance::new();
276        let cost = m.delta_cost((0, 0), (3, 0), false, 10, 10, false);
277        assert!((cost - 3.0).abs() < 1e-10);
278    }
279
280    #[test]
281    fn direct_distance_diagonal() {
282        let m = DirectDistance::new();
283        // (0,0) -> (2,2): 2 diagonal steps = 2*?2
284        let cost = m.delta_cost((0, 0), (2, 2), false, 10, 10, true);
285        assert!((cost - 2.0 * std::f64::consts::SQRT_2).abs() < 1e-10);
286    }
287
288    #[test]
289    fn direct_distance_mixed() {
290        let m = DirectDistance::new();
291        // (0,0) -> (3,1): 1 diagonal + 2 cardinal = ?2 + 2
292        let cost = m.delta_cost((0, 0), (3, 1), false, 10, 10, true);
293        assert!((cost - (std::f64::consts::SQRT_2 + 2.0)).abs() < 1e-10);
294    }
295
296    #[test]
297    fn max_distance_basic() {
298        let cost = MaxDistance.delta_cost((0, 0), (3, 5), false, 10, 10, true);
299        assert!((cost - 5.0).abs() < 1e-10);
300    }
301
302    #[test]
303    fn manhattan_basic() {
304        let cost = Manhattan.delta_cost((0, 0), (3, 5), false, 10, 10, false);
305        assert!((cost - 8.0).abs() < 1e-10);
306    }
307
308    #[test]
309    fn penalty_map_basic() {
310        let mut pens = vec![0i32; 25];
311        pens[2 * 5 + 2] = 100; // (2,2) has penalty 100
312        let m = PenaltyMap::new(pens, 5, 5, DirectDistance::new()).unwrap();
313
314        // From (0,0) penalty=0 to (2,2) penalty=100: base + 100
315        let cost = m.delta_cost((0, 0), (2, 2), false, 5, 5, true);
316        let base = DirectDistance::new().delta_cost((0, 0), (2, 2), false, 5, 5, true);
317        assert!((cost - (base + 100.0)).abs() < 1e-10);
318    }
319
320    #[test]
321    fn periodic_distance() {
322        let m = DirectDistance::new();
323        // On a 10x10 periodic grid, (0,0) to (9,0) should be 1 step, not 9
324        let cost = m.delta_cost((0, 0), (9, 0), true, 10, 10, false);
325        assert!((cost - 1.0).abs() < 1e-10);
326    }
327
328    #[test]
329    fn penalty_map_invalid_dimensions() {
330        let pens = vec![0i32; 25];
331        let result = PenaltyMap::new(pens, 0, 5, DirectDistance::new());
332        assert!(matches!(result, Err(PenaltyMapError::InvalidDimensions)));
333    }
334
335    #[test]
336    fn penalty_map_size_mismatch() {
337        let pens = vec![0i32; 24]; // Wrong size
338        let result = PenaltyMap::new(pens, 5, 5, DirectDistance::new());
339        assert!(matches!(
340            result,
341            Err(PenaltyMapError::SizeMismatch {
342                expected: 25,
343                actual: 24
344            })
345        ));
346    }
347}