use ndarray::Array2;
use std::collections::VecDeque;
use crate::error::{Error, Result};
use crate::kernel::grid_map;
pub const D8_E: u8 = 1;
pub const D8_SE: u8 = 2;
pub const D8_S: u8 = 4;
pub const D8_SW: u8 = 8;
pub const D8_W: u8 = 16;
pub const D8_NW: u8 = 32;
pub const D8_N: u8 = 64;
pub const D8_NE: u8 = 128;
const D8_OFFSETS: [(isize, isize); 8] = [
(0, 1), (1, 1), (1, 0), (1, -1), (0, -1), (-1, -1), (-1, 0), (-1, 1), ];
const D8_CODES: [u8; 8] = [D8_E, D8_SE, D8_S, D8_SW, D8_W, D8_NW, D8_N, D8_NE];
const DIAG: f64 = std::f64::consts::SQRT_2;
const D8_DIST: [f64; 8] = [1.0, DIAG, 1.0, DIAG, 1.0, DIAG, 1.0, DIAG];
fn downstream(fdir: &Array2<u8>, r: usize, c: usize, h: usize, w: usize) -> Option<(usize, usize)> {
let dir = fdir[[r, c]];
if dir == 0 {
return None;
}
let idx = D8_CODES.iter().position(|&d| d == dir)?;
let (dr, dc) = D8_OFFSETS[idx];
let nr = r as isize + dr;
let nc = c as isize + dc;
if nr >= 0 && nr < h as isize && nc >= 0 && nc < w as isize {
Some((nr as usize, nc as usize))
} else {
None
}
}
fn validate_pour_points(pour_points: &[(usize, usize)], h: usize, w: usize) -> Result<()> {
for &(row, col) in pour_points {
if row >= h || col >= w {
return Err(Error::PourPointOutOfBounds {
row,
col,
height: h,
width: w,
});
}
}
Ok(())
}
pub fn fill(dem: &Array2<f64>) -> Array2<f64> {
let (h, w) = dem.dim();
if h < 3 || w < 3 {
return dem.clone();
}
let eps = 1e-5;
let mut filled = Array2::from_elem((h, w), f64::INFINITY);
for r in 0..h {
for c in 0..w {
let is_edge = r == 0 || r == h - 1 || c == 0 || c == w - 1;
let is_nodata = dem[[r, c]].is_nan();
let touches_nodata = !is_nodata
&& D8_OFFSETS.iter().any(|&(dr, dc)| {
let nr = r as isize + dr;
let nc = c as isize + dc;
nr >= 0
&& nr < h as isize
&& nc >= 0
&& nc < w as isize
&& dem[[nr as usize, nc as usize]].is_nan()
});
if is_edge || is_nodata || touches_nodata {
filled[[r, c]] = dem[[r, c]];
}
}
}
let mut changed = true;
while changed {
changed = false;
for r in 1..h - 1 {
for c in 1..w - 1 {
if dem[[r, c]].is_nan() {
continue;
}
if filled[[r, c]] > dem[[r, c]] {
for &(dr, dc) in &D8_OFFSETS {
let nr = r as isize + dr;
let nc = c as isize + dc;
if nr >= 0 && nr < h as isize && nc >= 0 && nc < w as isize {
let nv = filled[[nr as usize, nc as usize]];
if !nv.is_nan() {
let candidate = nv + eps;
if dem[[r, c]] >= candidate {
filled[[r, c]] = dem[[r, c]];
changed = true;
} else if filled[[r, c]] > candidate {
filled[[r, c]] = candidate;
changed = true;
}
}
}
}
}
}
}
}
filled
}
pub fn flow_direction(dem: &Array2<f64>) -> Array2<u8> {
let (h, w) = dem.dim();
grid_map(h, w, |r, c| {
let z = dem[[r, c]];
if z.is_nan() {
return 0u8;
}
let mut max_drop = 0.0;
let mut best_dir: u8 = 0;
for i in 0..8 {
let nr = r as isize + D8_OFFSETS[i].0;
let nc = c as isize + D8_OFFSETS[i].1;
if nr >= 0 && nr < h as isize && nc >= 0 && nc < w as isize {
let nz = dem[[nr as usize, nc as usize]];
if !nz.is_nan() {
let drop = (z - nz) / D8_DIST[i];
if drop > max_drop {
max_drop = drop;
best_dir = D8_CODES[i];
}
}
}
}
best_dir
})
}
pub fn flow_accumulation(fdir: &Array2<u8>) -> Array2<f64> {
let (h, w) = fdir.dim();
let mut acc = Array2::ones((h, w));
let mut in_degree = Array2::<u32>::zeros((h, w));
for r in 0..h {
for c in 0..w {
if let Some((nr, nc)) = downstream(fdir, r, c, h, w) {
in_degree[[nr, nc]] += 1;
}
}
}
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
for r in 0..h {
for c in 0..w {
if in_degree[[r, c]] == 0 {
queue.push_back((r, c));
}
}
}
while let Some((r, c)) = queue.pop_front() {
if let Some((nr, nc)) = downstream(fdir, r, c, h, w) {
acc[[nr, nc]] += acc[[r, c]];
in_degree[[nr, nc]] -= 1;
if in_degree[[nr, nc]] == 0 {
queue.push_back((nr, nc));
}
}
}
acc
}
pub fn watershed(fdir: &Array2<u8>, pour_points: &[(usize, usize)]) -> Result<Array2<i32>> {
let (h, w) = fdir.dim();
validate_pour_points(pour_points, h, w)?;
let mut labels = Array2::zeros((h, w));
for (idx, &(pr, pc)) in pour_points.iter().enumerate() {
let label = (idx + 1) as i32;
let mut queue = VecDeque::new();
labels[[pr, pc]] = label;
queue.push_back((pr, pc));
while let Some((r, c)) = queue.pop_front() {
for (i, &(dr, dc)) in D8_OFFSETS.iter().enumerate() {
let nr = r as isize + dr;
let nc = c as isize + dc;
if nr >= 0 && nr < h as isize && nc >= 0 && nc < w as isize {
let nr = nr as usize;
let nc = nc as usize;
let reverse_idx = (i + 4) % 8;
if fdir[[nr, nc]] == D8_CODES[reverse_idx] && labels[[nr, nc]] == 0 {
labels[[nr, nc]] = label;
queue.push_back((nr, nc));
}
}
}
}
}
Ok(labels)
}
pub fn basin(fdir: &Array2<u8>) -> Array2<i32> {
let (h, w) = fdir.dim();
let mut labels = Array2::zeros((h, w));
let mut current_label = 0i32;
let mut outlets = Vec::new();
for r in 0..h {
for c in 0..w {
let is_boundary = r == 0 || r == h - 1 || c == 0 || c == w - 1;
let is_pit = fdir[[r, c]] == 0;
if is_boundary || is_pit {
outlets.push((r, c));
}
}
}
for &(or, oc) in &outlets {
if labels[[or, oc]] != 0 {
continue;
}
current_label += 1;
let mut queue = VecDeque::new();
labels[[or, oc]] = current_label;
queue.push_back((or, oc));
while let Some((r, c)) = queue.pop_front() {
for (i, &(dr, dc)) in D8_OFFSETS.iter().enumerate() {
let nr = r as isize + dr;
let nc = c as isize + dc;
if nr >= 0 && nr < h as isize && nc >= 0 && nc < w as isize {
let nr = nr as usize;
let nc = nc as usize;
let reverse_idx = (i + 4) % 8;
if fdir[[nr, nc]] == D8_CODES[reverse_idx] && labels[[nr, nc]] == 0 {
labels[[nr, nc]] = current_label;
queue.push_back((nr, nc));
}
}
}
}
}
labels
}
pub fn stream_order_strahler(
fdir: &Array2<u8>,
accumulation: &Array2<f64>,
threshold: f64,
) -> Result<Array2<i32>> {
let (h, w) = fdir.dim();
if accumulation.dim() != (h, w) {
return Err(Error::ShapeMismatch {
left: "fdir",
left_shape: (h, w),
right: "accumulation",
right_shape: accumulation.dim(),
});
}
let mut order = Array2::zeros((h, w));
let mut is_stream = Array2::from_elem((h, w), false);
for r in 0..h {
for c in 0..w {
if accumulation[[r, c]] >= threshold {
is_stream[[r, c]] = true;
}
}
}
let mut in_deg = Array2::<u32>::zeros((h, w));
for r in 0..h {
for c in 0..w {
if !is_stream[[r, c]] {
continue;
}
if let Some((nr, nc)) = downstream(fdir, r, c, h, w) {
if is_stream[[nr, nc]] {
in_deg[[nr, nc]] += 1;
}
}
}
}
let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
for r in 0..h {
for c in 0..w {
if is_stream[[r, c]] && in_deg[[r, c]] == 0 {
order[[r, c]] = 1;
queue.push_back((r, c));
}
}
}
let mut top_orders = Array2::from_elem((h, w), [0i32; 2]);
while let Some((r, c)) = queue.pop_front() {
let my_order = order[[r, c]];
if let Some((nr, nc)) = downstream(fdir, r, c, h, w) {
if is_stream[[nr, nc]] {
let top = &mut top_orders[[nr, nc]];
if my_order > top[0] {
top[1] = top[0];
top[0] = my_order;
} else if my_order > top[1] {
top[1] = my_order;
}
in_deg[[nr, nc]] -= 1;
if in_deg[[nr, nc]] == 0 {
let t = top_orders[[nr, nc]];
order[[nr, nc]] = if t[0] == 0 {
1
} else if t[0] == t[1] {
t[0] + 1
} else {
t[0]
};
queue.push_back((nr, nc));
}
}
}
}
Ok(order)
}
pub fn snap_pour_point(
accumulation: &Array2<f64>,
pour_points: &[(usize, usize)],
snap_distance: usize,
) -> Result<Vec<(usize, usize)>> {
let (h, w) = accumulation.dim();
validate_pour_points(pour_points, h, w)?;
let snapped = pour_points
.iter()
.map(|&(pr, pc)| {
let mut best_r = pr;
let mut best_c = pc;
let mut best_acc = accumulation[[pr, pc]];
let r_min = pr.saturating_sub(snap_distance);
let r_max = pr.saturating_add(snap_distance).saturating_add(1).min(h);
let c_min = pc.saturating_sub(snap_distance);
let c_max = pc.saturating_add(snap_distance).saturating_add(1).min(w);
let dist_sq = (snap_distance as f64) * (snap_distance as f64);
for r in r_min..r_max {
for c in c_min..c_max {
let dr = r as f64 - pr as f64;
let dc = c as f64 - pc as f64;
if dr * dr + dc * dc <= dist_sq && accumulation[[r, c]] > best_acc {
best_acc = accumulation[[r, c]];
best_r = r;
best_c = c;
}
}
}
(best_r, best_c)
})
.collect();
Ok(snapped)
}
#[cfg(test)]
mod tests {
use super::*;
fn slope_dem() -> Array2<f64> {
Array2::from_shape_fn((5, 5), |(r, c)| {
(4 - r) as f64 * 10.0 + (4 - c) as f64 * 10.0
})
}
#[test]
fn fill_no_sinks() {
let dem = slope_dem();
let filled = fill(&dem);
for r in 0..5 {
for c in 0..5 {
assert!(
(filled[[r, c]] - dem[[r, c]]).abs() < 1e-3,
"fill should not change a sink-free DEM at ({r}, {c})"
);
}
}
}
#[test]
fn fill_raises_sink() {
let mut dem = slope_dem();
dem[[2, 2]] = 0.0;
let filled = fill(&dem);
assert!(
filled[[2, 2]] > 0.0,
"sink should be filled above 0, got {}",
filled[[2, 2]]
);
}
#[test]
fn fill_small_grid() {
let dem = Array2::from_elem((2, 2), 10.0);
let filled = fill(&dem);
assert_eq!(filled.dim(), (2, 2));
}
#[test]
fn fill_preserves_finite_cell_enclosed_by_nodata() {
let dem = Array2::from_shape_vec(
(3, 3),
vec![
f64::NAN,
f64::NAN,
f64::NAN,
f64::NAN,
10.0,
f64::NAN,
f64::NAN,
f64::NAN,
f64::NAN,
],
)
.unwrap();
let filled = fill(&dem);
assert_eq!(filled[[1, 1]], 10.0);
assert!(filled[[1, 1]].is_finite());
for ((r, c), &z) in dem.indexed_iter() {
if z.is_nan() {
assert!(filled[[r, c]].is_nan(), "nodata changed at ({r}, {c})");
}
}
}
#[test]
fn flow_direction_se_slope() {
let dem = slope_dem();
let fdir = flow_direction(&dem);
assert_eq!(fdir[[2, 2]], D8_SE, "interior cell should flow SE");
}
#[test]
fn flow_direction_flat() {
let dem = Array2::from_elem((5, 5), 100.0);
let fdir = flow_direction(&dem);
for r in 1..4 {
for c in 1..4 {
assert_eq!(fdir[[r, c]], 0, "flat terrain should have direction 0");
}
}
}
#[test]
fn flow_accumulation_all_at_least_one() {
let dem = slope_dem();
let fdir = flow_direction(&dem);
let acc = flow_accumulation(&fdir);
for &v in acc.iter() {
assert!(v >= 1.0, "accumulation should be >= 1, got {v}");
}
}
#[test]
fn flow_accumulation_outlet_has_max() {
let dem = slope_dem();
let fdir = flow_direction(&dem);
let acc = flow_accumulation(&fdir);
let max_acc = acc.iter().cloned().fold(0.0f64, f64::max);
assert!(max_acc > 1.0, "max accumulation should be > 1");
}
#[test]
fn watershed_labels_upstream() {
let dem = slope_dem();
let fdir = flow_direction(&dem);
let ws = watershed(&fdir, &[(4, 4)]).unwrap();
assert_eq!(ws[[4, 4]], 1);
let count = ws.iter().filter(|&&v| v == 1).count();
assert!(count > 1, "watershed should contain multiple cells");
}
#[test]
fn watershed_rejects_out_of_bounds_pour_point() {
let fdir = Array2::from_elem((5, 5), 0);
assert!(matches!(
watershed(&fdir, &[(5, 0)]),
Err(Error::PourPointOutOfBounds {
row: 5,
col: 0,
height: 5,
width: 5,
})
));
}
#[test]
fn basin_labels_all() {
let dem = slope_dem();
let fdir = flow_direction(&dem);
let b = basin(&fdir);
for r in 0..5 {
for c in 0..5 {
assert!(
b[[r, c]] > 0,
"all cells should be labeled, got 0 at ({r}, {c})"
);
}
}
}
#[test]
fn stream_order_headwaters_are_one() {
let dem = slope_dem();
let fdir = flow_direction(&dem);
let acc = flow_accumulation(&fdir);
let order = stream_order_strahler(&fdir, &acc, 3.0).unwrap();
for r in 0..5 {
for c in 0..5 {
if acc[[r, c]] >= 3.0 {
assert!(
order[[r, c]] >= 1,
"stream cell at ({r},{c}) should have order >= 1"
);
}
}
}
}
#[test]
fn stream_order_rejects_shape_mismatch() {
let fdir = Array2::from_elem((3, 4), 0);
let acc = Array2::from_elem((3, 3), 1.0);
assert!(matches!(
stream_order_strahler(&fdir, &acc, 1.0),
Err(Error::ShapeMismatch {
left: "fdir",
left_shape: (3, 4),
right: "accumulation",
right_shape: (3, 3),
})
));
}
#[test]
fn snap_pour_point_finds_higher_acc() {
let dem = slope_dem();
let fdir = flow_direction(&dem);
let acc = flow_accumulation(&fdir);
let snapped = snap_pour_point(&acc, &[(3, 3)], 2).unwrap();
assert_eq!(snapped.len(), 1);
assert!(
acc[[snapped[0].0, snapped[0].1]] >= acc[[3, 3]],
"snapped point should have >= accumulation"
);
}
#[test]
fn snap_pour_point_rejects_out_of_bounds_pour_point() {
let acc = Array2::from_elem((5, 5), 1.0);
assert!(matches!(
snap_pour_point(&acc, &[(0, 5)], 2),
Err(Error::PourPointOutOfBounds {
row: 0,
col: 5,
height: 5,
width: 5,
})
));
}
#[test]
fn snap_pour_point_handles_huge_snap_distance() {
let mut acc = Array2::from_elem((2, 2), 1.0);
acc[[1, 1]] = 10.0;
let snapped = snap_pour_point(&acc, &[(0, 0)], usize::MAX).unwrap();
assert_eq!(snapped, vec![(1, 1)]);
}
}