Skip to main content

oxigdal_algorithms/raster/hydrology/
flow_direction.rs

1//! Flow direction algorithms for hydrological analysis
2//!
3//! Implements D8 (Jenson & Domingue, 1988), D-Infinity (Tarboton, 1997),
4//! and MFD (Multiple Flow Direction, Freeman 1991 / Quinn et al. 1991) algorithms.
5//! Includes flat area resolution (Garbrecht & Martz, 1997) and proper pit handling.
6
7use crate::error::{AlgorithmError, Result};
8use oxigdal_core::buffer::RasterBuffer;
9use oxigdal_core::types::RasterDataType;
10use std::collections::VecDeque;
11
12// ---------------------------------------------------------------------------
13// Direction constants and types
14// ---------------------------------------------------------------------------
15
16/// Flow direction method
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum FlowMethod {
19    /// D8 (8-direction) method -- Jenson & Domingue (1988)
20    D8,
21    /// D-infinity method (continuous flow direction) -- Tarboton (1997)
22    DInfinity,
23    /// Multiple Flow Direction -- Freeman (1991)
24    MFD,
25}
26
27/// D8 flow direction codes (powers of 2, ESRI convention)
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29#[repr(u8)]
30pub enum D8Direction {
31    /// East (1)
32    East = 1,
33    /// Southeast (2)
34    Southeast = 2,
35    /// South (4)
36    South = 4,
37    /// Southwest (8)
38    Southwest = 8,
39    /// West (16)
40    West = 16,
41    /// Northwest (32)
42    Northwest = 32,
43    /// North (64)
44    North = 64,
45    /// Northeast (128)
46    Northeast = 128,
47}
48
49/// Sentinel value for flat cells (no steepest downslope neighbour)
50pub const D8_FLAT: u8 = 0;
51
52/// Sentinel value for pit cells (lower than all neighbours)
53pub const D8_PIT: u8 = 255;
54
55/// Neighbour offsets in D8 order: E, SE, S, SW, W, NW, N, NE
56pub const D8_DX: [i64; 8] = [1, 1, 0, -1, -1, -1, 0, 1];
57/// Neighbour Y offsets in D8 order: E, SE, S, SW, W, NW, N, NE
58pub const D8_DY: [i64; 8] = [0, 1, 1, 1, 0, -1, -1, -1];
59const D8_CODES: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128];
60
61impl D8Direction {
62    /// Gets the (dx, dy) offset for this direction
63    #[must_use]
64    pub fn offset(&self) -> (i64, i64) {
65        let idx = self.index();
66        (D8_DX[idx], D8_DY[idx])
67    }
68
69    /// Returns the index [0..8) in the canonical D8 neighbour array
70    #[must_use]
71    pub fn index(&self) -> usize {
72        match self {
73            Self::East => 0,
74            Self::Southeast => 1,
75            Self::South => 2,
76            Self::Southwest => 3,
77            Self::West => 4,
78            Self::Northwest => 5,
79            Self::North => 6,
80            Self::Northeast => 7,
81        }
82    }
83
84    /// Returns all eight D8 directions in canonical order
85    #[must_use]
86    pub fn all() -> [Self; 8] {
87        [
88            Self::East,
89            Self::Southeast,
90            Self::South,
91            Self::Southwest,
92            Self::West,
93            Self::Northwest,
94            Self::North,
95            Self::Northeast,
96        ]
97    }
98
99    /// Gets the angle in degrees (0 = East, clockwise)
100    #[must_use]
101    pub fn angle_degrees(&self) -> f64 {
102        match self {
103            Self::East => 0.0,
104            Self::Southeast => 45.0,
105            Self::South => 90.0,
106            Self::Southwest => 135.0,
107            Self::West => 180.0,
108            Self::Northwest => 225.0,
109            Self::North => 270.0,
110            Self::Northeast => 315.0,
111        }
112    }
113
114    /// Returns the opposite direction
115    #[must_use]
116    pub fn opposite(&self) -> Self {
117        match self {
118            Self::East => Self::West,
119            Self::Southeast => Self::Northwest,
120            Self::South => Self::North,
121            Self::Southwest => Self::Northeast,
122            Self::West => Self::East,
123            Self::Northwest => Self::Southeast,
124            Self::North => Self::South,
125            Self::Northeast => Self::Southwest,
126        }
127    }
128
129    /// Construct from D8 code (power-of-2)
130    #[must_use]
131    pub fn from_code(code: u8) -> Option<Self> {
132        match code {
133            1 => Some(Self::East),
134            2 => Some(Self::Southeast),
135            4 => Some(Self::South),
136            8 => Some(Self::Southwest),
137            16 => Some(Self::West),
138            32 => Some(Self::Northwest),
139            64 => Some(Self::North),
140            128 => Some(Self::Northeast),
141            _ => None,
142        }
143    }
144}
145
146// ---------------------------------------------------------------------------
147// Helper: distance factor for D8 directions
148// ---------------------------------------------------------------------------
149
150/// Returns the distance factor for a D8 neighbour index (1.0 cardinal, sqrt(2) diagonal)
151#[inline]
152fn d8_distance(idx: usize) -> f64 {
153    if idx % 2 == 0 {
154        1.0
155    } else {
156        std::f64::consts::SQRT_2
157    }
158}
159
160// ---------------------------------------------------------------------------
161// Inline helpers for bounds
162// ---------------------------------------------------------------------------
163
164#[inline]
165fn in_bounds(x: i64, y: i64, w: u64, h: u64) -> bool {
166    x >= 0 && y >= 0 && (x as u64) < w && (y as u64) < h
167}
168
169// ---------------------------------------------------------------------------
170// D8 flow direction  (Jenson & Domingue, 1988)
171// ---------------------------------------------------------------------------
172
173/// Configuration for D8 flow direction computation
174#[derive(Debug, Clone)]
175pub struct D8Config {
176    /// Cell size in map units
177    pub cell_size: f64,
178    /// Whether to resolve flat areas using Garbrecht & Martz (1997)
179    pub resolve_flats: bool,
180}
181
182impl Default for D8Config {
183    fn default() -> Self {
184        Self {
185            cell_size: 1.0,
186            resolve_flats: true,
187        }
188    }
189}
190
191/// Computes D8 flow direction from a DEM.
192///
193/// Each cell receives the code of the steepest-descent neighbour.
194/// Flat cells get `D8_FLAT` (0), pit cells get `D8_PIT` (255) unless
195/// `resolve_flats` is enabled, in which case flats are resolved via the
196/// Garbrecht & Martz (1997) algorithm.
197///
198/// # Errors
199///
200/// Returns an error if raster pixel access fails.
201pub fn compute_d8_flow_direction(dem: &RasterBuffer, cell_size: f64) -> Result<RasterBuffer> {
202    let cfg = D8Config {
203        cell_size,
204        resolve_flats: true,
205    };
206    compute_d8_flow_direction_cfg(dem, &cfg)
207}
208
209/// Computes D8 flow direction with full configuration
210///
211/// # Errors
212///
213/// Returns an error if raster pixel access fails.
214pub fn compute_d8_flow_direction_cfg(dem: &RasterBuffer, cfg: &D8Config) -> Result<RasterBuffer> {
215    let w = dem.width();
216    let h = dem.height();
217    let mut flow_dir = RasterBuffer::zeros(w, h, RasterDataType::UInt8);
218
219    for y in 0..h {
220        for x in 0..w {
221            let center = dem.get_pixel(x, y).map_err(AlgorithmError::Core)?;
222            let mut max_slope = 0.0_f64; // only accept downhill
223            let mut best_code: u8 = D8_FLAT;
224            let mut is_pit = true;
225
226            for i in 0..8 {
227                let nx = x as i64 + D8_DX[i];
228                let ny = y as i64 + D8_DY[i];
229                if !in_bounds(nx, ny, w, h) {
230                    // Boundary cells drain outward -- they are never pits
231                    is_pit = false;
232                    continue;
233                }
234                let ne = dem
235                    .get_pixel(nx as u64, ny as u64)
236                    .map_err(AlgorithmError::Core)?;
237                if ne < center {
238                    is_pit = false;
239                    let dist = d8_distance(i) * cfg.cell_size;
240                    let slope = (center - ne) / dist;
241                    if slope > max_slope {
242                        max_slope = slope;
243                        best_code = D8_CODES[i];
244                    }
245                } else if (ne - center).abs() < f64::EPSILON {
246                    // Flat neighbour -- not a pit
247                    is_pit = false;
248                }
249            }
250
251            if is_pit && best_code == D8_FLAT {
252                best_code = D8_PIT;
253            }
254            flow_dir
255                .set_pixel(x, y, f64::from(best_code))
256                .map_err(AlgorithmError::Core)?;
257        }
258    }
259
260    // Resolve flat areas if requested
261    if cfg.resolve_flats {
262        resolve_flat_areas(dem, &mut flow_dir)?;
263    }
264
265    Ok(flow_dir)
266}
267
268// ---------------------------------------------------------------------------
269// Flat area resolution  (Garbrecht & Martz, 1997)
270// ---------------------------------------------------------------------------
271
272/// Resolves flat areas in a D8 flow direction grid.
273///
274/// Uses a two-pass approach inspired by Garbrecht & Martz (1997):
275///   1. Gradient *away from* higher terrain (increments toward lower edges)
276///   2. Gradient *toward* lower terrain (increments toward outlets)
277///
278/// The combined surface is used to assign flow directions to formerly flat cells.
279fn resolve_flat_areas(dem: &RasterBuffer, flow_dir: &mut RasterBuffer) -> Result<()> {
280    let w = dem.width();
281    let h = dem.height();
282
283    // Identify flat cells (code == D8_FLAT)
284    let mut is_flat = vec![false; (w * h) as usize];
285    let mut has_flat = false;
286
287    for y in 0..h {
288        for x in 0..w {
289            let code = flow_dir.get_pixel(x, y).map_err(AlgorithmError::Core)? as u8;
290            if code == D8_FLAT {
291                is_flat[(y * w + x) as usize] = true;
292                has_flat = true;
293            }
294        }
295    }
296
297    if !has_flat {
298        return Ok(());
299    }
300
301    // Build increment surfaces
302    let toward_lower = build_toward_lower_gradient(dem, &is_flat, w, h)?;
303    let away_higher = build_away_higher_gradient(dem, &is_flat, w, h)?;
304
305    // Combine: virtual elevation = toward_lower + away_higher
306    // Then assign D8 directions on the combined surface for flat cells
307    for y in 0..h {
308        for x in 0..w {
309            let idx = (y * w + x) as usize;
310            if !is_flat[idx] {
311                continue;
312            }
313            let combined = toward_lower[idx] + away_higher[idx];
314            let mut max_drop = 0.0_f64;
315            let mut best_code: u8 = D8_FLAT;
316
317            for i in 0..8 {
318                let nx = x as i64 + D8_DX[i];
319                let ny = y as i64 + D8_DY[i];
320                if !in_bounds(nx, ny, w, h) {
321                    continue;
322                }
323                let nidx = (ny as u64 * w + nx as u64) as usize;
324                let n_combined = toward_lower[nidx] + away_higher[nidx];
325                let drop = (combined - n_combined) / d8_distance(i);
326                if drop > max_drop {
327                    max_drop = drop;
328                    best_code = D8_CODES[i];
329                }
330            }
331            flow_dir
332                .set_pixel(x, y, f64::from(best_code))
333                .map_err(AlgorithmError::Core)?;
334        }
335    }
336
337    Ok(())
338}
339
340/// BFS from flat-cell edges that border *lower* terrain, assigning increasing
341/// values inward. This creates a gradient *toward* lower terrain.
342fn build_toward_lower_gradient(
343    dem: &RasterBuffer,
344    is_flat: &[bool],
345    w: u64,
346    h: u64,
347) -> Result<Vec<f64>> {
348    let n = (w * h) as usize;
349    let mut grad = vec![0.0_f64; n];
350    let mut visited = vec![false; n];
351    let mut queue = VecDeque::new();
352
353    // Seed: flat cells adjacent to lower cells
354    for y in 0..h {
355        for x in 0..w {
356            let idx = (y * w + x) as usize;
357            if !is_flat[idx] {
358                continue;
359            }
360            let center = dem.get_pixel(x, y).map_err(AlgorithmError::Core)?;
361            let mut borders_lower = false;
362            for i in 0..8 {
363                let nx = x as i64 + D8_DX[i];
364                let ny = y as i64 + D8_DY[i];
365                if !in_bounds(nx, ny, w, h) {
366                    // Edge cells can drain outward
367                    borders_lower = true;
368                    break;
369                }
370                let ne = dem
371                    .get_pixel(nx as u64, ny as u64)
372                    .map_err(AlgorithmError::Core)?;
373                if ne < center {
374                    borders_lower = true;
375                    break;
376                }
377            }
378            if borders_lower {
379                queue.push_back((x, y));
380                visited[idx] = true;
381                grad[idx] = 0.0;
382            }
383        }
384    }
385
386    // BFS inward
387    while let Some((x, y)) = queue.pop_front() {
388        let idx = (y * w + x) as usize;
389        let cur_val = grad[idx];
390        for i in 0..8 {
391            let nx = x as i64 + D8_DX[i];
392            let ny = y as i64 + D8_DY[i];
393            if !in_bounds(nx, ny, w, h) {
394                continue;
395            }
396            let nidx = (ny as u64 * w + nx as u64) as usize;
397            if !is_flat[nidx] || visited[nidx] {
398                continue;
399            }
400            visited[nidx] = true;
401            grad[nidx] = cur_val + 1.0;
402            queue.push_back((nx as u64, ny as u64));
403        }
404    }
405
406    Ok(grad)
407}
408
409/// BFS from flat-cell edges that border *higher* terrain, assigning increasing
410/// values inward. This creates a gradient *away from* higher terrain.
411fn build_away_higher_gradient(
412    dem: &RasterBuffer,
413    is_flat: &[bool],
414    w: u64,
415    h: u64,
416) -> Result<Vec<f64>> {
417    let n = (w * h) as usize;
418    let mut grad = vec![0.0_f64; n];
419    let mut visited = vec![false; n];
420    let mut queue = VecDeque::new();
421
422    // Seed: flat cells adjacent to higher cells
423    for y in 0..h {
424        for x in 0..w {
425            let idx = (y * w + x) as usize;
426            if !is_flat[idx] {
427                continue;
428            }
429            let center = dem.get_pixel(x, y).map_err(AlgorithmError::Core)?;
430            let mut borders_higher = false;
431            for i in 0..8 {
432                let nx = x as i64 + D8_DX[i];
433                let ny = y as i64 + D8_DY[i];
434                if !in_bounds(nx, ny, w, h) {
435                    continue;
436                }
437                let ne = dem
438                    .get_pixel(nx as u64, ny as u64)
439                    .map_err(AlgorithmError::Core)?;
440                if ne > center {
441                    borders_higher = true;
442                    break;
443                }
444            }
445            if borders_higher {
446                queue.push_back((x, y));
447                visited[idx] = true;
448                grad[idx] = 0.0;
449            }
450        }
451    }
452
453    // BFS inward
454    while let Some((x, y)) = queue.pop_front() {
455        let idx = (y * w + x) as usize;
456        let cur_val = grad[idx];
457        for i in 0..8 {
458            let nx = x as i64 + D8_DX[i];
459            let ny = y as i64 + D8_DY[i];
460            if !in_bounds(nx, ny, w, h) {
461                continue;
462            }
463            let nidx = (ny as u64 * w + nx as u64) as usize;
464            if !is_flat[nidx] || visited[nidx] {
465                continue;
466            }
467            visited[nidx] = true;
468            grad[nidx] = cur_val + 1.0;
469            queue.push_back((nx as u64, ny as u64));
470        }
471    }
472
473    Ok(grad)
474}
475
476// ---------------------------------------------------------------------------
477// D-Infinity flow direction  (Tarboton, 1997)
478// ---------------------------------------------------------------------------
479
480/// Computes D-infinity flow direction.
481///
482/// Returns two rasters:
483///   - **angle** in radians [0, 2 pi), measured counter-clockwise from East
484///   - **proportion**: fraction of flow going to the "left" cell of the
485///     steepest triangular facet (the remainder goes to the "right" cell).
486///
487/// Based on Tarboton (1997) "A new method for the determination of flow
488/// directions and upslope areas in grid digital elevation models."
489///
490/// # Errors
491///
492/// Returns an error if raster pixel access fails.
493pub fn compute_dinf_flow_direction(
494    dem: &RasterBuffer,
495    cell_size: f64,
496) -> Result<(RasterBuffer, RasterBuffer)> {
497    let w = dem.width();
498    let h = dem.height();
499    let mut flow_angle = RasterBuffer::zeros(w, h, RasterDataType::Float64);
500    let mut flow_prop = RasterBuffer::zeros(w, h, RasterDataType::Float64);
501
502    for y in 0..h {
503        for x in 0..w {
504            let center = dem.get_pixel(x, y).map_err(AlgorithmError::Core)?;
505
506            // Gather 8-neighbour elevations (use center if out of bounds)
507            let mut e = [0.0_f64; 8];
508            for i in 0..8 {
509                let nx = x as i64 + D8_DX[i];
510                let ny = y as i64 + D8_DY[i];
511                e[i] = if in_bounds(nx, ny, w, h) {
512                    dem.get_pixel(nx as u64, ny as u64)
513                        .map_err(AlgorithmError::Core)?
514                } else {
515                    center // treat boundary as same elevation (no flow outward via Dinf)
516                };
517            }
518
519            let (angle, prop) = dinf_facet_steepest(center, &e, cell_size);
520            flow_angle
521                .set_pixel(x, y, angle)
522                .map_err(AlgorithmError::Core)?;
523            flow_prop
524                .set_pixel(x, y, prop)
525                .map_err(AlgorithmError::Core)?;
526        }
527    }
528
529    Ok((flow_angle, flow_prop))
530}
531
532/// Finds the steepest facet and returns (angle_degrees, proportion).
533///
534/// The eight triangular facets are formed by the center cell and each pair of
535/// adjacent neighbours. The steepest downhill slope across all facets determines
536/// the flow direction.
537fn dinf_facet_steepest(center: f64, e: &[f64; 8], cell_size: f64) -> (f64, f64) {
538    let mut max_slope = f64::NEG_INFINITY;
539    let mut best_angle = 0.0_f64;
540    let mut best_prop = 1.0_f64;
541
542    let d1 = cell_size;
543    let d2 = cell_size;
544    let dd = cell_size * std::f64::consts::SQRT_2;
545
546    // 8 facets: each formed by (e[i], e[(i+1)%8])
547    for i in 0..8 {
548        let j = (i + 1) % 8;
549        let e1 = e[i];
550        let e2 = e[j];
551
552        // Distance to e1 and e2
553        let dist1 = if i % 2 == 0 { d1 } else { dd };
554        let dist2 = if j % 2 == 0 { d2 } else { dd };
555
556        let s1 = (center - e1) / dist1;
557        let s2 = (center - e2) / dist2;
558
559        let base_angle = (i as f64) * 45.0;
560
561        if s1 > 0.0 || s2 > 0.0 {
562            // Check if flow direction lies within the facet
563            if s1 > 0.0 && s2 > 0.0 {
564                let facet_angle = (s2 / s1).atan().to_degrees();
565                if (0.0..=45.0).contains(&facet_angle) {
566                    let s = (s1 * s1 + s2 * s2).sqrt();
567                    if s > max_slope {
568                        max_slope = s;
569                        best_angle = base_angle + facet_angle;
570                        // Proportion to the "left" cell (e1 side)
571                        best_prop = 1.0 - facet_angle / 45.0;
572                    }
573                    continue;
574                }
575            }
576
577            // Check along edge 1 (angle = base_angle)
578            if s1 > max_slope {
579                max_slope = s1;
580                best_angle = base_angle;
581                best_prop = 1.0;
582            }
583            // Check along edge 2 (angle = base_angle + 45)
584            if s2 > max_slope {
585                max_slope = s2;
586                best_angle = base_angle + 45.0;
587                best_prop = 0.0;
588            }
589        }
590    }
591
592    // Normalise angle to [0, 360)
593    best_angle = best_angle.rem_euclid(360.0);
594    (best_angle, best_prop)
595}
596
597// ---------------------------------------------------------------------------
598// MFD -- Multiple Flow Direction  (Freeman, 1991)
599// ---------------------------------------------------------------------------
600
601/// Configuration for MFD computation
602#[derive(Debug, Clone)]
603pub struct MfdConfig {
604    /// Cell size in map units
605    pub cell_size: f64,
606    /// Exponent controlling flow partitioning convergence.
607    /// Larger values concentrate flow into the steepest path.
608    /// Freeman (1991) recommends p = 1.1; Holmgren (1994) uses p = 4-8.
609    pub exponent: f64,
610}
611
612impl Default for MfdConfig {
613    fn default() -> Self {
614        Self {
615            cell_size: 1.0,
616            exponent: 1.1,
617        }
618    }
619}
620
621/// MFD result: for each cell, a vector of (neighbour_index, proportion).
622/// Stored as a flat array of 8 proportions per cell.
623pub struct MfdResult {
624    /// Width of the raster
625    pub width: u64,
626    /// Height of the raster
627    pub height: u64,
628    /// Proportions: `proportions[cell_idx * 8 + dir_idx]`
629    pub proportions: Vec<f64>,
630}
631
632impl MfdResult {
633    /// Gets flow proportions for a cell as `[f64; 8]` in D8 order
634    #[must_use]
635    pub fn get_proportions(&self, x: u64, y: u64) -> [f64; 8] {
636        let base = ((y * self.width + x) * 8) as usize;
637        let mut out = [0.0; 8];
638        for i in 0..8 {
639            out[i] = self.proportions[base + i];
640        }
641        out
642    }
643}
644
645/// Computes MFD (Multiple Flow Direction) proportions.
646///
647/// Each cell partitions flow among all downslope neighbours according to
648/// `w_i = (tan(slope_i))^p / sum((tan(slope_j))^p)` where the sum is over
649/// all downslope neighbours.
650///
651/// # Errors
652///
653/// Returns an error if raster pixel access fails.
654pub fn compute_mfd_flow_direction(dem: &RasterBuffer, cfg: &MfdConfig) -> Result<MfdResult> {
655    let w = dem.width();
656    let h = dem.height();
657    let total_cells = (w * h) as usize;
658    let mut proportions = vec![0.0_f64; total_cells * 8];
659
660    for y in 0..h {
661        for x in 0..w {
662            let center = dem.get_pixel(x, y).map_err(AlgorithmError::Core)?;
663            let base = ((y * w + x) * 8) as usize;
664
665            let mut slopes = [0.0_f64; 8];
666            let mut total_weight = 0.0_f64;
667
668            for i in 0..8 {
669                let nx = x as i64 + D8_DX[i];
670                let ny = y as i64 + D8_DY[i];
671                if !in_bounds(nx, ny, w, h) {
672                    continue;
673                }
674                let ne = dem
675                    .get_pixel(nx as u64, ny as u64)
676                    .map_err(AlgorithmError::Core)?;
677                let drop = center - ne;
678                if drop > 0.0 {
679                    let dist = d8_distance(i) * cfg.cell_size;
680                    let tan_slope = drop / dist;
681                    let weight = tan_slope.powf(cfg.exponent);
682                    slopes[i] = weight;
683                    total_weight += weight;
684                }
685            }
686
687            if total_weight > 0.0 {
688                for i in 0..8 {
689                    proportions[base + i] = slopes[i] / total_weight;
690                }
691            }
692            // If total_weight == 0, cell is flat/pit: all proportions stay 0
693        }
694    }
695
696    Ok(MfdResult {
697        width: w,
698        height: h,
699        proportions,
700    })
701}
702
703// ---------------------------------------------------------------------------
704// Unified flow direction computation
705// ---------------------------------------------------------------------------
706
707/// Result from a flow direction computation.
708///
709/// Different methods produce different result types, unified here.
710pub enum FlowDirectionResult {
711    /// D8: a single raster with direction codes
712    D8(RasterBuffer),
713    /// D-Infinity: (angle raster, proportion raster)
714    DInfinity(RasterBuffer, RasterBuffer),
715    /// MFD: per-cell proportions
716    Mfd(MfdResult),
717}
718
719/// Computes flow direction using the specified method.
720///
721/// # Errors
722///
723/// Returns an error if raster pixel access fails.
724pub fn compute_flow_direction(
725    dem: &RasterBuffer,
726    method: FlowMethod,
727    cell_size: f64,
728) -> Result<FlowDirectionResult> {
729    match method {
730        FlowMethod::D8 => {
731            let fd = compute_d8_flow_direction(dem, cell_size)?;
732            Ok(FlowDirectionResult::D8(fd))
733        }
734        FlowMethod::DInfinity => {
735            let (a, p) = compute_dinf_flow_direction(dem, cell_size)?;
736            Ok(FlowDirectionResult::DInfinity(a, p))
737        }
738        FlowMethod::MFD => {
739            let cfg = MfdConfig {
740                cell_size,
741                ..MfdConfig::default()
742            };
743            let mfd = compute_mfd_flow_direction(dem, &cfg)?;
744            Ok(FlowDirectionResult::Mfd(mfd))
745        }
746    }
747}
748
749// ---------------------------------------------------------------------------
750// Tests
751// ---------------------------------------------------------------------------
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756    use approx::assert_abs_diff_eq;
757
758    fn make_slope_dem(width: u64, height: u64) -> RasterBuffer {
759        // Elevation decreases toward the east: elev = (width - 1 - x)
760        let mut dem = RasterBuffer::zeros(width, height, RasterDataType::Float32);
761        for y in 0..height {
762            for x in 0..width {
763                let _ = dem.set_pixel(x, y, (width - 1 - x) as f64);
764            }
765        }
766        dem
767    }
768
769    fn make_se_slope_dem(width: u64, height: u64) -> RasterBuffer {
770        let mut dem = RasterBuffer::zeros(width, height, RasterDataType::Float32);
771        for y in 0..height {
772            for x in 0..width {
773                let _ = dem.set_pixel(x, y, ((width - 1 - x) + (height - 1 - y)) as f64);
774            }
775        }
776        dem
777    }
778
779    #[test]
780    fn test_d8_direction_offset() {
781        assert_eq!(D8Direction::East.offset(), (1, 0));
782        assert_eq!(D8Direction::South.offset(), (0, 1));
783        assert_eq!(D8Direction::West.offset(), (-1, 0));
784        assert_eq!(D8Direction::North.offset(), (0, -1));
785    }
786
787    #[test]
788    fn test_d8_direction_angle() {
789        assert_abs_diff_eq!(D8Direction::East.angle_degrees(), 0.0);
790        assert_abs_diff_eq!(D8Direction::South.angle_degrees(), 90.0);
791        assert_abs_diff_eq!(D8Direction::West.angle_degrees(), 180.0);
792        assert_abs_diff_eq!(D8Direction::North.angle_degrees(), 270.0);
793    }
794
795    #[test]
796    fn test_d8_from_code_round_trip() {
797        for dir in D8Direction::all() {
798            let code = dir as u8;
799            let recovered = D8Direction::from_code(code);
800            assert_eq!(recovered, Some(dir));
801        }
802    }
803
804    #[test]
805    fn test_d8_opposite() {
806        assert_eq!(D8Direction::East.opposite(), D8Direction::West);
807        assert_eq!(D8Direction::North.opposite(), D8Direction::South);
808        assert_eq!(D8Direction::Southeast.opposite(), D8Direction::Northwest);
809    }
810
811    #[test]
812    fn test_d8_simple_east_slope() {
813        let dem = make_slope_dem(7, 7);
814        let fd = compute_d8_flow_direction(&dem, 1.0);
815        assert!(fd.is_ok());
816        let fd = fd.expect("should succeed");
817        // Interior cell (3,3): slope goes east
818        let code = fd.get_pixel(3, 3).expect("should succeed") as u8;
819        assert_eq!(code, D8Direction::East as u8);
820    }
821
822    #[test]
823    fn test_d8_pit_detection() {
824        let mut dem = RasterBuffer::zeros(5, 5, RasterDataType::Float32);
825        for y in 0..5u64 {
826            for x in 0..5u64 {
827                let _ = dem.set_pixel(x, y, 10.0);
828            }
829        }
830        // Create pit at (2,2)
831        let _ = dem.set_pixel(2, 2, 1.0);
832
833        let cfg = D8Config {
834            cell_size: 1.0,
835            resolve_flats: false,
836        };
837        let fd = compute_d8_flow_direction_cfg(&dem, &cfg);
838        assert!(fd.is_ok());
839        let fd = fd.expect("should succeed");
840        let code = fd.get_pixel(2, 2).expect("should succeed") as u8;
841        assert_eq!(code, D8_PIT, "Cell (2,2) should be marked as pit");
842    }
843
844    #[test]
845    fn test_d8_flat_resolution() {
846        // Flat DEM except edges are high
847        let mut dem = RasterBuffer::zeros(7, 7, RasterDataType::Float32);
848        for y in 0..7u64 {
849            for x in 0..7u64 {
850                let _ = dem.set_pixel(x, y, 10.0);
851            }
852        }
853        // Lower outlet at (6,3)
854        let _ = dem.set_pixel(6, 3, 5.0);
855
856        let fd = compute_d8_flow_direction(&dem, 1.0);
857        assert!(fd.is_ok());
858        let fd = fd.expect("should succeed");
859
860        // Interior flat cells should now have non-zero, non-pit directions
861        let code = fd.get_pixel(3, 3).expect("should succeed") as u8;
862        assert_ne!(code, D8_FLAT, "Flat cell should be resolved");
863        assert_ne!(code, D8_PIT, "Flat cell should not be marked as pit");
864    }
865
866    #[test]
867    fn test_dinf_se_slope() {
868        let dem = make_se_slope_dem(7, 7);
869        let result = compute_dinf_flow_direction(&dem, 1.0);
870        assert!(result.is_ok());
871        let (angle_raster, prop_raster) = result.expect("should succeed");
872
873        let angle = angle_raster.get_pixel(3, 3).expect("should succeed");
874        let prop = prop_raster.get_pixel(3, 3).expect("should succeed");
875
876        // Should be in the SE quadrant (roughly 45 degrees)
877        assert!((0.0..=360.0).contains(&angle), "Angle {angle} out of range");
878        assert!(
879            (0.0..=1.0).contains(&prop),
880            "Proportion {prop} out of range"
881        );
882    }
883
884    #[test]
885    fn test_mfd_east_slope() {
886        let dem = make_slope_dem(7, 7);
887        let cfg = MfdConfig {
888            cell_size: 1.0,
889            exponent: 1.1,
890        };
891        let result = compute_mfd_flow_direction(&dem, &cfg);
892        assert!(result.is_ok());
893        let mfd = result.expect("should succeed");
894
895        let props = mfd.get_proportions(3, 3);
896        // East (idx 0) should get the largest proportion
897        let east_prop = props[0];
898        for i in 1..8 {
899            assert!(
900                east_prop >= props[i],
901                "East proportion {east_prop} should be >= props[{i}] = {}",
902                props[i]
903            );
904        }
905        // Sum should be ~1.0
906        let sum: f64 = props.iter().sum();
907        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
908    }
909
910    #[test]
911    fn test_mfd_flat_no_crash() {
912        // Completely flat DEM
913        let dem = RasterBuffer::zeros(5, 5, RasterDataType::Float32);
914        let cfg = MfdConfig::default();
915        let result = compute_mfd_flow_direction(&dem, &cfg);
916        assert!(result.is_ok());
917        let mfd = result.expect("should succeed");
918
919        // All proportions should be 0 (flat)
920        let props = mfd.get_proportions(2, 2);
921        let sum: f64 = props.iter().sum();
922        assert_abs_diff_eq!(sum, 0.0, epsilon = 1e-10);
923    }
924
925    #[test]
926    fn test_unified_compute() {
927        let dem = make_slope_dem(7, 7);
928        let r1 = compute_flow_direction(&dem, FlowMethod::D8, 1.0);
929        assert!(r1.is_ok());
930        let r2 = compute_flow_direction(&dem, FlowMethod::DInfinity, 1.0);
931        assert!(r2.is_ok());
932        let r3 = compute_flow_direction(&dem, FlowMethod::MFD, 1.0);
933        assert!(r3.is_ok());
934    }
935}