edt/
fast_marcher.rs

1use super::BoolLike;
2use std::{
3    cmp::{Ordering, Reverse},
4    collections::BinaryHeap,
5    ops::{Index, IndexMut},
6};
7
8/// Shorthand function for EDT using Fast Marching method.
9///
10/// Fast Marching method is inexact, but much faster algorithm to compute EDT especially for large images.
11pub fn edt_fmm<T: BoolLike>(map: &[T], shape: (usize, usize), invert: bool) -> Vec<f64> {
12    let mut grid = Grid {
13        storage: map
14            .iter()
15            .map(|b| ((b.as_bool() != invert) as usize) as f64)
16            .collect::<Vec<f64>>(),
17        dims: shape,
18    };
19    let mut fast_marcher = FastMarcher::new_from_map(&grid, shape);
20
21    fast_marcher.evolve(&mut grid);
22
23    grid.storage
24}
25
26/// EDT with Fast Marching method with a callback.
27///
28/// The callback can terminate the process by returning false.
29pub fn edt_fmm_cb<T: BoolLike>(
30    map: &[T],
31    shape: (usize, usize),
32    invert: bool,
33    callback: impl FnMut(FMMCallbackData) -> bool,
34) -> Vec<f64> {
35    let mut grid = Grid {
36        storage: map
37            .iter()
38            .map(|b| ((b.as_bool() != invert) as usize) as f64)
39            .collect::<Vec<f64>>(),
40        dims: shape,
41    };
42    let mut fast_marcher = FastMarcher::new_from_map(&grid, shape);
43
44    fast_marcher.evolve_cb(&mut grid, callback);
45
46    grid.storage
47}
48
49/// A type representing a position in Grid
50pub type GridPos = (usize, usize);
51
52pub(super) struct Grid {
53    pub storage: Vec<f64>,
54    pub dims: (usize, usize),
55}
56
57impl Grid {
58    pub(super) fn find_boundary(&self) -> Vec<GridPos> {
59        // let storage = self.storage.as_ref();
60        let mut boundary = Vec::new();
61        for y in 0..self.dims.1 {
62            for x in 0..self.dims.0 {
63                if self[(x, y)] != 0.
64                    && (x < 1
65                        || self[(x - 1, y)] == 0.
66                        || y < 1
67                        || self[(x, y - 1)] == 0.
68                        || self.dims.0 <= x + 1
69                        || self[(x + 1, y)] == 0.
70                        || self.dims.1 <= y + 1
71                        || self[(x, y + 1)] == 0.)
72                {
73                    let pos = (x, y);
74                    boundary.push(pos);
75                }
76            }
77        }
78
79        boundary
80    }
81}
82
83impl Index<GridPos> for Grid {
84    type Output = f64;
85    fn index(&self, pos: GridPos) -> &Self::Output {
86        let idx = pos.1 * self.dims.0 + pos.0;
87        self.storage.index(idx)
88    }
89}
90
91impl IndexMut<GridPos> for Grid {
92    fn index_mut(&mut self, pos: GridPos) -> &mut Self::Output {
93        let idx = pos.1 * self.dims.0 + pos.0;
94        self.storage.index_mut(idx)
95    }
96}
97
98#[derive(Clone)]
99pub(super) struct NextCell {
100    pos: GridPos,
101    cost: f64,
102}
103
104impl PartialEq for NextCell {
105    fn eq(&self, other: &Self) -> bool {
106        self.cost.eq(&other.cost)
107    }
108}
109
110impl Eq for NextCell {}
111
112impl PartialOrd for NextCell {
113    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
114        Reverse(self.cost).partial_cmp(&Reverse(other.cost))
115    }
116}
117
118impl Ord for NextCell {
119    fn cmp(&self, other: &Self) -> Ordering {
120        self.partial_cmp(other).unwrap_or(Ordering::Equal)
121    }
122}
123
124#[derive(Clone)]
125pub(super) struct FastMarcher {
126    next_cells: BinaryHeap<NextCell>,
127    visited: Vec<f64>,
128    dims: (usize, usize),
129}
130
131impl FastMarcher {
132    pub(super) fn new_from_map(grid: &Grid, dims: (usize, usize)) -> Self {
133        Self::new(grid.find_boundary().into_iter(), dims)
134    }
135
136    pub(super) fn new(next_cells: impl Iterator<Item = GridPos>, dims: (usize, usize)) -> Self {
137        let next_cells: BinaryHeap<_> = next_cells
138            .map(|gpos| NextCell {
139                pos: gpos,
140                cost: 1.,
141            })
142            .collect();
143        let mut visited = vec![0.; dims.0 * dims.1];
144        for NextCell { pos: (x, y), .. } in &next_cells {
145            visited[x + y * dims.0] = 1.;
146        }
147        Self {
148            next_cells,
149            visited,
150            dims,
151        }
152    }
153
154    /// Returns whether a pixel has changed; if not, there is no point iterating again
155    fn evolve_single(&mut self, grid: &mut Grid) -> bool {
156        while let Some(next) = self.next_cells.pop() {
157            let x = next.pos.0 as isize;
158            let y = next.pos.1 as isize;
159
160            let mut check_neighbor = |x, y| {
161                if x < 0 || self.dims.0 as isize <= x || y < 0 || self.dims.1 as isize <= y {
162                    return false;
163                }
164                let get_visited = |x, y| {
165                    if x < 0 || self.dims.0 as isize <= x || y < 0 || self.dims.1 as isize <= y {
166                        0.
167                    } else {
168                        self.visited[x as usize + y as usize * self.dims.0]
169                    }
170                };
171                let delta_1d = |p: f64, n: f64| {
172                    if p == 0. && n == 0. {
173                        None
174                    } else if p == 0. {
175                        Some(n)
176                    } else if n == 0. {
177                        Some(p)
178                    } else {
179                        Some(p.min(n))
180                    }
181                };
182                let u_h = delta_1d(get_visited(x + 1, y), get_visited(x - 1, y));
183                let u_v = delta_1d(get_visited(x, y + 1), get_visited(x, y - 1));
184                let next_cost = match (u_h, u_v) {
185                    (Some(u_h), Some(u_v)) => {
186                        let delta = 2. - (u_v - u_h).powf(2.);
187                        if delta < 0. {
188                            u_h.min(u_v) + 1.
189                        } else {
190                            (u_v + u_h + delta.sqrt()) / 2.
191                        }
192                    }
193                    (Some(u_h), None) => u_h + 1.,
194                    (None, Some(u_v)) => u_v + 1.,
195                    _ => panic!("No way"),
196                };
197                let (x, y) = (x as usize, y as usize);
198                let visited = self.visited[x + y * self.dims.0];
199                if (visited == 0. || next_cost < visited) && grid[(x, y)] != 0. {
200                    self.visited[x + y * self.dims.0] = next_cost;
201                    let pos = (x, y);
202                    let cost = next_cost;
203                    grid[pos] = cost;
204                    self.next_cells.push(NextCell {
205                        pos,
206                        cost: next_cost,
207                    });
208                    true
209                } else {
210                    false
211                }
212            };
213            let mut f = false;
214            f |= check_neighbor(x - 1, y);
215            f |= check_neighbor(x, y - 1);
216            f |= check_neighbor(x + 1, y);
217            f |= check_neighbor(x, y + 1);
218            if f {
219                return true;
220            }
221        }
222        false
223    }
224}
225
226#[non_exhaustive]
227/// A type that will be given as the argument to the callback with [`crate::edt_fmm_cb`].
228///
229/// It has `non_exhaustive` attribute so that the library can add more arguments in
230/// the future.
231pub struct FMMCallbackData<'src> {
232    /// The buffer for Fast Marching output in progress.
233    pub map: &'src [f64],
234    /// A dynamically dispatched iterator for positions of next pixels.
235    ///
236    /// You can examine "expanding wavefront" by iterating this iterator.
237    pub next_pixels: &'src mut dyn Iterator<Item = GridPos>,
238}
239
240impl FastMarcher {
241    pub(super) fn evolve_cb(
242        &mut self,
243        grid: &mut Grid,
244        mut callback: impl FnMut(FMMCallbackData) -> bool,
245    ) {
246        while self.evolve_single(grid) {
247            if !callback(FMMCallbackData {
248                map: &grid.storage,
249                next_pixels: &mut self.next_cells.iter().map(|nc| nc.pos),
250            }) {
251                return;
252            }
253        }
254    }
255
256    pub(super) fn evolve(&mut self, grid: &mut Grid) {
257        loop {
258            if !self.evolve_single(grid) {
259                break;
260            }
261        }
262    }
263}
264
265#[cfg(test)]
266mod test {
267    use super::*;
268    use crate::test_util::*;
269
270    fn approx_eq(a: f64, b: f64) {
271        if a == 0. && b == 0. {
272            return;
273        }
274        let rel_err = (a - b).abs() / a.abs().max(b.abs());
275        assert!(rel_err < 0.2, "a: {}, b: {}", a, b);
276    }
277
278    #[test]
279    fn test_edt() {
280        let map = test_map();
281        let str_edt = [
282            "0000000000",
283            "0001111000",
284            "0013443110",
285            "0013443100",
286            "0001111000",
287        ];
288        let shape = (map.len() / str_edt.len(), str_edt.len());
289        let mut edt = edt_fmm(&map, shape, false);
290        for cell in &mut edt {
291            *cell = cell.powf(2.);
292        }
293        eprintln!("edt({:?}):", shape);
294        print_2d(&reshape(&edt, shape));
295        for (a, b) in edt.iter().zip(parse_edt_str(&str_edt).iter()) {
296            approx_eq(*a, *b);
297        }
298    }
299}