Skip to main content

hydro_analysis/
lib.rs

1//! # Hydro-analysis
2//!
3//! `hydro-analysis` provides functions for Hydrology DEM manipulation.  There are
4//! a couple generic functions for reading/writing raster files of any common
5//! primative type (which surprizingly I couldn't find anywhere else, unless you
6//! use GDAL which I am trying to avoid).  Also there are a couple functions based
7//! on [whitebox](https://github.com/jblindsay/whitebox-tools).  Whitebox is a
8//! command line tool, this provides functionality via functions so can be called
9//! from your code.
10//!
11//! ```
12
13use rayon::prelude::*;
14use std::collections::{HashMap, BinaryHeap, VecDeque};
15use std::cmp::Ordering;
16use std::cmp::Ordering::Equal;
17use std::collections::hash_map::Entry::Vacant;
18use ndarray::{Array2, par_azip};
19
20use std::{fs::File, f64, path::PathBuf};
21use thiserror::Error;
22use bytemuck::cast_slice;
23
24use tiff::decoder::DecodingResult;
25use tiff::encoder::compression::Deflate;
26use tiff::encoder::colortype::{Gray8,Gray16,Gray32,Gray64,Gray32Float,Gray64Float,GrayI8,GrayI16,GrayI32,GrayI64};
27use tiff::tags::Tag;
28use tiff::TiffFormatError;
29
30
31#[derive(Debug, Error)]
32pub enum RasterError {
33    #[error("TIFF error: {0}")]
34    Tiff(#[from] tiff::TiffError),
35
36    #[error("I/O error: {0}")]
37    Io(#[from] std::io::Error),
38
39    #[error("NDarray: {0}")]
40    Shape(#[from] ndarray::ShapeError),
41
42    #[error("Failed to parse nodata value")]
43    ParseIntError(#[from] std::num::ParseIntError),
44
45    #[error("Failed to parse nodata value")]
46    ParseFloatError(#[from] std::num::ParseFloatError),
47
48    #[error("Unsupported type: {0}")]
49    UnsupportedType(String)
50}
51
52/// Reads a single-band grayscale GeoTIFF raster file returning ndarry2 and metadata
53///
54/// # Type Parameters
55/// - `T`: The pixel value type.  u8, u16, i16, f64 etc
56///
57/// # Parameters
58/// - `fname`: Path to the input `.tif` GeoTIFF file.
59///
60/// # Returns
61/// A `Result` with a tuple containing:
62///  - `Array2<T>`: The raster data in a 2D array.
63///  - `T`: nodata
64///  - `u16`: CRS (e.g. 2193)
65///  - `[f64; 6]`: The affine GeoTransform in the format:
66///   `[origin_x, pixel_size_x, rotation_x, origin_y, rotation_y, pixel_size_y]`.
67///  - `Vec<u64>`: raw GeoKeyDirectoryTag values (needed for writing to file)
68///  - `String`: PROJ string (needed for writing to file)
69///
70/// # Errors
71///  - Returns `RasterError` variants if reading fails, the type conversion for data or metadata
72///   fails, or required tags are missing from the TIFF file.
73///
74/// # Example
75///   See [examples/reading_and_writing_rasters.rs](examples/reading_and_writing_rasters.rs)
76pub fn rasterfile_to_array<T>(fname: &PathBuf) -> Result<
77    (
78        Array2<T>,
79        T,          // nodata
80        u16,        // crs
81        [f64; 6],   // geo transform [start_x, psize_x, rotation, starty, rotation, psize_y]
82        Vec<u64>,   // geo dir, it has the crs in it
83        String      // the projection string
84    ),
85    RasterError
86>
87    where T: std::str::FromStr + num::FromPrimitive,
88          <T as std::str::FromStr>::Err: std::fmt::Debug,
89          RasterError: std::convert::From<<T as std::str::FromStr>::Err> { // Open the file
90    let file = File::open(fname)?;
91
92    // Create a TIFF decoder
93    let mut decoder = tiff::decoder::Decoder::new(file)?;
94    decoder = decoder.with_limits(tiff::decoder::Limits::unlimited());
95
96    // Read the image dimensions
97    let (width, height) = decoder.dimensions()?;
98
99    fn estr<T>(etype: &'static str) -> RasterError {
100        RasterError::Tiff(TiffFormatError::Format(format!("Raster is {}, I was expecting {}", etype, std::any::type_name::<T>()).into()).into())
101    }
102    let data: Vec<T> = match decoder.read_image()? {
103        DecodingResult::I8(buf)  => buf.into_iter().map(|v| <T>::from_i8(v).ok_or(estr::<T>("I8"))).collect::<Result<_, _>>(),
104        DecodingResult::I16(buf) => buf.into_iter().map(|v| <T>::from_i16(v).ok_or(estr::<T>("I16"))).collect::<Result<_, _>>(),
105        DecodingResult::I32(buf) => buf.into_iter().map(|v| <T>::from_i32(v).ok_or(estr::<T>("I32"))).collect::<Result<_, _>>(),
106        DecodingResult::I64(buf) => buf.into_iter().map(|v| <T>::from_i64(v).ok_or(estr::<T>("I64"))).collect::<Result<_, _>>(),
107        DecodingResult::U8(buf)  => buf.into_iter().map(|v| <T>::from_u8(v).ok_or(estr::<T>("U8"))).collect::<Result<_, _>>(),
108        DecodingResult::U16(buf) => buf.into_iter().map(|v| <T>::from_u16(v).ok_or(estr::<T>("U16"))).collect::<Result<_, _>>(),
109        DecodingResult::U32(buf) => buf.into_iter().map(|v| <T>::from_u32(v).ok_or(estr::<T>("U32"))).collect::<Result<_, _>>(),
110        DecodingResult::U64(buf) => buf.into_iter().map(|v| <T>::from_u64(v).ok_or(estr::<T>("U64"))).collect::<Result<_, _>>(),
111        DecodingResult::F32(buf) => buf.into_iter().map(|v| <T>::from_f32(v).ok_or(estr::<T>("F32"))).collect::<Result<_, _>>(),
112        DecodingResult::F64(buf) => buf.into_iter().map(|v| <T>::from_f64(v).ok_or(estr::<T>("F64"))).collect::<Result<_, _>>(),
113    }?;
114
115    // Convert the flat vector into an ndarray::Array2
116    let array: Array2<T> = Array2::from_shape_vec((height as usize, width as usize), data)?;
117
118    // nodata value
119    let nodata: T = decoder.get_tag_ascii_string(Tag::GdalNodata)?.trim().parse::<T>()?;
120
121    // pixel scale [pixel scale x, pixel scale y, ...]
122    // NB pixel scale y is the absolute value, it is POSITIVE.  We have to make it negative later
123    let pscale: Vec<f64> = decoder.get_tag_f64_vec(Tag::ModelPixelScaleTag)?.into_iter().collect();
124
125    // tie point [0 0 0 startx starty 0]
126    let tie: Vec<f64>  = decoder.get_tag_f64_vec(Tag::ModelTiepointTag)?.into_iter().collect();
127
128    // transform, the zeros are the rotations [start x, x pixel size, 0, start y, 0, y pixel size]
129    let geotrans: [f64; 6] = [tie[3], pscale[0], 0.0, tie[4], 0.0, -pscale[1]];
130
131    let projection: String = decoder.get_tag_ascii_string(Tag::GeoAsciiParamsTag)?;
132    let geokeydir: Vec<u64> = decoder .get_tag_u64_vec(Tag::GeoKeyDirectoryTag)?;
133
134    // try and get the CRS out of the geokeydir, it is the bit after 3072
135    let crs = geokeydir.windows(4).find(|w| w[0] == 3072).map(|w| w[3])
136        .ok_or(RasterError::Tiff(tiff::TiffFormatError::InvalidTagValueType(Tag::GeoKeyDirectoryTag).into()))? as u16;
137
138    Ok((array, nodata, crs, geotrans, geokeydir, projection))
139}
140
141/// Writes a 2D array of values to a GeoTIFF raster with geo metadata.
142///
143/// # Type Parameters
144/// - `T`: The element type of the array, which must implement `bytemuck::Pod`
145/// (for safe byte casting) and `ToString` (for writing NoData values to
146/// metadata).
147///
148/// # Parameters
149/// `data`: A 2D array (`ndarray::Array2<T>`) containing raster pixel values.
150/// `nd`: NoData value
151/// `geotrans`: A 6-element array defining the affine geotransform:
152///   `[origin_x, pixel_size_x, rotation_x, origin_y, rotation_y, pixel_size_y]`.
153/// `geokeydir`: &[u64] the GeoKeyDirectoryTag (best got from reading a raster)
154/// `proj`: PROJ string (best got from reading a raster)
155/// `outfile`: The path to the output `.tif` file.
156///
157/// # Returns
158/// Ok() or a `RasterError`
159///
160/// # Errors
161/// - Returns `RasterError::UnsupportedType` if `T` can't be mapped to a TIFF format.
162/// - Propagates I/O and TIFF writing errors
163///
164/// # Example
165///   See [examples/reading_and_writing_rasters.rs](examples/reading_and_writing_rasters.rs)
166pub fn array_to_rasterfile<T>(
167    data: &Array2<T>,
168    nd: T,                      // nodata
169    geotrans: &[f64; 6],        // geo transform [start_x, psize_x, rotation, starty, rotation, psize_y]
170    geokeydir: &[u64],          // geo dir, it has the crs in it
171    proj: &str,                 // the projection string
172    outfile: &PathBuf
173) -> Result<(), RasterError>
174    where T: bytemuck::Pod + ToString
175{
176    let (nrows, ncols) = (data.nrows(), data.ncols());
177
178    let fh = File::create(outfile)?;
179    let mut encoder = tiff::encoder::TiffEncoder::new(fh)?;
180
181    // Because image doesn't have traits I couldn't figure out how to do this with generics
182    // This macro takes the tiff colortype
183    macro_rules! writit {
184        ($pix:ty) => {{
185            let mut image = encoder.new_image_with_compression::<$pix, Deflate>(ncols as u32, nrows as u32, Deflate::default())?;
186            image.encoder().write_tag(Tag::GdalNodata, &nd.to_string()[..])?;
187            // remember that geotrans is negative, but in tiff tags assumed to be positive
188            image.encoder().write_tag(Tag::ModelPixelScaleTag, &[geotrans[1], -geotrans[5], 0.0][..])?;
189            image.encoder().write_tag(Tag::ModelTiepointTag, &[0.0, 0.0, 0.0, geotrans[0], geotrans[3], 0.0][..])?;
190            image.encoder().write_tag(Tag::GeoKeyDirectoryTag, geokeydir)?;
191            image.encoder().write_tag(Tag::GeoAsciiParamsTag, &proj)?;
192            image.write_data(cast_slice(data.as_slice().unwrap()))?;
193        }};
194    }
195
196    match std::any::TypeId::of::<T>() {
197        id if id == std::any::TypeId::of::<u8>()  => writit!(Gray8),
198        id if id == std::any::TypeId::of::<u16>() => writit!(Gray16),
199        id if id == std::any::TypeId::of::<u32>() => writit!(Gray32),
200        id if id == std::any::TypeId::of::<u64>() => writit!(Gray64),
201        id if id == std::any::TypeId::of::<f32>() => writit!(Gray32Float),
202        id if id == std::any::TypeId::of::<f64>() => writit!(Gray64Float),
203        id if id == std::any::TypeId::of::<i8>()  => writit!(GrayI8),
204        id if id == std::any::TypeId::of::<i16>() => writit!(GrayI16),
205        id if id == std::any::TypeId::of::<i32>() => writit!(GrayI32),
206        id if id == std::any::TypeId::of::<i64>() => writit!(GrayI64),
207        _ => return Err(RasterError::UnsupportedType(format!("Cannot handle type {}", std::any::type_name::<T>())))
208    };
209
210    Ok(())
211}
212
213
214#[derive(PartialEq, Debug)]
215struct GridCell {
216    row: usize,
217    column: usize,
218    priority: f64,
219}
220
221impl Eq for GridCell {}
222
223impl PartialOrd for GridCell {
224    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
225        other.priority.partial_cmp(&self.priority)
226    }
227}
228
229impl Ord for GridCell {
230    fn cmp(&self, other: &Self) -> Ordering {
231        self.partial_cmp(other).unwrap()
232    }
233}
234
235#[derive(PartialEq, Debug)]
236struct GridCell2 {
237    row: usize,
238    column: usize,
239    z: f64,
240    priority: f64,
241}
242
243impl Eq for GridCell2 {}
244
245impl PartialOrd for GridCell2 {
246    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
247        other.priority.partial_cmp(&self.priority)
248    }
249}
250
251impl Ord for GridCell2 {
252    fn cmp(&self, other: &Self) -> Ordering {
253        self.partial_cmp(other).unwrap()
254    }
255}
256
257
258/// Fills depressions (sinks) in a digital elevation model (DEM).
259///
260/// More-or-less the contents of
261/// [whitebox fill_depressions](https://github.com/jblindsay/whitebox-tools/blob/master/whitebox-tools-app/src/tools/hydro_analysis/fill_depressions.rs)
262///
263/// This function modifies the input `dem` to ensure that all depressions (local minima that do not
264/// drain) are removed, making the surface hydrologically correct. It also considers no-data values
265/// and can optionally fix flat areas.
266///
267/// # Parameters
268///
269/// - `dem`: A mutable reference to a 2D array (`Array2<f64>`) representing the elevation data.
270/// - `nodata`: The value representing no-data cells in the DEM.
271/// - `resx`: The horizontal resolution (grid spacing in the x-direction).
272/// - `resy`: The vertical resolution (grid spacing in the y-direction).
273/// - `fix_flats`: A boolean flag to determine whether flat areas should be slightly sloped.
274///
275/// # Example
276///
277/// ```
278/// use ndarray::{array, Array2};
279/// use hydro_analysis::fill_depressions;
280///
281/// let mut dem: Array2<f64> = array![
282///         [10.0, 12.0, 10.0],
283///         [12.0, 1.0, 12.0],
284///         [10.0, 12.0, 10.0],
285/// ];
286/// let filled: Array2<f64> = array![
287///         [10.0, 12.0, 10.0],
288///         [12.0, 1.0, 12.0],
289///         [10.0, 12.0, 10.0],
290/// ];
291/// fill_depressions(&mut dem, -3.0, 8.0, 8.0, true);
292/// assert_eq!(dem, filled);
293/// let mut dem: Array2<f64> = array![
294///         [10.0, 12.0, 10.0, 10.0],
295///         [12.0, 1.0, 10.0, 12.0],
296///         [10.0, 12.0, 10.0, 9.0],
297/// ];
298/// let filled: Array2<f64> = array![
299///         [10.0, 12.0, 10.0, 10.0],
300///         [12.0, 10.0, 10.0, 12.0],
301///         [10.0, 12.0, 10.0, 9.0],
302/// ];
303/// fill_depressions(&mut dem, -3.0, 8.0, 8.0, true);
304/// for (x, y) in dem.iter().zip(filled.iter()) {
305///    assert!((*x - *y).abs() < 1e-4);
306/// }
307/// ```
308pub fn fill_depressions(
309    dem: &mut Array2<f64>, nodata: f64, resx: f64, resy: f64, fix_flats: bool
310)
311{
312    let (rows, columns) = (dem.nrows(), dem.ncols());
313    let small_num = {
314        let diagres = (resx * resx + resy * resy).sqrt();
315        let elev_digits = (dem.iter().cloned().fold(f64::NEG_INFINITY, f64::max) as i64).to_string().len();
316        let elev_multiplier = 10.0_f64.powi((9 - elev_digits) as i32);
317        1.0_f64 / elev_multiplier as f64 * diagres.ceil()
318    };
319
320    //let input = dem.clone();    // original whitebox used an input and output while doing fixing flats, don't think you really need it
321
322    let dx = [1, 1, 1, 0, -1, -1, -1, 0];
323    let dy = [-1, 0, 1, 1, 1, 0, -1, -1];
324
325    // Find pit cells (we don't care about pits around edge, they won't be considered a pit). This step is parallelizable.
326	let mut pits: Vec<_> = (1..rows - 1)
327		.into_par_iter()
328		.flat_map(|row| {
329			let mut local_pits = Vec::new();
330			for col in 1..columns - 1 {
331				let z = dem[[row, col]];
332				if z == nodata {
333					continue;
334				}
335				let mut apit = true;
336            	// is anything lower than me?
337				for n in 0..8 {
338					let zn = dem[[(row as isize + dy[n]) as usize, (col as isize + dx[n]) as usize]];
339					if zn < z || zn == nodata {
340						apit = false;
341						break;
342					}
343				}
344				// no, so I am a pit
345				if apit {
346					local_pits.push((row, col, z));
347				}
348			}
349			local_pits
350		}).collect();
351
352    // Now we need to perform an in-place depression filling
353    let mut minheap = BinaryHeap::new();
354    let mut minheap2 = BinaryHeap::new();
355    let mut visited = Array2::<u8>::zeros((rows, columns));
356    let mut flats = Array2::<u8>::zeros((rows, columns));
357    let mut possible_outlets = vec![];
358    let mut queue = VecDeque::new();
359
360    // go through pits from highest to lowest
361    pits.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(Equal));
362    while let Some(cell) = pits.pop() {
363        let row: usize = cell.0;
364        let col: usize = cell.1;
365
366        // if it's already in a solved site, don't do it a second time.
367        if flats[[row, col]] != 1 {
368
369            // First there is a priority region-growing operation to find the outlets.
370            minheap.clear();
371            minheap.push(GridCell {
372                row: row,
373                column: col,
374                priority: dem[[row, col]],
375            });
376            visited[[row, col]] = 1;
377            let mut outlet_found = false;
378            let mut outlet_z = f64::INFINITY;
379            if !queue.is_empty() {
380                queue.clear();
381            }
382            while let Some(cell2) = minheap.pop() {
383
384                let z = cell2.priority;
385                if outlet_found && z > outlet_z {
386                    break;
387                }
388                if !outlet_found {
389                    for n in 0..8 {
390                        let cn = cell2.column as isize + dx[n];
391                        let rn = cell2.row as isize + dy[n];
392                        if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
393                            continue;
394                        }
395                        let cn = cn as usize;
396                        let rn = rn as usize;
397                        if visited[[rn, cn]] == 0 {
398                            let zn = dem[[rn, cn]];
399                            if !outlet_found {
400                                if zn >= z && zn != nodata {
401                                    minheap.push(GridCell {
402                                        row: rn,
403                                        column: cn,
404                                        priority: zn,
405                                    });
406                                    visited[[rn, cn]] = 1;
407                                } else if zn != nodata {
408                                    // zn < z
409                                    // 'cell' has a lower neighbour that hasn't already passed through minheap.
410                                    // Therefore, 'cell' is a pour point cell.
411                                    outlet_found = true;
412                                    outlet_z = z;
413                                    queue.push_back((cell2.row, cell2.column));
414                                    possible_outlets.push((cell2.row, cell2.column));
415                                }
416                            } else if zn == outlet_z {
417                                // We've found the outlet but are still looking for additional depression cells.
418                                minheap.push(GridCell {
419                                    row: rn,
420                                    column: cn,
421                                    priority: zn,
422                                });
423                                visited[[rn, cn]] = 1;
424                            }
425                        }
426                    }
427                } else {
428                    // We've found the outlet but are still looking for additional depression cells and potential outlets.
429                    if z == outlet_z {
430                        let mut anoutlet = false;
431                        for n in 0..8 {
432                            let cn = cell2.column as isize + dx[n];
433                            let rn = cell2.row as isize + dy[n];
434                            if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
435                                continue;
436                            }
437                            let cn = cn as usize;
438                            let rn = rn as usize;
439                            if visited[[rn, cn]] == 0 {
440                                let zn = dem[[rn, cn]];
441                                if zn < z {
442                                    anoutlet = true;
443                                } else if zn == outlet_z {
444                                    minheap.push(GridCell {
445                                        row: rn,
446                                        column: cn,
447                                        priority: zn,
448                                    });
449                                    visited[[rn, cn]] = 1;
450                                }
451                            }
452                        }
453                        if anoutlet {
454                            queue.push_back((cell2.row, cell2.column));
455                            possible_outlets.push((cell2.row, cell2.column));
456                        } else {
457                            visited[[cell2.row, cell2.column]] = 1;
458                        }
459                    }
460                }
461            }
462
463            if outlet_found {
464                // Now that we have the outlets, raise the interior of the depression.
465                // Start from the outlets.
466                while let Some(cell2) = queue.pop_front() {
467                    for n in 0..8 {
468                        let cn = cell2.1 as isize + dx[n];
469                        let rn = cell2.0 as isize + dy[n];
470                        if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
471                            continue;
472                        }
473                        let cn = cn as usize;
474                        let rn = rn as usize;
475                        if visited[[rn, cn]] == 1 {
476                            visited[[rn, cn]] = 0;
477                            queue.push_back((rn, cn));
478                            let z = dem[[rn, cn]];
479                            if z < outlet_z {
480                                dem[[rn, cn]] = outlet_z;
481                                flats[[rn, cn]] = 1;
482                            } else if z == outlet_z {
483                                flats[[rn, cn]] = 1;
484                            }
485                        }
486                    }
487                }
488            } else {
489                queue.push_back((row, col)); // start at the pit cell and clean up visited
490                while let Some(cell2) = queue.pop_front() {
491                    for n in 0..8 {
492                        let cn = cell2.1 as isize + dx[n];
493                        let rn = cell2.0 as isize + dy[n];
494                        if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
495                            continue;
496                        }
497                        let cn = cn as usize;
498                        let rn = rn as usize;
499                        if visited[[rn, cn]] == 1 {
500                            visited[[rn, cn]] = 0;
501                            queue.push_back((rn, cn));
502                        }
503                    }
504                }
505            }
506        }
507
508    }
509
510    drop(visited);
511
512    if small_num > 0.0 && fix_flats {
513        // Some of the potential outlets really will have lower cells.
514        minheap.clear();
515        while let Some(cell) = possible_outlets.pop() {
516            let z = dem[[cell.0, cell.1]];
517            let mut anoutlet = false;
518            for n in 0..8 {
519                let rn: isize = cell.0 as isize + dy[n];
520                let cn: isize = cell.1 as isize + dx[n];
521                if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
522                    continue;
523                }
524                let zn = dem[[rn as usize, cn as usize]];
525                if zn < z && zn != nodata {
526                    anoutlet = true;
527                    break;
528                }
529            }
530            if anoutlet {
531                minheap.push(GridCell {
532                    row: cell.0,
533                    column: cell.1,
534                    priority: z,
535                });
536            }
537        }
538
539        let mut outlets = vec![];
540        while let Some(cell) = minheap.pop() {
541            if flats[[cell.row, cell.column]] != 3 {
542                let z = dem[[cell.row, cell.column]];
543                flats[[cell.row, cell.column]] = 3;
544                if !outlets.is_empty() {
545                    outlets.clear();
546                }
547                outlets.push(cell);
548                // Are there any other outlet cells at the same elevation (likely for the same feature)
549                let mut flag = true;
550                while flag {
551                    match minheap.peek() {
552                        Some(cell2) => {
553                            if cell2.priority == z {
554                                flats[[cell2.row, cell2.column]] = 3;
555                                outlets
556                                    .push(minheap.pop().expect("Error during pop operation."));
557                            } else {
558                                flag = false;
559                            }
560                        }
561                        None => {
562                            flag = false;
563                        }
564                    }
565                }
566                if !minheap2.is_empty() {
567                    minheap2.clear();
568                }
569                for cell2 in &outlets {
570                    let z = dem[[cell2.row, cell2.column]];
571                    for n in 0..8 {
572                        let cn = cell2.column as isize + dx[n];
573                        let rn = cell2.row as isize + dy[n];
574                        if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
575                            continue;
576                        }
577                        let cn = cn as usize;
578                        let rn = rn as usize;
579                        if flats[[rn, cn]] != 3 {
580                            let zn = dem[[rn, cn]];
581                            if zn == z && zn != nodata {
582                                minheap2.push(GridCell2 {
583                                    row: rn,
584                                    column: cn,
585                                    z: z,
586                                    priority: dem[[rn, cn]], // FIXME
587                                });
588                                dem[[rn, cn]] = z + small_num;
589                                flats[[rn, cn]] = 3;
590                            }
591                        }
592                    }
593                }
594
595                // Now fix the flats
596                while let Some(cell2) = minheap2.pop() {
597                    let z = dem[[cell2.row, cell2.column]];
598                    for n in 0..8 {
599                        let cn = cell2.column as isize + dx[n];
600                        let rn = cell2.row as isize + dy[n];
601                        if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
602                            continue;
603                        }
604                        let cn = cn as usize;
605                        let rn = rn as usize;
606                        if flats[[rn, cn]] != 3 {
607                            let zn = dem[[rn, cn]];
608                            if zn < z + small_num && zn >= cell2.z && zn != nodata {
609                                minheap2.push(GridCell2 {
610                                    row: rn,
611                                    column: cn,
612                                    z: cell2.z,
613                                    priority: dem[[rn, cn]], // FIXME
614                                });
615                                dem[[rn, cn]] = z + small_num;
616                                flats[[rn, cn]] = 3;
617                            }
618                        }
619                    }
620                }
621            }
622
623        }
624    }
625}
626
627/// Calculates the D8 flow direction from a digital elevation model (DEM).
628///
629/// More-or-less the contents of
630/// [whitebox d8_pointer](https://github.com/jblindsay/whitebox-tools/blob/master/whitebox-tools-app/src/tools/hydro_analysis/d8_pointer.rs)
631///
632/// This function computes the D8 flow direction for each cell in the provided DEM:
633///
634/// | .  |  .  |  . |
635/// |:--:|:---:|:--:|
636/// | 64 | 128 | 1  |
637/// | 32 |  0  | 2  |
638/// | 16 |  8  | 4  |
639///
640/// Grid cells that have no lower neighbours are assigned a flow direction of zero. In a DEM that
641/// has been pre-processed to remove all depressions and flat areas, this condition will only occur
642/// along the edges of the grid.
643///
644/// Grid cells possessing the NoData value in the input DEM are assigned the NoData value in the
645/// output image.
646///
647/// # Parameters
648/// - `dem`: A 2D array representing the digital elevation model (DEM)
649/// - `nodata`: The nodata in the DEM
650/// - `resx`: The resolution of the DEM in the x-direction in meters
651/// - `resy`: The resolution of the DEM in the y-direction in meters
652///
653/// # Returns
654/// - A tuple containing:
655/// - An `Array2<u8>` representing the D8 flow directions for each cell.
656/// - A `u8` nodata value (255)
657///
658/// # Example
659/// ```rust
660/// use anyhow::Result;
661/// use ndarray::{Array2, array};
662/// use hydro_analysis::{d8_pointer};
663/// let dem = Array2::from_shape_vec(
664///     (3, 3),
665///     vec![
666///         11.0, 12.0, 10.0,
667///         12.0, 13.0, 12.0,
668///         10.5, 12.0, 11.0,
669///     ],
670/// ).expect("Failed to create DEM");
671/// let d8: Array2<u8> = array![
672///     [0, 2, 0],
673///     [8, 1, 128],
674///     [0, 32, 0],
675/// ];
676/// let nodata = -9999.0;
677/// let resx = 8.0;
678/// let resy = 8.0;
679/// let (res, _nd) = d8_pointer(&dem, nodata, resx, resy);
680/// assert_eq!(res, d8);
681/// ```
682pub fn d8_pointer(dem: &Array2<f64>, nodata: f64, resx: f64, resy: f64) -> (Array2<u8>, u8)
683{
684    let (nrows, ncols) = (dem.nrows(), dem.ncols());
685    let out_nodata: u8 = 255;
686    let mut d8: Array2<u8> = Array2::from_elem((nrows, ncols), out_nodata);
687
688    let diag = (resx * resx + resy * resy).sqrt();
689    let grid_lengths = [diag, resx, diag, resy, diag, resx, diag, resy];
690
691    let dx = [1, 1, 1, 0, -1, -1, -1, 0];
692    let dy = [-1, 0, 1, 1, 1, 0, -1, -1];
693
694    d8.axis_iter_mut(ndarray::Axis(0))
695        .into_par_iter()
696        .enumerate()
697        .for_each(|(row, mut d8_row)| {
698            for col in 0..ncols {
699                let z = dem[[row, col]];
700                if z == nodata {
701                    continue;
702                }
703
704                let mut dir = 0;
705                let mut max_slope = f64::MIN;
706                for i in 0..8 {
707                    let rn: isize = row as isize + dy[i];
708                    let cn: isize = col as isize + dx[i];
709                    if rn < 0 || rn >= nrows as isize || cn < 0 || cn >= ncols as isize {
710                        continue;
711                    }
712                    let z_n = dem[[rn as usize, cn as usize]];
713                    if z_n != nodata {
714                        let slope = (z - z_n) / grid_lengths[i];
715                        if slope > max_slope && slope > 0.0 {
716                            max_slope = slope;
717                            dir = i;
718                        }
719                    }
720                }
721
722                if max_slope >= 0.0 {
723                    d8_row[col] = 1 << dir;
724                } else {
725                    d8_row[col] = 0u8;
726                }
727            }
728        });
729
730    return (d8, out_nodata);
731}
732
733
734/// Updates a D8 flow direction using a mask that indicates onland and ocean and return changed
735/// flattened indices.
736///
737/// This function will ensure your downstream traces calculated from a D8 will stop at the coast.
738/// Any cell whose direction leads us to the ocean will become a sink.  Any cell in the ocean will
739/// be nodata
740///
741/// An alternative approach (which doesn't work) is to first mask the DEM with nodata and then
742/// calculate the D8.  Unfortunately this results in downstream traces that get to the coast
743/// (nodata) and then head along the coast, sometimes for ages before coming to a natural sink.
744///
745/// # Parameters
746///   d8: A 2D u8 array representing the directions
747///   nodata: The nodata for the d8 (probably 255)
748///   onland: A 2D bool array representing the land
749///
750/// # Returns
751/// nothing, but updates the d8
752///
753/// # Example
754/// ```rust
755/// use ndarray::Array2;
756/// let mut d8 = Array2::from_shape_vec(
757///     (3, 3),
758///     vec![
759///         4, 2, 1,
760///         2, 2, 1,
761///         1, 128, 4,
762///     ],
763/// ).expect("Failed to create D8");
764/// let onland = Array2::from_shape_vec(
765///     (3, 3),
766///     vec![
767///         true, true, false,
768///         true, true, false,
769///         true, true, false,
770///     ],
771/// ).expect("Failed to create onland");
772/// let newd8 = Array2::from_shape_vec(
773///     (3, 3),
774///     vec![
775///         4, 0, 255,
776///         2, 0, 255,
777///         1, 128, 255,
778///     ],
779/// ).expect("Failed to create D8");
780/// let changed = hydro_analysis::d8_clipped(&mut d8, 255, &onland);
781/// assert_eq!(d8, newd8);
782/// assert_eq!(changed, vec![1, 4]);
783/// ```
784pub fn d8_clipped(d8: &mut Array2<u8>, nodata: u8, onland: &Array2<bool>) -> Vec<usize>
785{
786    let (nrows, ncols) = (d8.nrows(), d8.ncols());
787
788    // in the ocean is nodata
789    par_azip!((d in &mut *d8, m in onland){ if !m { *d = nodata } });
790
791    // so we can index using 1D 
792    let slice = d8.as_slice().unwrap();
793
794    let tozero: Vec<usize> = (0..nrows).into_par_iter().flat_map(|r| {
795        let mut local: Vec<usize> = Vec::new();
796        for c in 0..ncols {
797            let idx = r*ncols + c;
798            if slice[idx] == nodata || slice[idx] == 0 {
799                continue;
800            }
801
802            // get next cell downstream
803            let (dr, dc) = match slice[idx] {
804                1   => (-1,  1),
805                2   => ( 0,  1),
806                4   => ( 1,  1),
807                8   => ( 1,  0),
808                16  => ( 1, -1),
809                32  => ( 0, -1),
810                64  => (-1, -1),
811                128 => (-1,  0),
812                _ => unreachable!(),
813            };
814            let rn = r as isize + dr;
815            let cn = c as isize + dc;
816
817            // if next is outside, don't change idx
818            if rn < 0 || rn >= nrows as isize || cn < 0 || cn >= ncols as isize {
819                continue;
820            }
821            // next is inside and nodata, then set idx to a sink
822            if d8[[rn as usize, cn as usize]] == nodata {
823                local.push(idx);
824            }
825        }
826        local
827    }).collect();
828
829    let slice = d8.as_slice_mut().unwrap();
830    for idx in &tozero {
831        slice[*idx] = 0;
832    }
833
834    tozero
835}
836
837
838/// Breach depressions least cost.  Implements
839///
840/// [whitebox breach_depressions_least_cost](https://github.com/jblindsay/whitebox-tools/blob/master/whitebox-tools-app/src/tools/hydro_analysis/breach_depressions_least_cost.rs)
841///
842/// with
843///   max_cost set to infinity,
844///   flat_increment is default, and
845///   minimize_dist set to true.
846///
847/// with more modern and less memory intensive datastructures.
848///
849/// The only real parameter to tune is max_dist (measured in cells) and a default would be 20
850///
851/// # Returns
852///  number of pits that remain
853///
854/// # Examples
855///
856/// See tests/breach_depressions.rs for lots of examples, here is just one
857///
858/// ```rust
859/// use ndarray::{Array2, array};
860/// use hydro_analysis::breach_depressions;
861/// let resx = 8.0;
862/// let resy = 8.0;
863/// let max_dist = 100;
864/// let ep = 0.00000012;
865/// let mut dem: Array2<f64> = array![
866///     [2.0, 2.0, 2.0, 2.0, 0.0],
867///     [2.0, 1.0, 2.0, 0.5, 0.0],
868///     [2.0, 2.0, 2.0, 2.0, 0.0],
869/// ];
870/// let breached: Array2<f64> = array![
871///     [2.0, 2.0, 2.0, 2.0, 0.0],
872///     [2.0, 2.0-ep, 2.0-2.0*ep, 0.5, 0.0],
873///     [2.0, 2.0, 2.0, 2.0, 0.0],
874/// ];
875/// let n = breach_depressions(&mut dem, -1.0, resx, resy, max_dist);
876/// assert_eq!(n, 0);
877/// for (x, y) in dem.iter().zip(breached.iter()) {
878///    assert!((*x - *y).abs() < 1e-10);
879/// }
880///
881pub fn breach_depressions(dem: &mut Array2<f64>, nodata: f64, resx: f64, resy: f64, max_dist: usize) -> usize
882{
883
884    let diagres: f64 = (resx * resx + resy * resy).sqrt();
885    let cost_dist = [diagres, resx, diagres, resy, diagres, resx, diagres, resy];
886
887    let (nrows, ncols) = (dem.nrows(), dem.ncols());
888    let small_num = {
889        let diagres = (resx * resx + resy * resy).sqrt();
890        let elev_digits = (dem.iter().cloned().fold(f64::NEG_INFINITY, f64::max) as i64).to_string().len();
891        let elev_multiplier = 10.0_f64.powi((9 - elev_digits) as i32);
892        1.0_f64 / elev_multiplier as f64 * diagres.ceil()
893    };
894
895    // want to raise interior (ie not on boundary or with a nodata neighbour) pit cells up to
896    // just below minimum neighbour height.
897    //
898    // so first find the pits
899    let dx = [1, 1, 1, 0, -1, -1, -1, 0];
900    let dy = [-1, 0, 1, 1, 1, 0, -1, -1];
901	let mut pits: Vec<_> = (1..nrows - 1).into_par_iter().flat_map(|row| {
902        let mut local_pits = Vec::new();
903        for col in 1..ncols - 1 {
904            let z = dem[[row, col]];
905            if z == nodata {
906                continue;
907            }
908            let mut apit = true;
909            // is any neighbour lower than me or nodata?
910            for n in 0..8 {
911                let zn = dem[[(row as isize + dy[n]) as usize, (col as isize + dx[n]) as usize]];
912                if zn < z || zn == nodata {
913                    apit = false;
914                    break;
915                }
916            }
917            // no, so I am a pit
918            if apit {
919                local_pits.push((row, col, z));
920            }
921        }
922        local_pits
923    }).collect();
924
925    // set depth to just below min neighbour, can't do this in parallel, we update the dem and pits
926    for &mut (row, col, ref mut z) in pits.iter_mut() {
927        let min_zn: f64 = dx.iter().zip(&dy).map(|(&dxi, &dyi)|
928            dem[[
929                (row as isize + dyi) as usize,
930                (col as isize + dxi) as usize,
931            ]]
932        ).fold(f64::INFINITY, f64::min);
933        *z = min_zn - small_num;
934        dem[[row, col]] = *z;
935    }
936
937    // Sort highest to lowest so poping off the end will get the lowest, should do that first
938    // because might solve some higher ones
939    pits.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Equal));
940
941    // keep track of number we can't deal with to return
942    let mut num_unsolved: usize = 0;
943
944    // roughly how many cells we will have to look at
945    let num_to_chk: usize = (max_dist * 2 + 1) * (max_dist * 2 + 1);
946
947    // for keeping track of path be make
948    #[derive(Debug)]
949    struct NodeInfo {
950        prev: Option<(usize, usize)>,
951        length: usize,
952    }
953
954    // try and dig a channel from row, col
955    while let Some((row, col, z)) = pits.pop() {
956
957        // May have been solved during previous depression step, so can skip
958        if dx.iter().zip(&dy).any(|(&dxi, &dyi)|
959            dem[[(row as isize + dyi) as usize, (col as isize + dxi) as usize]] < z
960        ) {
961            continue;
962        }
963
964        // keep our path info in here
965        let mut visited: HashMap<(usize,usize), NodeInfo> = HashMap::default();
966        visited.insert((row,col), NodeInfo {
967            prev: None,
968            length: 0,
969        });
970
971        // cells we will check to see if lower than (row, col, z)
972        let mut cells_to_chk = BinaryHeap::with_capacity(num_to_chk);
973        cells_to_chk.push(GridCell {row: row, column: col, priority: 0.0 });
974
975        let mut dugit = false;
976        'search: while let Some(GridCell {row: chkrow, column: chkcol, priority: accum}) = cells_to_chk.pop() {
977
978            let length: usize = visited[&(chkrow,chkcol)].length;
979            let zn: f64 = dem[[chkrow, chkcol]];
980            let cost1: f64 = zn - z + length as f64 * small_num;
981            let mut breach: Option<(usize, usize)> = None;
982
983            // lets look about chkrow, chkcol
984            for n in 0..8 {
985                let rn: isize = chkrow as isize + dy[n];
986                let cn: isize = chkcol as isize + dx[n];
987
988                // chkrow, chkcol is a breach if on the edge
989                if rn < 0 || rn >= nrows as isize || cn < 0 || cn >= ncols as isize {
990                    breach = Some((chkrow, chkcol));
991                    // we need to force this boundary cell down
992                    dem[[chkrow, chkcol]] = z - (length as f64 * small_num);
993                    break;
994                }
995                let rn: usize = rn as usize;
996                let cn: usize = cn as usize;
997                let nextlen = length + 1;
998
999                // insert if not visited
1000                if let Vacant(e) = visited.entry((rn, cn)) {
1001                    e.insert(NodeInfo {
1002                        prev: Some((chkrow, chkcol)),
1003                        length: nextlen
1004                    });
1005                } else {
1006                    continue;
1007                }
1008
1009                let zn: f64 = dem[[rn, cn]];
1010                // an internal nodata cannot be breach point
1011                if zn == nodata {
1012                    continue;
1013                }
1014
1015                // zout is lowered from z by nextlen * slope
1016                let zout: f64 = z - (nextlen as f64 * small_num);
1017                if zn <= zout {
1018                    breach = Some((rn, cn));
1019                    break;
1020                }
1021
1022                // (rn, cn) no good, zn is too high, just push onto heap
1023                if zn > zout {
1024                    let cost2: f64 = zn - zout;
1025                    let new_cost: f64 = accum + (cost1 + cost2) / 2.0 * cost_dist[n];
1026                    // but haven't travelled too far, so we will scan from this cell
1027                    if nextlen <= max_dist {
1028                        cells_to_chk.push(GridCell {
1029                            row: rn,
1030                            column: cn,
1031                            priority: new_cost
1032                        });
1033                    }
1034                }
1035            }
1036
1037            if let Some(mut cur) = breach {
1038                loop {
1039                    let node: &NodeInfo = &visited[&cur];
1040                    // back at the pit?
1041                    let Some(parent) = node.prev else {
1042                        break;
1043                    };
1044
1045                    let zn = dem[[parent.0, parent.1]];
1046                    let length = visited[&parent].length;
1047                    let zout = z - (length as f64 * small_num);
1048                    if zn > zout {
1049                        dem[[parent.0, parent.1]] = zout;
1050                    }
1051                    cur = parent;
1052                }
1053                dugit = true;
1054                break 'search;
1055            }
1056
1057        }
1058
1059        if !dugit {
1060            // Didn't find any lower cells, tough luck
1061            num_unsolved += 1;
1062        }
1063    }
1064
1065    num_unsolved
1066}
1067