Skip to main content

oxiphysics_geometry/signed_distance_field/
fmm.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Fast Marching Method for SDF initialization on a uniform 3D grid.
5
6use std::collections::BinaryHeap;
7
8// ─────────────────────────────────────────────────────────────────────────────
9// Fast Marching Method (FMM) for SDF initialization
10// ─────────────────────────────────────────────────────────────────────────────
11
12/// State of a grid cell during FMM.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14enum FmmState {
15    /// Final (accepted) value.
16    Known,
17    /// In the narrow band / heap.
18    Trial,
19    /// Not yet processed.
20    Far,
21}
22
23/// Entry in the FMM priority queue.
24#[derive(Debug, Clone, Copy)]
25struct FmmEntry {
26    /// Negative distance (max-heap used as min-heap).
27    neg_dist: f64,
28    /// Flat grid index.
29    idx: usize,
30}
31
32impl PartialEq for FmmEntry {
33    fn eq(&self, other: &Self) -> bool {
34        self.neg_dist == other.neg_dist
35    }
36}
37impl Eq for FmmEntry {}
38impl PartialOrd for FmmEntry {
39    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
40        Some(self.cmp(other))
41    }
42}
43impl Ord for FmmEntry {
44    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
45        self.neg_dist
46            .partial_cmp(&other.neg_dist)
47            .unwrap_or(std::cmp::Ordering::Equal)
48    }
49}
50
51/// Fast Marching Method SDF solver on a uniform 3D grid.
52///
53/// Given initial interface cells (known SDF values near zero), propagates
54/// the signed distance function throughout the grid.
55#[derive(Debug, Clone)]
56pub struct FastMarchingMethod {
57    /// Grid size.
58    pub nx: usize,
59    /// Grid size.
60    pub ny: usize,
61    /// Grid size.
62    pub nz: usize,
63    /// Grid spacing.
64    pub dx: f64,
65    /// Computed signed distances.
66    pub dist: Vec<f64>,
67    /// FMM state flags.
68    state: Vec<FmmState>,
69}
70
71impl FastMarchingMethod {
72    /// Construct a new FMM solver for a grid of given size and spacing.
73    pub fn new(nx: usize, ny: usize, nz: usize, dx: f64) -> Self {
74        let n = nx * ny * nz;
75        Self {
76            nx,
77            ny,
78            nz,
79            dx,
80            dist: vec![f64::MAX; n],
81            state: vec![FmmState::Far; n],
82        }
83    }
84
85    #[inline]
86    pub(crate) fn flat(&self, ix: usize, iy: usize, iz: usize) -> usize {
87        iz * self.ny * self.nx + iy * self.nx + ix
88    }
89
90    /// Set known interface cells from (index, distance) pairs.
91    pub fn set_known(&mut self, known: &[(usize, f64)]) {
92        for &(idx, d) in known {
93            if idx < self.dist.len() {
94                self.dist[idx] = d;
95                self.state[idx] = FmmState::Known;
96            }
97        }
98    }
99
100    /// Run the FMM to propagate distances from known cells.
101    pub fn run(&mut self) {
102        let mut heap: BinaryHeap<FmmEntry> = BinaryHeap::new();
103
104        // Seed with neighbours of known cells
105        for iz in 0..self.nz {
106            for iy in 0..self.ny {
107                for ix in 0..self.nx {
108                    let idx = self.flat(ix, iy, iz);
109                    if self.state[idx] == FmmState::Known {
110                        self.push_neighbours(ix, iy, iz, &mut heap);
111                    }
112                }
113            }
114        }
115
116        while let Some(entry) = heap.pop() {
117            let cidx = entry.idx;
118            if self.state[cidx] == FmmState::Known {
119                continue;
120            }
121            self.state[cidx] = FmmState::Known;
122            let iz = cidx / (self.ny * self.nx);
123            let rem = cidx % (self.ny * self.nx);
124            let iy = rem / self.nx;
125            let ix = rem % self.nx;
126            self.push_neighbours(ix, iy, iz, &mut heap);
127        }
128    }
129
130    fn push_neighbours(
131        &mut self,
132        ix: usize,
133        iy: usize,
134        iz: usize,
135        heap: &mut BinaryHeap<FmmEntry>,
136    ) {
137        let neighbors = self.get_neighbors(ix, iy, iz);
138        for (nx_i, ny_i, nz_i) in neighbors {
139            let nidx = self.flat(nx_i, ny_i, nz_i);
140            if self.state[nidx] == FmmState::Known {
141                continue;
142            }
143            let d = self.solve_eikonal(nx_i, ny_i, nz_i);
144            if d < self.dist[nidx] {
145                self.dist[nidx] = d;
146                self.state[nidx] = FmmState::Trial;
147                heap.push(FmmEntry {
148                    neg_dist: -d,
149                    idx: nidx,
150                });
151            }
152        }
153    }
154
155    fn get_neighbors(&self, ix: usize, iy: usize, iz: usize) -> Vec<(usize, usize, usize)> {
156        let mut ns = Vec::with_capacity(6);
157        if ix > 0 {
158            ns.push((ix - 1, iy, iz));
159        }
160        if ix + 1 < self.nx {
161            ns.push((ix + 1, iy, iz));
162        }
163        if iy > 0 {
164            ns.push((ix, iy - 1, iz));
165        }
166        if iy + 1 < self.ny {
167            ns.push((ix, iy + 1, iz));
168        }
169        if iz > 0 {
170            ns.push((ix, iy, iz - 1));
171        }
172        if iz + 1 < self.nz {
173            ns.push((ix, iy, iz + 1));
174        }
175        ns
176    }
177
178    fn solve_eikonal(&self, ix: usize, iy: usize, iz: usize) -> f64 {
179        // 1st-order upwind Eikonal: solve (dx1² + dy1² + dz1²) = dx²
180        let dx = self.dx;
181        let mut terms: [f64; 3] = [f64::MAX; 3];
182
183        // x-direction
184        let mut d_x = f64::MAX;
185        if ix > 0 {
186            d_x = d_x.min(self.dist[self.flat(ix - 1, iy, iz)]);
187        }
188        if ix + 1 < self.nx {
189            d_x = d_x.min(self.dist[self.flat(ix + 1, iy, iz)]);
190        }
191        terms[0] = d_x;
192
193        // y-direction
194        let mut d_y = f64::MAX;
195        if iy > 0 {
196            d_y = d_y.min(self.dist[self.flat(ix, iy - 1, iz)]);
197        }
198        if iy + 1 < self.ny {
199            d_y = d_y.min(self.dist[self.flat(ix, iy + 1, iz)]);
200        }
201        terms[1] = d_y;
202
203        // z-direction
204        let mut d_z = f64::MAX;
205        if iz > 0 {
206            d_z = d_z.min(self.dist[self.flat(ix, iy, iz - 1)]);
207        }
208        if iz + 1 < self.nz {
209            d_z = d_z.min(self.dist[self.flat(ix, iy, iz + 1)]);
210        }
211        terms[2] = d_z;
212
213        terms.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
214
215        // Quadratic solve: try adding terms one by one
216        for k in 1..=3 {
217            let valid: Vec<f64> = terms[..k]
218                .iter()
219                .filter(|&&t| t < f64::MAX)
220                .copied()
221                .collect();
222            if valid.is_empty() {
223                continue;
224            }
225            let sum_t = valid.iter().sum::<f64>();
226            let sum_t2 = valid.iter().map(|t| t * t).sum::<f64>();
227            let n_v = valid.len() as f64;
228            let discriminant = sum_t * sum_t - n_v * (sum_t2 - dx * dx);
229            if discriminant >= 0.0 {
230                let sol = (sum_t + discriminant.sqrt()) / n_v;
231                if k == 1 || sol > *valid.last().expect("collection should not be empty") {
232                    return sol;
233                }
234            }
235        }
236
237        // Fallback: nearest neighbour + one cell
238        terms
239            .iter()
240            .copied()
241            .filter(|&t| t < f64::MAX)
242            .fold(f64::MAX, f64::min)
243            + dx
244    }
245
246    /// Get the distance at grid index (ix, iy, iz).
247    pub fn distance_at(&self, ix: usize, iy: usize, iz: usize) -> f64 {
248        self.dist[self.flat(ix, iy, iz)]
249    }
250}