shortestpath/
gradient.rs

1// Copyright (C) 2024 Christian Mauduit <ufoot@ufoot.org>
2
3use crate::distance::*;
4use crate::gate::*;
5use crate::mesh::*;
6use std::fmt;
7use std::iter::Iterator;
8
9#[derive(Debug, Clone)]
10pub struct Gradient {
11    slots: Vec<Gate>,
12}
13
14impl Gradient {
15    pub fn new_with_len(len: usize) -> Self {
16        let mut slots = Vec::with_capacity(len);
17        for i in 0..len {
18            slots.push(Gate::new(i, DISTANCE_MAX));
19        }
20        Gradient { slots }
21    }
22
23    pub fn new_with_mesh(mesh: &impl Mesh) -> Self {
24        Self::new_with_len(mesh.len())
25    }
26
27    pub fn spread_forward(&mut self, mesh: &impl Mesh, incr: f64) -> usize {
28        if self.slots.len() != mesh.len() {
29            panic!(
30                "slots len {} does not match mesh len {}",
31                self.slots.len(),
32                mesh.len()
33            )
34        }
35        (0..(mesh.len()))
36            .into_iter()
37            .map(|i| self.spread_slot(i, mesh, incr))
38            .filter(|ok| *ok)
39            .count()
40    }
41
42    pub fn spread_backward(&mut self, mesh: &impl Mesh, incr: f64) -> usize {
43        if self.slots.len() != mesh.len() {
44            panic!(
45                "slots len {} does not match mesh len {}",
46                self.slots.len(),
47                mesh.len()
48            )
49        }
50        (0..(mesh.len()))
51            .rev()
52            .into_iter()
53            .map(|i| self.spread_slot(i, mesh, incr))
54            .filter(|ok| *ok)
55            .count()
56    }
57
58    pub fn spread_both(&mut self, mesh: &impl Mesh, incr: f64) -> usize {
59        self.spread_forward(mesh, incr) + self.spread_backward(mesh, incr)
60    }
61
62    pub fn spread_slot(&mut self, here_index: usize, mesh: &impl Mesh, incr: f64) -> bool {
63        let here = Gate::new(here_index, 0.0);
64        let best = mesh.successors(here_index).into_iter().fold(here, |a, b| {
65            if self.slots[a.target].distance + a.distance
66                <= self.slots[b.target].distance + b.distance
67            {
68                a
69            } else {
70                b
71            }
72        });
73        if best.target != here_index {
74            self.slots[here_index] = Gate::new(
75                best.target,
76                self.slots[best.target].distance + best.distance,
77            );
78            true
79        } else {
80            // If nothing has changed, consider we're farther by "incr" which
81            // avoids old paths to remain valid when they should not be.
82            if incr != 0.0 {
83                self.slots[here_index].distance += incr;
84            }
85            false
86        }
87    }
88
89    pub fn incr(&mut self, incr: f64) {
90        self.slots.iter_mut().for_each(|i| i.distance += incr);
91    }
92
93    pub fn spread(&mut self, mesh: &impl Mesh) {
94        while self.spread_both(mesh, 0.0) > 0 {}
95    }
96
97    pub fn set_distance(&mut self, slot_index: usize, value: f64) {
98        self.slots[slot_index].distance = value
99    }
100
101    pub fn get_distance(&self, slot_index: usize) -> f64 {
102        self.slots[slot_index].distance
103    }
104
105    pub fn get_target(&self, slot_index: usize) -> usize {
106        self.slots[slot_index].target
107    }
108
109    pub fn get(&self, slot_index: usize) -> &Gate {
110        &self.slots[slot_index]
111    }
112}
113
114impl std::fmt::Display for Gradient {
115    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116        write!(f, "TODO")
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use crate::mesh_2d::*;
124
125    #[test]
126    fn test_spread_forward() {
127        let rect = Full2D::new(5, 5);
128        let mut grad = Gradient::new_with_mesh(&rect);
129        grad.set_distance(12, 0.0);
130        grad.spread_forward(&rect, 0.0);
131        // row 0
132        assert_eq!(DISTANCE_MAX, grad.get_distance(0));
133        assert_eq!(0, grad.get_target(0));
134        assert_eq!(DISTANCE_MAX, grad.get_distance(1));
135        assert_eq!(1, grad.get_target(1));
136        assert_eq!(DISTANCE_MAX, grad.get_distance(2));
137        assert_eq!(2, grad.get_target(2));
138        assert_eq!(DISTANCE_MAX, grad.get_distance(3));
139        assert_eq!(3, grad.get_target(3));
140        assert_eq!(DISTANCE_MAX, grad.get_distance(4));
141        assert_eq!(4, grad.get_target(4));
142        // row 1
143        assert_eq!(DISTANCE_MAX, grad.get_distance(5));
144        assert_eq!(5, grad.get_target(5));
145        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(6));
146        assert_eq!(12, grad.get_target(6));
147        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(7));
148        assert_eq!(12, grad.get_target(7));
149        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(8));
150        assert_eq!(12, grad.get_target(8));
151        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(9));
152        assert_eq!(8, grad.get_target(9));
153        // row 2
154        assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(10));
155        assert_eq!(6, grad.get_target(10));
156        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(11));
157        assert_eq!(12, grad.get_target(11));
158        assert_eq!(DISTANCE_MIN, grad.get_distance(12));
159        assert_eq!(12, grad.get_target(12));
160        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(13));
161        assert_eq!(12, grad.get_target(13));
162        assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(14));
163        assert_eq!(13, grad.get_target(14));
164        // row 3
165        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(15));
166        assert_eq!(11, grad.get_target(15));
167        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(16));
168        assert_eq!(12, grad.get_target(16));
169        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(17));
170        assert_eq!(12, grad.get_target(17));
171        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(18));
172        assert_eq!(12, grad.get_target(18));
173        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(19));
174        assert_eq!(18, grad.get_target(19));
175        // row 4
176        assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(20));
177        assert_eq!(16, grad.get_target(20));
178        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(21));
179        assert_eq!(16, grad.get_target(21));
180        assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(22));
181        assert_eq!(17, grad.get_target(22));
182        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(23));
183        assert_eq!(17, grad.get_target(23));
184        assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(24));
185        assert_eq!(18, grad.get_target(24));
186    }
187
188    #[test]
189    fn test_spread_backward() {
190        let rect = Full2D::new(5, 5);
191        let mut grad = Gradient::new_with_mesh(&rect);
192        grad.set_distance(12, 0.0);
193        grad.spread_backward(&rect, 0.0);
194        // row 5
195        assert_eq!(DISTANCE_MAX, grad.get_distance(23));
196        assert_eq!(24, grad.get_target(24));
197        assert_eq!(DISTANCE_MAX, grad.get_distance(23));
198        assert_eq!(23, grad.get_target(23));
199        assert_eq!(DISTANCE_MAX, grad.get_distance(22));
200        assert_eq!(22, grad.get_target(22));
201        assert_eq!(DISTANCE_MAX, grad.get_distance(21));
202        assert_eq!(21, grad.get_target(21));
203        assert_eq!(DISTANCE_MAX, grad.get_distance(20));
204        assert_eq!(20, grad.get_target(20));
205        // row 4
206        assert_eq!(DISTANCE_MAX, grad.get_distance(19));
207        assert_eq!(19, grad.get_target(19));
208        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(18));
209        assert_eq!(12, grad.get_target(18));
210        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(17));
211        assert_eq!(12, grad.get_target(17));
212        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(16));
213        assert_eq!(12, grad.get_target(16));
214        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(15));
215        assert_eq!(16, grad.get_target(15));
216        // row 3
217        assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(14));
218        assert_eq!(18, grad.get_target(14));
219        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(13));
220        assert_eq!(12, grad.get_target(13));
221        assert_eq!(DISTANCE_MIN, grad.get_distance(12));
222        assert_eq!(12, grad.get_target(12));
223        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(11));
224        assert_eq!(12, grad.get_target(11));
225        assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(10));
226        assert_eq!(11, grad.get_target(10));
227        // row 1
228        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(9));
229        assert_eq!(13, grad.get_target(9));
230        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(8));
231        assert_eq!(12, grad.get_target(8));
232        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(7));
233        assert_eq!(12, grad.get_target(7));
234        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(6));
235        assert_eq!(12, grad.get_target(6));
236        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(5));
237        assert_eq!(6, grad.get_target(5));
238        // row 0
239        assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(4));
240        assert_eq!(8, grad.get_target(4));
241        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(3));
242        assert_eq!(8, grad.get_target(3));
243        assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(2));
244        assert_eq!(7, grad.get_target(2));
245        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(1));
246        assert_eq!(7, grad.get_target(1));
247        assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(0));
248        assert_eq!(6, grad.get_target(0));
249    }
250
251    #[test]
252    fn test_spread() {
253        let rect = Full2D::new(5, 5);
254        let mut grad = Gradient::new_with_mesh(&rect);
255        grad.set_distance(12, 0.0);
256        grad.spread(&rect);
257        // row 0
258        assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(0));
259        assert_eq!(6, grad.get_target(0));
260        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(1));
261        assert_eq!(7, grad.get_target(1));
262        assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(2));
263        assert_eq!(7, grad.get_target(2));
264        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(3));
265        assert_eq!(8, grad.get_target(3));
266        assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(4));
267        assert_eq!(8, grad.get_target(4));
268        // row 1
269        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(5));
270        assert_eq!(6, grad.get_target(5));
271        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(6));
272        assert_eq!(12, grad.get_target(6));
273        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(7));
274        assert_eq!(12, grad.get_target(7));
275        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(8));
276        assert_eq!(12, grad.get_target(8));
277        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(9));
278        assert_eq!(8, grad.get_target(9));
279        // row 2
280        assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(10));
281        assert_eq!(11, grad.get_target(10));
282        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(11));
283        assert_eq!(12, grad.get_target(11));
284        assert_eq!(DISTANCE_MIN, grad.get_distance(12));
285        assert_eq!(12, grad.get_target(12));
286        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(13));
287        assert_eq!(12, grad.get_target(13));
288        assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(14));
289        assert_eq!(13, grad.get_target(14));
290        // row 3
291        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(15));
292        assert_eq!(11, grad.get_target(15));
293        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(16));
294        assert_eq!(12, grad.get_target(16));
295        assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(17));
296        assert_eq!(12, grad.get_target(17));
297        assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(18));
298        assert_eq!(12, grad.get_target(18));
299        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(19));
300        assert_eq!(18, grad.get_target(19));
301        // row 4
302        assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(20));
303        assert_eq!(16, grad.get_target(20));
304        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(21));
305        assert_eq!(16, grad.get_target(21));
306        assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(22));
307        assert_eq!(17, grad.get_target(22));
308        assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(23));
309        assert_eq!(17, grad.get_target(23));
310        assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(24));
311        assert_eq!(18, grad.get_target(24));
312    }
313}