1use thiserror::Error;
17
18pub trait CostMetric: std::fmt::Debug {
23 fn delta_cost(
27 &self,
28 from: (usize, usize),
29 to: (usize, usize),
30 periodic: bool,
31 width: usize,
32 height: usize,
33 diagonal: bool,
34 ) -> f64;
35}
36
37#[derive(Debug, Clone)]
47pub struct DirectDistance {
48 pub cost_cardinal: f64,
50 pub cost_diagonal: f64,
52}
53
54impl DirectDistance {
55 pub fn new() -> Self {
57 Self {
58 cost_cardinal: 1.0,
59 cost_diagonal: std::f64::consts::SQRT_2,
60 }
61 }
62
63 pub fn with_costs(cost_cardinal: f64, cost_diagonal: f64) -> Self {
65 Self {
66 cost_cardinal,
67 cost_diagonal,
68 }
69 }
70}
71
72impl Default for DirectDistance {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl CostMetric for DirectDistance {
79 fn delta_cost(
80 &self,
81 from: (usize, usize),
82 to: (usize, usize),
83 periodic: bool,
84 width: usize,
85 height: usize,
86 diagonal: bool,
87 ) -> f64 {
88 let (dx, dy) = delta_with_periodic(from, to, periodic, width, height);
89 if diagonal {
90 let min_d = dx.min(dy) as f64;
91 let max_d = dx.max(dy) as f64;
92 self.cost_diagonal * min_d + self.cost_cardinal * (max_d - min_d)
93 } else {
94 self.cost_cardinal * (dx + dy) as f64
95 }
96 }
97}
98
99#[derive(Debug, Clone, Copy)]
105pub struct MaxDistance;
106
107impl CostMetric for MaxDistance {
108 fn delta_cost(
109 &self,
110 from: (usize, usize),
111 to: (usize, usize),
112 periodic: bool,
113 width: usize,
114 height: usize,
115 _diagonal: bool,
116 ) -> f64 {
117 let (dx, dy) = delta_with_periodic(from, to, periodic, width, height);
118 dx.max(dy) as f64
119 }
120}
121
122#[derive(Debug, Clone, Copy)]
126pub struct Manhattan;
127
128impl CostMetric for Manhattan {
129 fn delta_cost(
130 &self,
131 from: (usize, usize),
132 to: (usize, usize),
133 periodic: bool,
134 width: usize,
135 height: usize,
136 _diagonal: bool,
137 ) -> f64 {
138 let (dx, dy) = delta_with_periodic(from, to, periodic, width, height);
139 (dx + dy) as f64
140 }
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
145pub enum PenaltyMapError {
146 #[error("penalty-map dimensions must be positive")]
147 InvalidDimensions,
148 #[error("penalty map size mismatch: expected {expected}, got {actual}")]
149 SizeMismatch { expected: usize, actual: usize },
150}
151
152#[derive(Debug)]
173pub struct PenaltyMap {
174 penalties: Vec<i32>,
176 map_width: usize,
177 #[allow(dead_code)]
178 map_height: usize,
179 base_metric: Box<dyn CostMetric>,
181}
182
183impl PenaltyMap {
187 pub fn new(
189 penalties: Vec<i32>,
190 width: usize,
191 height: usize,
192 base: impl CostMetric + 'static,
193 ) -> Result<Self, PenaltyMapError> {
194 if width == 0 || height == 0 {
195 return Err(PenaltyMapError::InvalidDimensions);
196 }
197 let expected = width * height;
198 if penalties.len() != expected {
199 return Err(PenaltyMapError::SizeMismatch {
200 expected,
201 actual: penalties.len(),
202 });
203 }
204 Ok(Self {
205 penalties,
206 map_width: width,
207 map_height: height,
208 base_metric: Box::new(base),
209 })
210 }
211
212 pub fn penalties(&self) -> &[i32] {
214 &self.penalties
215 }
216
217 pub fn penalties_mut(&mut self) -> &mut [i32] {
219 &mut self.penalties
220 }
221
222 pub fn penalty_at(&self, x: usize, y: usize) -> i32 {
224 self.penalties[y * self.map_width + x]
225 }
226}
227
228impl CostMetric for PenaltyMap {
229 fn delta_cost(
230 &self,
231 from: (usize, usize),
232 to: (usize, usize),
233 periodic: bool,
234 width: usize,
235 height: usize,
236 diagonal: bool,
237 ) -> f64 {
238 let base = self
239 .base_metric
240 .delta_cost(from, to, periodic, width, height, diagonal);
241 let pen_from = self.penalties[from.1 * self.map_width + from.0];
242 let pen_to = self.penalties[to.1 * self.map_width + to.0];
243 base + (pen_to - pen_from).unsigned_abs() as f64
244 }
245}
246
247fn delta_with_periodic(
250 from: (usize, usize),
251 to: (usize, usize),
252 periodic: bool,
253 width: usize,
254 height: usize,
255) -> (usize, usize) {
256 if periodic {
257 let dx_raw = (from.0 as isize - to.0 as isize).unsigned_abs();
258 let dy_raw = (from.1 as isize - to.1 as isize).unsigned_abs();
259 let dx = dx_raw.min(width - dx_raw);
260 let dy = dy_raw.min(height - dy_raw);
261 (dx, dy)
262 } else {
263 let dx = (from.0 as isize - to.0 as isize).unsigned_abs();
264 let dy = (from.1 as isize - to.1 as isize).unsigned_abs();
265 (dx, dy)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn direct_distance_cardinal() {
275 let m = DirectDistance::new();
276 let cost = m.delta_cost((0, 0), (3, 0), false, 10, 10, false);
277 assert!((cost - 3.0).abs() < 1e-10);
278 }
279
280 #[test]
281 fn direct_distance_diagonal() {
282 let m = DirectDistance::new();
283 let cost = m.delta_cost((0, 0), (2, 2), false, 10, 10, true);
285 assert!((cost - 2.0 * std::f64::consts::SQRT_2).abs() < 1e-10);
286 }
287
288 #[test]
289 fn direct_distance_mixed() {
290 let m = DirectDistance::new();
291 let cost = m.delta_cost((0, 0), (3, 1), false, 10, 10, true);
293 assert!((cost - (std::f64::consts::SQRT_2 + 2.0)).abs() < 1e-10);
294 }
295
296 #[test]
297 fn max_distance_basic() {
298 let cost = MaxDistance.delta_cost((0, 0), (3, 5), false, 10, 10, true);
299 assert!((cost - 5.0).abs() < 1e-10);
300 }
301
302 #[test]
303 fn manhattan_basic() {
304 let cost = Manhattan.delta_cost((0, 0), (3, 5), false, 10, 10, false);
305 assert!((cost - 8.0).abs() < 1e-10);
306 }
307
308 #[test]
309 fn penalty_map_basic() {
310 let mut pens = vec![0i32; 25];
311 pens[2 * 5 + 2] = 100; let m = PenaltyMap::new(pens, 5, 5, DirectDistance::new()).unwrap();
313
314 let cost = m.delta_cost((0, 0), (2, 2), false, 5, 5, true);
316 let base = DirectDistance::new().delta_cost((0, 0), (2, 2), false, 5, 5, true);
317 assert!((cost - (base + 100.0)).abs() < 1e-10);
318 }
319
320 #[test]
321 fn periodic_distance() {
322 let m = DirectDistance::new();
323 let cost = m.delta_cost((0, 0), (9, 0), true, 10, 10, false);
325 assert!((cost - 1.0).abs() < 1e-10);
326 }
327
328 #[test]
329 fn penalty_map_invalid_dimensions() {
330 let pens = vec![0i32; 25];
331 let result = PenaltyMap::new(pens, 0, 5, DirectDistance::new());
332 assert!(matches!(result, Err(PenaltyMapError::InvalidDimensions)));
333 }
334
335 #[test]
336 fn penalty_map_size_mismatch() {
337 let pens = vec![0i32; 24]; let result = PenaltyMap::new(pens, 5, 5, DirectDistance::new());
339 assert!(matches!(
340 result,
341 Err(PenaltyMapError::SizeMismatch {
342 expected: 25,
343 actual: 24
344 })
345 ));
346 }
347}