1use crate::distance::*;
4use crate::gate::*;
5use crate::mesh::*;
6use std::fmt;
7use std::iter::Iterator;
8
9#[derive(Debug, Clone)]
10pub struct Gradient {
11 slots: Vec<Gate>,
12}
13
14impl Gradient {
15 pub fn new_with_len(len: usize) -> Self {
16 let mut slots = Vec::with_capacity(len);
17 for i in 0..len {
18 slots.push(Gate::new(i, DISTANCE_MAX));
19 }
20 Gradient { slots }
21 }
22
23 pub fn new_with_mesh(mesh: &impl Mesh) -> Self {
24 Self::new_with_len(mesh.len())
25 }
26
27 pub fn spread_forward(&mut self, mesh: &impl Mesh, incr: f64) -> usize {
28 if self.slots.len() != mesh.len() {
29 panic!(
30 "slots len {} does not match mesh len {}",
31 self.slots.len(),
32 mesh.len()
33 )
34 }
35 (0..(mesh.len()))
36 .into_iter()
37 .map(|i| self.spread_slot(i, mesh, incr))
38 .filter(|ok| *ok)
39 .count()
40 }
41
42 pub fn spread_backward(&mut self, mesh: &impl Mesh, incr: f64) -> usize {
43 if self.slots.len() != mesh.len() {
44 panic!(
45 "slots len {} does not match mesh len {}",
46 self.slots.len(),
47 mesh.len()
48 )
49 }
50 (0..(mesh.len()))
51 .rev()
52 .into_iter()
53 .map(|i| self.spread_slot(i, mesh, incr))
54 .filter(|ok| *ok)
55 .count()
56 }
57
58 pub fn spread_both(&mut self, mesh: &impl Mesh, incr: f64) -> usize {
59 self.spread_forward(mesh, incr) + self.spread_backward(mesh, incr)
60 }
61
62 pub fn spread_slot(&mut self, here_index: usize, mesh: &impl Mesh, incr: f64) -> bool {
63 let here = Gate::new(here_index, 0.0);
64 let best = mesh.successors(here_index).into_iter().fold(here, |a, b| {
65 if self.slots[a.target].distance + a.distance
66 <= self.slots[b.target].distance + b.distance
67 {
68 a
69 } else {
70 b
71 }
72 });
73 if best.target != here_index {
74 self.slots[here_index] = Gate::new(
75 best.target,
76 self.slots[best.target].distance + best.distance,
77 );
78 true
79 } else {
80 if incr != 0.0 {
83 self.slots[here_index].distance += incr;
84 }
85 false
86 }
87 }
88
89 pub fn incr(&mut self, incr: f64) {
90 self.slots.iter_mut().for_each(|i| i.distance += incr);
91 }
92
93 pub fn spread(&mut self, mesh: &impl Mesh) {
94 while self.spread_both(mesh, 0.0) > 0 {}
95 }
96
97 pub fn set_distance(&mut self, slot_index: usize, value: f64) {
98 self.slots[slot_index].distance = value
99 }
100
101 pub fn get_distance(&self, slot_index: usize) -> f64 {
102 self.slots[slot_index].distance
103 }
104
105 pub fn get_target(&self, slot_index: usize) -> usize {
106 self.slots[slot_index].target
107 }
108
109 pub fn get(&self, slot_index: usize) -> &Gate {
110 &self.slots[slot_index]
111 }
112}
113
114impl std::fmt::Display for Gradient {
115 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116 write!(f, "TODO")
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use crate::mesh_2d::*;
124
125 #[test]
126 fn test_spread_forward() {
127 let rect = Full2D::new(5, 5);
128 let mut grad = Gradient::new_with_mesh(&rect);
129 grad.set_distance(12, 0.0);
130 grad.spread_forward(&rect, 0.0);
131 assert_eq!(DISTANCE_MAX, grad.get_distance(0));
133 assert_eq!(0, grad.get_target(0));
134 assert_eq!(DISTANCE_MAX, grad.get_distance(1));
135 assert_eq!(1, grad.get_target(1));
136 assert_eq!(DISTANCE_MAX, grad.get_distance(2));
137 assert_eq!(2, grad.get_target(2));
138 assert_eq!(DISTANCE_MAX, grad.get_distance(3));
139 assert_eq!(3, grad.get_target(3));
140 assert_eq!(DISTANCE_MAX, grad.get_distance(4));
141 assert_eq!(4, grad.get_target(4));
142 assert_eq!(DISTANCE_MAX, grad.get_distance(5));
144 assert_eq!(5, grad.get_target(5));
145 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(6));
146 assert_eq!(12, grad.get_target(6));
147 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(7));
148 assert_eq!(12, grad.get_target(7));
149 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(8));
150 assert_eq!(12, grad.get_target(8));
151 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(9));
152 assert_eq!(8, grad.get_target(9));
153 assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(10));
155 assert_eq!(6, grad.get_target(10));
156 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(11));
157 assert_eq!(12, grad.get_target(11));
158 assert_eq!(DISTANCE_MIN, grad.get_distance(12));
159 assert_eq!(12, grad.get_target(12));
160 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(13));
161 assert_eq!(12, grad.get_target(13));
162 assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(14));
163 assert_eq!(13, grad.get_target(14));
164 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(15));
166 assert_eq!(11, grad.get_target(15));
167 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(16));
168 assert_eq!(12, grad.get_target(16));
169 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(17));
170 assert_eq!(12, grad.get_target(17));
171 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(18));
172 assert_eq!(12, grad.get_target(18));
173 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(19));
174 assert_eq!(18, grad.get_target(19));
175 assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(20));
177 assert_eq!(16, grad.get_target(20));
178 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(21));
179 assert_eq!(16, grad.get_target(21));
180 assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(22));
181 assert_eq!(17, grad.get_target(22));
182 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(23));
183 assert_eq!(17, grad.get_target(23));
184 assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(24));
185 assert_eq!(18, grad.get_target(24));
186 }
187
188 #[test]
189 fn test_spread_backward() {
190 let rect = Full2D::new(5, 5);
191 let mut grad = Gradient::new_with_mesh(&rect);
192 grad.set_distance(12, 0.0);
193 grad.spread_backward(&rect, 0.0);
194 assert_eq!(DISTANCE_MAX, grad.get_distance(23));
196 assert_eq!(24, grad.get_target(24));
197 assert_eq!(DISTANCE_MAX, grad.get_distance(23));
198 assert_eq!(23, grad.get_target(23));
199 assert_eq!(DISTANCE_MAX, grad.get_distance(22));
200 assert_eq!(22, grad.get_target(22));
201 assert_eq!(DISTANCE_MAX, grad.get_distance(21));
202 assert_eq!(21, grad.get_target(21));
203 assert_eq!(DISTANCE_MAX, grad.get_distance(20));
204 assert_eq!(20, grad.get_target(20));
205 assert_eq!(DISTANCE_MAX, grad.get_distance(19));
207 assert_eq!(19, grad.get_target(19));
208 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(18));
209 assert_eq!(12, grad.get_target(18));
210 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(17));
211 assert_eq!(12, grad.get_target(17));
212 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(16));
213 assert_eq!(12, grad.get_target(16));
214 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(15));
215 assert_eq!(16, grad.get_target(15));
216 assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(14));
218 assert_eq!(18, grad.get_target(14));
219 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(13));
220 assert_eq!(12, grad.get_target(13));
221 assert_eq!(DISTANCE_MIN, grad.get_distance(12));
222 assert_eq!(12, grad.get_target(12));
223 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(11));
224 assert_eq!(12, grad.get_target(11));
225 assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(10));
226 assert_eq!(11, grad.get_target(10));
227 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(9));
229 assert_eq!(13, grad.get_target(9));
230 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(8));
231 assert_eq!(12, grad.get_target(8));
232 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(7));
233 assert_eq!(12, grad.get_target(7));
234 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(6));
235 assert_eq!(12, grad.get_target(6));
236 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(5));
237 assert_eq!(6, grad.get_target(5));
238 assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(4));
240 assert_eq!(8, grad.get_target(4));
241 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(3));
242 assert_eq!(8, grad.get_target(3));
243 assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(2));
244 assert_eq!(7, grad.get_target(2));
245 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(1));
246 assert_eq!(7, grad.get_target(1));
247 assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(0));
248 assert_eq!(6, grad.get_target(0));
249 }
250
251 #[test]
252 fn test_spread() {
253 let rect = Full2D::new(5, 5);
254 let mut grad = Gradient::new_with_mesh(&rect);
255 grad.set_distance(12, 0.0);
256 grad.spread(&rect);
257 assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(0));
259 assert_eq!(6, grad.get_target(0));
260 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(1));
261 assert_eq!(7, grad.get_target(1));
262 assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(2));
263 assert_eq!(7, grad.get_target(2));
264 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(3));
265 assert_eq!(8, grad.get_target(3));
266 assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(4));
267 assert_eq!(8, grad.get_target(4));
268 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(5));
270 assert_eq!(6, grad.get_target(5));
271 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(6));
272 assert_eq!(12, grad.get_target(6));
273 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(7));
274 assert_eq!(12, grad.get_target(7));
275 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(8));
276 assert_eq!(12, grad.get_target(8));
277 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(9));
278 assert_eq!(8, grad.get_target(9));
279 assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(10));
281 assert_eq!(11, grad.get_target(10));
282 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(11));
283 assert_eq!(12, grad.get_target(11));
284 assert_eq!(DISTANCE_MIN, grad.get_distance(12));
285 assert_eq!(12, grad.get_target(12));
286 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(13));
287 assert_eq!(12, grad.get_target(13));
288 assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(14));
289 assert_eq!(13, grad.get_target(14));
290 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(15));
292 assert_eq!(11, grad.get_target(15));
293 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(16));
294 assert_eq!(12, grad.get_target(16));
295 assert_eq!(DISTANCE_STRAIGHT, grad.get_distance(17));
296 assert_eq!(12, grad.get_target(17));
297 assert_eq!(DISTANCE_DIAGONAL, grad.get_distance(18));
298 assert_eq!(12, grad.get_target(18));
299 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(19));
300 assert_eq!(18, grad.get_target(19));
301 assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(20));
303 assert_eq!(16, grad.get_target(20));
304 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(21));
305 assert_eq!(16, grad.get_target(21));
306 assert_eq!(2.0 * DISTANCE_STRAIGHT, grad.get_distance(22));
307 assert_eq!(17, grad.get_target(22));
308 assert_eq!(DISTANCE_STRAIGHT + DISTANCE_DIAGONAL, grad.get_distance(23));
309 assert_eq!(17, grad.get_target(23));
310 assert_eq!(2.0 * DISTANCE_DIAGONAL, grad.get_distance(24));
311 assert_eq!(18, grad.get_target(24));
312 }
313}