hydro_analysis/lib.rs
1//! # Hydro-analysis
2//!
3//! `hydro-analysis` provides functions for Hydrology DEM manipulation. There are
4//! a couple generic functions for reading/writing raster files of any common
5//! primative type (which surprizingly I couldn't find anywhere else, unless you
6//! use GDAL which I am trying to avoid). Also there are a couple functions based
7//! on [whitebox](https://github.com/jblindsay/whitebox-tools). Whitebox is a
8//! command line tool, this provides functionality via functions so can be called
9//! from your code.
10//!
11//! ```
12
13use rayon::prelude::*;
14use std::collections::{HashMap, BinaryHeap, VecDeque};
15use std::cmp::Ordering;
16use std::cmp::Ordering::Equal;
17use std::collections::hash_map::Entry::Vacant;
18use ndarray::{Array2, par_azip};
19
20use std::{fs::File, f64, path::PathBuf};
21use thiserror::Error;
22use bytemuck::cast_slice;
23
24use tiff::decoder::DecodingResult;
25use tiff::encoder::compression::Deflate;
26use tiff::encoder::colortype::{Gray8,Gray16,Gray32,Gray64,Gray32Float,Gray64Float,GrayI8,GrayI16,GrayI32,GrayI64};
27use tiff::tags::Tag;
28use tiff::TiffFormatError;
29
30
31#[derive(Debug, Error)]
32pub enum RasterError {
33 #[error("TIFF error: {0}")]
34 Tiff(#[from] tiff::TiffError),
35
36 #[error("I/O error: {0}")]
37 Io(#[from] std::io::Error),
38
39 #[error("NDarray: {0}")]
40 Shape(#[from] ndarray::ShapeError),
41
42 #[error("Failed to parse nodata value")]
43 ParseIntError(#[from] std::num::ParseIntError),
44
45 #[error("Failed to parse nodata value")]
46 ParseFloatError(#[from] std::num::ParseFloatError),
47
48 #[error("Unsupported type: {0}")]
49 UnsupportedType(String)
50}
51
52/// Reads a single-band grayscale GeoTIFF raster file returning ndarry2 and metadata
53///
54/// # Type Parameters
55/// - `T`: The pixel value type. u8, u16, i16, f64 etc
56///
57/// # Parameters
58/// - `fname`: Path to the input `.tif` GeoTIFF file.
59///
60/// # Returns
61/// A `Result` with a tuple containing:
62/// - `Array2<T>`: The raster data in a 2D array.
63/// - `T`: nodata
64/// - `u16`: CRS (e.g. 2193)
65/// - `[f64; 6]`: The affine GeoTransform in the format:
66/// `[origin_x, pixel_size_x, rotation_x, origin_y, rotation_y, pixel_size_y]`.
67/// - `Vec<u64>`: raw GeoKeyDirectoryTag values (needed for writing to file)
68/// - `String`: PROJ string (needed for writing to file)
69///
70/// # Errors
71/// - Returns `RasterError` variants if reading fails, the type conversion for data or metadata
72/// fails, or required tags are missing from the TIFF file.
73///
74/// # Example
75/// See [examples/reading_and_writing_rasters.rs](examples/reading_and_writing_rasters.rs)
76pub fn rasterfile_to_array<T>(fname: &PathBuf) -> Result<
77 (
78 Array2<T>,
79 T, // nodata
80 u16, // crs
81 [f64; 6], // geo transform [start_x, psize_x, rotation, starty, rotation, psize_y]
82 Vec<u64>, // geo dir, it has the crs in it
83 String // the projection string
84 ),
85 RasterError
86>
87 where T: std::str::FromStr + num::FromPrimitive,
88 <T as std::str::FromStr>::Err: std::fmt::Debug,
89 RasterError: std::convert::From<<T as std::str::FromStr>::Err> { // Open the file
90 let file = File::open(fname)?;
91
92 // Create a TIFF decoder
93 let mut decoder = tiff::decoder::Decoder::new(file)?;
94 decoder = decoder.with_limits(tiff::decoder::Limits::unlimited());
95
96 // Read the image dimensions
97 let (width, height) = decoder.dimensions()?;
98
99 fn estr<T>(etype: &'static str) -> RasterError {
100 RasterError::Tiff(TiffFormatError::Format(format!("Raster is {}, I was expecting {}", etype, std::any::type_name::<T>()).into()).into())
101 }
102 let data: Vec<T> = match decoder.read_image()? {
103 DecodingResult::I8(buf) => buf.into_iter().map(|v| <T>::from_i8(v).ok_or(estr::<T>("I8"))).collect::<Result<_, _>>(),
104 DecodingResult::I16(buf) => buf.into_iter().map(|v| <T>::from_i16(v).ok_or(estr::<T>("I16"))).collect::<Result<_, _>>(),
105 DecodingResult::I32(buf) => buf.into_iter().map(|v| <T>::from_i32(v).ok_or(estr::<T>("I32"))).collect::<Result<_, _>>(),
106 DecodingResult::I64(buf) => buf.into_iter().map(|v| <T>::from_i64(v).ok_or(estr::<T>("I64"))).collect::<Result<_, _>>(),
107 DecodingResult::U8(buf) => buf.into_iter().map(|v| <T>::from_u8(v).ok_or(estr::<T>("U8"))).collect::<Result<_, _>>(),
108 DecodingResult::U16(buf) => buf.into_iter().map(|v| <T>::from_u16(v).ok_or(estr::<T>("U16"))).collect::<Result<_, _>>(),
109 DecodingResult::U32(buf) => buf.into_iter().map(|v| <T>::from_u32(v).ok_or(estr::<T>("U32"))).collect::<Result<_, _>>(),
110 DecodingResult::U64(buf) => buf.into_iter().map(|v| <T>::from_u64(v).ok_or(estr::<T>("U64"))).collect::<Result<_, _>>(),
111 DecodingResult::F32(buf) => buf.into_iter().map(|v| <T>::from_f32(v).ok_or(estr::<T>("F32"))).collect::<Result<_, _>>(),
112 DecodingResult::F64(buf) => buf.into_iter().map(|v| <T>::from_f64(v).ok_or(estr::<T>("F64"))).collect::<Result<_, _>>(),
113 }?;
114
115 // Convert the flat vector into an ndarray::Array2
116 let array: Array2<T> = Array2::from_shape_vec((height as usize, width as usize), data)?;
117
118 // nodata value
119 let nodata: T = decoder.get_tag_ascii_string(Tag::GdalNodata)?.trim().parse::<T>()?;
120
121 // pixel scale [pixel scale x, pixel scale y, ...]
122 // NB pixel scale y is the absolute value, it is POSITIVE. We have to make it negative later
123 let pscale: Vec<f64> = decoder.get_tag_f64_vec(Tag::ModelPixelScaleTag)?.into_iter().collect();
124
125 // tie point [0 0 0 startx starty 0]
126 let tie: Vec<f64> = decoder.get_tag_f64_vec(Tag::ModelTiepointTag)?.into_iter().collect();
127
128 // transform, the zeros are the rotations [start x, x pixel size, 0, start y, 0, y pixel size]
129 let geotrans: [f64; 6] = [tie[3], pscale[0], 0.0, tie[4], 0.0, -pscale[1]];
130
131 let projection: String = decoder.get_tag_ascii_string(Tag::GeoAsciiParamsTag)?;
132 let geokeydir: Vec<u64> = decoder .get_tag_u64_vec(Tag::GeoKeyDirectoryTag)?;
133
134 // try and get the CRS out of the geokeydir, it is the bit after 3072
135 let crs = geokeydir.windows(4).find(|w| w[0] == 3072).map(|w| w[3])
136 .ok_or(RasterError::Tiff(tiff::TiffFormatError::InvalidTagValueType(Tag::GeoKeyDirectoryTag).into()))? as u16;
137
138 Ok((array, nodata, crs, geotrans, geokeydir, projection))
139}
140
141/// Writes a 2D array of values to a GeoTIFF raster with geo metadata.
142///
143/// # Type Parameters
144/// - `T`: The element type of the array, which must implement `bytemuck::Pod`
145/// (for safe byte casting) and `ToString` (for writing NoData values to
146/// metadata).
147///
148/// # Parameters
149/// `data`: A 2D array (`ndarray::Array2<T>`) containing raster pixel values.
150/// `nd`: NoData value
151/// `geotrans`: A 6-element array defining the affine geotransform:
152/// `[origin_x, pixel_size_x, rotation_x, origin_y, rotation_y, pixel_size_y]`.
153/// `geokeydir`: &[u64] the GeoKeyDirectoryTag (best got from reading a raster)
154/// `proj`: PROJ string (best got from reading a raster)
155/// `outfile`: The path to the output `.tif` file.
156///
157/// # Returns
158/// Ok() or a `RasterError`
159///
160/// # Errors
161/// - Returns `RasterError::UnsupportedType` if `T` can't be mapped to a TIFF format.
162/// - Propagates I/O and TIFF writing errors
163///
164/// # Example
165/// See [examples/reading_and_writing_rasters.rs](examples/reading_and_writing_rasters.rs)
166pub fn array_to_rasterfile<T>(
167 data: &Array2<T>,
168 nd: T, // nodata
169 geotrans: &[f64; 6], // geo transform [start_x, psize_x, rotation, starty, rotation, psize_y]
170 geokeydir: &[u64], // geo dir, it has the crs in it
171 proj: &str, // the projection string
172 outfile: &PathBuf
173) -> Result<(), RasterError>
174 where T: bytemuck::Pod + ToString
175{
176 let (nrows, ncols) = (data.nrows(), data.ncols());
177
178 let fh = File::create(outfile)?;
179 let mut encoder = tiff::encoder::TiffEncoder::new(fh)?;
180
181 // Because image doesn't have traits I couldn't figure out how to do this with generics
182 // This macro takes the tiff colortype
183 macro_rules! writit {
184 ($pix:ty) => {{
185 let mut image = encoder.new_image_with_compression::<$pix, Deflate>(ncols as u32, nrows as u32, Deflate::default())?;
186 image.encoder().write_tag(Tag::GdalNodata, &nd.to_string()[..])?;
187 // remember that geotrans is negative, but in tiff tags assumed to be positive
188 image.encoder().write_tag(Tag::ModelPixelScaleTag, &[geotrans[1], -geotrans[5], 0.0][..])?;
189 image.encoder().write_tag(Tag::ModelTiepointTag, &[0.0, 0.0, 0.0, geotrans[0], geotrans[3], 0.0][..])?;
190 image.encoder().write_tag(Tag::GeoKeyDirectoryTag, geokeydir)?;
191 image.encoder().write_tag(Tag::GeoAsciiParamsTag, &proj)?;
192 image.write_data(cast_slice(data.as_slice().unwrap()))?;
193 }};
194 }
195
196 match std::any::TypeId::of::<T>() {
197 id if id == std::any::TypeId::of::<u8>() => writit!(Gray8),
198 id if id == std::any::TypeId::of::<u16>() => writit!(Gray16),
199 id if id == std::any::TypeId::of::<u32>() => writit!(Gray32),
200 id if id == std::any::TypeId::of::<u64>() => writit!(Gray64),
201 id if id == std::any::TypeId::of::<f32>() => writit!(Gray32Float),
202 id if id == std::any::TypeId::of::<f64>() => writit!(Gray64Float),
203 id if id == std::any::TypeId::of::<i8>() => writit!(GrayI8),
204 id if id == std::any::TypeId::of::<i16>() => writit!(GrayI16),
205 id if id == std::any::TypeId::of::<i32>() => writit!(GrayI32),
206 id if id == std::any::TypeId::of::<i64>() => writit!(GrayI64),
207 _ => return Err(RasterError::UnsupportedType(format!("Cannot handle type {}", std::any::type_name::<T>())))
208 };
209
210 Ok(())
211}
212
213
214#[derive(PartialEq, Debug)]
215struct GridCell {
216 row: usize,
217 column: usize,
218 priority: f64,
219}
220
221impl Eq for GridCell {}
222
223impl PartialOrd for GridCell {
224 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
225 other.priority.partial_cmp(&self.priority)
226 }
227}
228
229impl Ord for GridCell {
230 fn cmp(&self, other: &Self) -> Ordering {
231 self.partial_cmp(other).unwrap()
232 }
233}
234
235#[derive(PartialEq, Debug)]
236struct GridCell2 {
237 row: usize,
238 column: usize,
239 z: f64,
240 priority: f64,
241}
242
243impl Eq for GridCell2 {}
244
245impl PartialOrd for GridCell2 {
246 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
247 other.priority.partial_cmp(&self.priority)
248 }
249}
250
251impl Ord for GridCell2 {
252 fn cmp(&self, other: &Self) -> Ordering {
253 self.partial_cmp(other).unwrap()
254 }
255}
256
257
258/// Fills depressions (sinks) in a digital elevation model (DEM).
259///
260/// More-or-less the contents of
261/// [whitebox fill_depressions](https://github.com/jblindsay/whitebox-tools/blob/master/whitebox-tools-app/src/tools/hydro_analysis/fill_depressions.rs)
262///
263/// This function modifies the input `dem` to ensure that all depressions (local minima that do not
264/// drain) are removed, making the surface hydrologically correct. It also considers no-data values
265/// and can optionally fix flat areas.
266///
267/// # Parameters
268///
269/// - `dem`: A mutable reference to a 2D array (`Array2<f64>`) representing the elevation data.
270/// - `nodata`: The value representing no-data cells in the DEM.
271/// - `resx`: The horizontal resolution (grid spacing in the x-direction).
272/// - `resy`: The vertical resolution (grid spacing in the y-direction).
273/// - `fix_flats`: A boolean flag to determine whether flat areas should be slightly sloped.
274///
275/// # Example
276///
277/// ```
278/// use ndarray::{array, Array2};
279/// use hydro_analysis::fill_depressions;
280///
281/// let mut dem: Array2<f64> = array![
282/// [10.0, 12.0, 10.0],
283/// [12.0, 1.0, 12.0],
284/// [10.0, 12.0, 10.0],
285/// ];
286/// let filled: Array2<f64> = array![
287/// [10.0, 12.0, 10.0],
288/// [12.0, 1.0, 12.0],
289/// [10.0, 12.0, 10.0],
290/// ];
291/// fill_depressions(&mut dem, -3.0, 8.0, 8.0, true);
292/// assert_eq!(dem, filled);
293/// let mut dem: Array2<f64> = array![
294/// [10.0, 12.0, 10.0, 10.0],
295/// [12.0, 1.0, 10.0, 12.0],
296/// [10.0, 12.0, 10.0, 9.0],
297/// ];
298/// let filled: Array2<f64> = array![
299/// [10.0, 12.0, 10.0, 10.0],
300/// [12.0, 10.0, 10.0, 12.0],
301/// [10.0, 12.0, 10.0, 9.0],
302/// ];
303/// fill_depressions(&mut dem, -3.0, 8.0, 8.0, true);
304/// for (x, y) in dem.iter().zip(filled.iter()) {
305/// assert!((*x - *y).abs() < 1e-4);
306/// }
307/// ```
308pub fn fill_depressions(
309 dem: &mut Array2<f64>, nodata: f64, resx: f64, resy: f64, fix_flats: bool
310)
311{
312 let (rows, columns) = (dem.nrows(), dem.ncols());
313 let small_num = {
314 let diagres = (resx * resx + resy * resy).sqrt();
315 let elev_digits = (dem.iter().cloned().fold(f64::NEG_INFINITY, f64::max) as i64).to_string().len();
316 let elev_multiplier = 10.0_f64.powi((9 - elev_digits) as i32);
317 1.0_f64 / elev_multiplier as f64 * diagres.ceil()
318 };
319
320 //let input = dem.clone(); // original whitebox used an input and output while doing fixing flats, don't think you really need it
321
322 let dx = [1, 1, 1, 0, -1, -1, -1, 0];
323 let dy = [-1, 0, 1, 1, 1, 0, -1, -1];
324
325 // Find pit cells (we don't care about pits around edge, they won't be considered a pit). This step is parallelizable.
326 let mut pits: Vec<_> = (1..rows - 1)
327 .into_par_iter()
328 .flat_map(|row| {
329 let mut local_pits = Vec::new();
330 for col in 1..columns - 1 {
331 let z = dem[[row, col]];
332 if z == nodata {
333 continue;
334 }
335 let mut apit = true;
336 // is anything lower than me?
337 for n in 0..8 {
338 let zn = dem[[(row as isize + dy[n]) as usize, (col as isize + dx[n]) as usize]];
339 if zn < z || zn == nodata {
340 apit = false;
341 break;
342 }
343 }
344 // no, so I am a pit
345 if apit {
346 local_pits.push((row, col, z));
347 }
348 }
349 local_pits
350 }).collect();
351
352 // Now we need to perform an in-place depression filling
353 let mut minheap = BinaryHeap::new();
354 let mut minheap2 = BinaryHeap::new();
355 let mut visited = Array2::<u8>::zeros((rows, columns));
356 let mut flats = Array2::<u8>::zeros((rows, columns));
357 let mut possible_outlets = vec![];
358 let mut queue = VecDeque::new();
359
360 // go through pits from highest to lowest
361 pits.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(Equal));
362 while let Some(cell) = pits.pop() {
363 let row: usize = cell.0;
364 let col: usize = cell.1;
365
366 // if it's already in a solved site, don't do it a second time.
367 if flats[[row, col]] != 1 {
368
369 // First there is a priority region-growing operation to find the outlets.
370 minheap.clear();
371 minheap.push(GridCell {
372 row: row,
373 column: col,
374 priority: dem[[row, col]],
375 });
376 visited[[row, col]] = 1;
377 let mut outlet_found = false;
378 let mut outlet_z = f64::INFINITY;
379 if !queue.is_empty() {
380 queue.clear();
381 }
382 while let Some(cell2) = minheap.pop() {
383
384 let z = cell2.priority;
385 if outlet_found && z > outlet_z {
386 break;
387 }
388 if !outlet_found {
389 for n in 0..8 {
390 let cn = cell2.column as isize + dx[n];
391 let rn = cell2.row as isize + dy[n];
392 if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
393 continue;
394 }
395 let cn = cn as usize;
396 let rn = rn as usize;
397 if visited[[rn, cn]] == 0 {
398 let zn = dem[[rn, cn]];
399 if !outlet_found {
400 if zn >= z && zn != nodata {
401 minheap.push(GridCell {
402 row: rn,
403 column: cn,
404 priority: zn,
405 });
406 visited[[rn, cn]] = 1;
407 } else if zn != nodata {
408 // zn < z
409 // 'cell' has a lower neighbour that hasn't already passed through minheap.
410 // Therefore, 'cell' is a pour point cell.
411 outlet_found = true;
412 outlet_z = z;
413 queue.push_back((cell2.row, cell2.column));
414 possible_outlets.push((cell2.row, cell2.column));
415 }
416 } else if zn == outlet_z {
417 // We've found the outlet but are still looking for additional depression cells.
418 minheap.push(GridCell {
419 row: rn,
420 column: cn,
421 priority: zn,
422 });
423 visited[[rn, cn]] = 1;
424 }
425 }
426 }
427 } else {
428 // We've found the outlet but are still looking for additional depression cells and potential outlets.
429 if z == outlet_z {
430 let mut anoutlet = false;
431 for n in 0..8 {
432 let cn = cell2.column as isize + dx[n];
433 let rn = cell2.row as isize + dy[n];
434 if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
435 continue;
436 }
437 let cn = cn as usize;
438 let rn = rn as usize;
439 if visited[[rn, cn]] == 0 {
440 let zn = dem[[rn, cn]];
441 if zn < z {
442 anoutlet = true;
443 } else if zn == outlet_z {
444 minheap.push(GridCell {
445 row: rn,
446 column: cn,
447 priority: zn,
448 });
449 visited[[rn, cn]] = 1;
450 }
451 }
452 }
453 if anoutlet {
454 queue.push_back((cell2.row, cell2.column));
455 possible_outlets.push((cell2.row, cell2.column));
456 } else {
457 visited[[cell2.row, cell2.column]] = 1;
458 }
459 }
460 }
461 }
462
463 if outlet_found {
464 // Now that we have the outlets, raise the interior of the depression.
465 // Start from the outlets.
466 while let Some(cell2) = queue.pop_front() {
467 for n in 0..8 {
468 let cn = cell2.1 as isize + dx[n];
469 let rn = cell2.0 as isize + dy[n];
470 if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
471 continue;
472 }
473 let cn = cn as usize;
474 let rn = rn as usize;
475 if visited[[rn, cn]] == 1 {
476 visited[[rn, cn]] = 0;
477 queue.push_back((rn, cn));
478 let z = dem[[rn, cn]];
479 if z < outlet_z {
480 dem[[rn, cn]] = outlet_z;
481 flats[[rn, cn]] = 1;
482 } else if z == outlet_z {
483 flats[[rn, cn]] = 1;
484 }
485 }
486 }
487 }
488 } else {
489 queue.push_back((row, col)); // start at the pit cell and clean up visited
490 while let Some(cell2) = queue.pop_front() {
491 for n in 0..8 {
492 let cn = cell2.1 as isize + dx[n];
493 let rn = cell2.0 as isize + dy[n];
494 if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
495 continue;
496 }
497 let cn = cn as usize;
498 let rn = rn as usize;
499 if visited[[rn, cn]] == 1 {
500 visited[[rn, cn]] = 0;
501 queue.push_back((rn, cn));
502 }
503 }
504 }
505 }
506 }
507
508 }
509
510 drop(visited);
511
512 if small_num > 0.0 && fix_flats {
513 // Some of the potential outlets really will have lower cells.
514 minheap.clear();
515 while let Some(cell) = possible_outlets.pop() {
516 let z = dem[[cell.0, cell.1]];
517 let mut anoutlet = false;
518 for n in 0..8 {
519 let rn: isize = cell.0 as isize + dy[n];
520 let cn: isize = cell.1 as isize + dx[n];
521 if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
522 continue;
523 }
524 let zn = dem[[rn as usize, cn as usize]];
525 if zn < z && zn != nodata {
526 anoutlet = true;
527 break;
528 }
529 }
530 if anoutlet {
531 minheap.push(GridCell {
532 row: cell.0,
533 column: cell.1,
534 priority: z,
535 });
536 }
537 }
538
539 let mut outlets = vec![];
540 while let Some(cell) = minheap.pop() {
541 if flats[[cell.row, cell.column]] != 3 {
542 let z = dem[[cell.row, cell.column]];
543 flats[[cell.row, cell.column]] = 3;
544 if !outlets.is_empty() {
545 outlets.clear();
546 }
547 outlets.push(cell);
548 // Are there any other outlet cells at the same elevation (likely for the same feature)
549 let mut flag = true;
550 while flag {
551 match minheap.peek() {
552 Some(cell2) => {
553 if cell2.priority == z {
554 flats[[cell2.row, cell2.column]] = 3;
555 outlets
556 .push(minheap.pop().expect("Error during pop operation."));
557 } else {
558 flag = false;
559 }
560 }
561 None => {
562 flag = false;
563 }
564 }
565 }
566 if !minheap2.is_empty() {
567 minheap2.clear();
568 }
569 for cell2 in &outlets {
570 let z = dem[[cell2.row, cell2.column]];
571 for n in 0..8 {
572 let cn = cell2.column as isize + dx[n];
573 let rn = cell2.row as isize + dy[n];
574 if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
575 continue;
576 }
577 let cn = cn as usize;
578 let rn = rn as usize;
579 if flats[[rn, cn]] != 3 {
580 let zn = dem[[rn, cn]];
581 if zn == z && zn != nodata {
582 minheap2.push(GridCell2 {
583 row: rn,
584 column: cn,
585 z: z,
586 priority: dem[[rn, cn]], // FIXME
587 });
588 dem[[rn, cn]] = z + small_num;
589 flats[[rn, cn]] = 3;
590 }
591 }
592 }
593 }
594
595 // Now fix the flats
596 while let Some(cell2) = minheap2.pop() {
597 let z = dem[[cell2.row, cell2.column]];
598 for n in 0..8 {
599 let cn = cell2.column as isize + dx[n];
600 let rn = cell2.row as isize + dy[n];
601 if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
602 continue;
603 }
604 let cn = cn as usize;
605 let rn = rn as usize;
606 if flats[[rn, cn]] != 3 {
607 let zn = dem[[rn, cn]];
608 if zn < z + small_num && zn >= cell2.z && zn != nodata {
609 minheap2.push(GridCell2 {
610 row: rn,
611 column: cn,
612 z: cell2.z,
613 priority: dem[[rn, cn]], // FIXME
614 });
615 dem[[rn, cn]] = z + small_num;
616 flats[[rn, cn]] = 3;
617 }
618 }
619 }
620 }
621 }
622
623 }
624 }
625}
626
627/// Calculates the D8 flow direction from a digital elevation model (DEM).
628///
629/// More-or-less the contents of
630/// [whitebox d8_pointer](https://github.com/jblindsay/whitebox-tools/blob/master/whitebox-tools-app/src/tools/hydro_analysis/d8_pointer.rs)
631///
632/// This function computes the D8 flow direction for each cell in the provided DEM:
633///
634/// | . | . | . |
635/// |:--:|:---:|:--:|
636/// | 64 | 128 | 1 |
637/// | 32 | 0 | 2 |
638/// | 16 | 8 | 4 |
639///
640/// Grid cells that have no lower neighbours are assigned a flow direction of zero. In a DEM that
641/// has been pre-processed to remove all depressions and flat areas, this condition will only occur
642/// along the edges of the grid.
643///
644/// Grid cells possessing the NoData value in the input DEM are assigned the NoData value in the
645/// output image.
646///
647/// # Parameters
648/// - `dem`: A 2D array representing the digital elevation model (DEM)
649/// - `nodata`: The nodata in the DEM
650/// - `resx`: The resolution of the DEM in the x-direction in meters
651/// - `resy`: The resolution of the DEM in the y-direction in meters
652///
653/// # Returns
654/// - A tuple containing:
655/// - An `Array2<u8>` representing the D8 flow directions for each cell.
656/// - A `u8` nodata value (255)
657///
658/// # Example
659/// ```rust
660/// use anyhow::Result;
661/// use ndarray::{Array2, array};
662/// use hydro_analysis::{d8_pointer};
663/// let dem = Array2::from_shape_vec(
664/// (3, 3),
665/// vec![
666/// 11.0, 12.0, 10.0,
667/// 12.0, 13.0, 12.0,
668/// 10.5, 12.0, 11.0,
669/// ],
670/// ).expect("Failed to create DEM");
671/// let d8: Array2<u8> = array![
672/// [0, 2, 0],
673/// [8, 1, 128],
674/// [0, 32, 0],
675/// ];
676/// let nodata = -9999.0;
677/// let resx = 8.0;
678/// let resy = 8.0;
679/// let (res, _nd) = d8_pointer(&dem, nodata, resx, resy);
680/// assert_eq!(res, d8);
681/// ```
682pub fn d8_pointer(dem: &Array2<f64>, nodata: f64, resx: f64, resy: f64) -> (Array2<u8>, u8)
683{
684 let (nrows, ncols) = (dem.nrows(), dem.ncols());
685 let out_nodata: u8 = 255;
686 let mut d8: Array2<u8> = Array2::from_elem((nrows, ncols), out_nodata);
687
688 let diag = (resx * resx + resy * resy).sqrt();
689 let grid_lengths = [diag, resx, diag, resy, diag, resx, diag, resy];
690
691 let dx = [1, 1, 1, 0, -1, -1, -1, 0];
692 let dy = [-1, 0, 1, 1, 1, 0, -1, -1];
693
694 d8.axis_iter_mut(ndarray::Axis(0))
695 .into_par_iter()
696 .enumerate()
697 .for_each(|(row, mut d8_row)| {
698 for col in 0..ncols {
699 let z = dem[[row, col]];
700 if z == nodata {
701 continue;
702 }
703
704 let mut dir = 0;
705 let mut max_slope = f64::MIN;
706 for i in 0..8 {
707 let rn: isize = row as isize + dy[i];
708 let cn: isize = col as isize + dx[i];
709 if rn < 0 || rn >= nrows as isize || cn < 0 || cn >= ncols as isize {
710 continue;
711 }
712 let z_n = dem[[rn as usize, cn as usize]];
713 if z_n != nodata {
714 let slope = (z - z_n) / grid_lengths[i];
715 if slope > max_slope && slope > 0.0 {
716 max_slope = slope;
717 dir = i;
718 }
719 }
720 }
721
722 if max_slope >= 0.0 {
723 d8_row[col] = 1 << dir;
724 } else {
725 d8_row[col] = 0u8;
726 }
727 }
728 });
729
730 return (d8, out_nodata);
731}
732
733
734/// Updates a D8 flow direction using a mask that indicates onland and ocean and return changed
735/// flattened indices.
736///
737/// This function will ensure your downstream traces calculated from a D8 will stop at the coast.
738/// Any cell whose direction leads us to the ocean will become a sink. Any cell in the ocean will
739/// be nodata
740///
741/// An alternative approach (which doesn't work) is to first mask the DEM with nodata and then
742/// calculate the D8. Unfortunately this results in downstream traces that get to the coast
743/// (nodata) and then head along the coast, sometimes for ages before coming to a natural sink.
744///
745/// # Parameters
746/// d8: A 2D u8 array representing the directions
747/// nodata: The nodata for the d8 (probably 255)
748/// onland: A 2D bool array representing the land
749///
750/// # Returns
751/// nothing, but updates the d8
752///
753/// # Example
754/// ```rust
755/// use ndarray::Array2;
756/// let mut d8 = Array2::from_shape_vec(
757/// (3, 3),
758/// vec![
759/// 4, 2, 1,
760/// 2, 2, 1,
761/// 1, 128, 4,
762/// ],
763/// ).expect("Failed to create D8");
764/// let onland = Array2::from_shape_vec(
765/// (3, 3),
766/// vec![
767/// true, true, false,
768/// true, true, false,
769/// true, true, false,
770/// ],
771/// ).expect("Failed to create onland");
772/// let newd8 = Array2::from_shape_vec(
773/// (3, 3),
774/// vec![
775/// 4, 0, 255,
776/// 2, 0, 255,
777/// 1, 128, 255,
778/// ],
779/// ).expect("Failed to create D8");
780/// let changed = hydro_analysis::d8_clipped(&mut d8, 255, &onland);
781/// assert_eq!(d8, newd8);
782/// assert_eq!(changed, vec![1, 4]);
783/// ```
784pub fn d8_clipped(d8: &mut Array2<u8>, nodata: u8, onland: &Array2<bool>) -> Vec<usize>
785{
786 let (nrows, ncols) = (d8.nrows(), d8.ncols());
787
788 // in the ocean is nodata
789 par_azip!((d in &mut *d8, m in onland){ if !m { *d = nodata } });
790
791 // so we can index using 1D
792 let slice = d8.as_slice().unwrap();
793
794 let tozero: Vec<usize> = (0..nrows).into_par_iter().flat_map(|r| {
795 let mut local: Vec<usize> = Vec::new();
796 for c in 0..ncols {
797 let idx = r*ncols + c;
798 if slice[idx] == nodata || slice[idx] == 0 {
799 continue;
800 }
801
802 // get next cell downstream
803 let (dr, dc) = match slice[idx] {
804 1 => (-1, 1),
805 2 => ( 0, 1),
806 4 => ( 1, 1),
807 8 => ( 1, 0),
808 16 => ( 1, -1),
809 32 => ( 0, -1),
810 64 => (-1, -1),
811 128 => (-1, 0),
812 _ => unreachable!(),
813 };
814 let rn = r as isize + dr;
815 let cn = c as isize + dc;
816
817 // if next is outside, don't change idx
818 if rn < 0 || rn >= nrows as isize || cn < 0 || cn >= ncols as isize {
819 continue;
820 }
821 // next is inside and nodata, then set idx to a sink
822 if d8[[rn as usize, cn as usize]] == nodata {
823 local.push(idx);
824 }
825 }
826 local
827 }).collect();
828
829 let slice = d8.as_slice_mut().unwrap();
830 for idx in &tozero {
831 slice[*idx] = 0;
832 }
833
834 tozero
835}
836
837
838/// Breach depressions least cost. Implements
839///
840/// [whitebox breach_depressions_least_cost](https://github.com/jblindsay/whitebox-tools/blob/master/whitebox-tools-app/src/tools/hydro_analysis/breach_depressions_least_cost.rs)
841///
842/// with
843/// max_cost set to infinity,
844/// flat_increment is default, and
845/// minimize_dist set to true.
846///
847/// with more modern and less memory intensive datastructures.
848///
849/// The only real parameter to tune is max_dist (measured in cells) and a default would be 20
850///
851/// # Returns
852/// number of pits that remain
853///
854/// # Examples
855///
856/// See tests/breach_depressions.rs for lots of examples, here is just one
857///
858/// ```rust
859/// use ndarray::{Array2, array};
860/// use hydro_analysis::breach_depressions;
861/// let resx = 8.0;
862/// let resy = 8.0;
863/// let max_dist = 100;
864/// let ep = 0.00000012;
865/// let mut dem: Array2<f64> = array![
866/// [2.0, 2.0, 2.0, 2.0, 0.0],
867/// [2.0, 1.0, 2.0, 0.5, 0.0],
868/// [2.0, 2.0, 2.0, 2.0, 0.0],
869/// ];
870/// let breached: Array2<f64> = array![
871/// [2.0, 2.0, 2.0, 2.0, 0.0],
872/// [2.0, 2.0-ep, 2.0-2.0*ep, 0.5, 0.0],
873/// [2.0, 2.0, 2.0, 2.0, 0.0],
874/// ];
875/// let n = breach_depressions(&mut dem, -1.0, resx, resy, max_dist);
876/// assert_eq!(n, 0);
877/// for (x, y) in dem.iter().zip(breached.iter()) {
878/// assert!((*x - *y).abs() < 1e-10);
879/// }
880///
881pub fn breach_depressions(dem: &mut Array2<f64>, nodata: f64, resx: f64, resy: f64, max_dist: usize) -> usize
882{
883
884 let diagres: f64 = (resx * resx + resy * resy).sqrt();
885 let cost_dist = [diagres, resx, diagres, resy, diagres, resx, diagres, resy];
886
887 let (nrows, ncols) = (dem.nrows(), dem.ncols());
888 let small_num = {
889 let diagres = (resx * resx + resy * resy).sqrt();
890 let elev_digits = (dem.iter().cloned().fold(f64::NEG_INFINITY, f64::max) as i64).to_string().len();
891 let elev_multiplier = 10.0_f64.powi((9 - elev_digits) as i32);
892 1.0_f64 / elev_multiplier as f64 * diagres.ceil()
893 };
894
895 // want to raise interior (ie not on boundary or with a nodata neighbour) pit cells up to
896 // just below minimum neighbour height.
897 //
898 // so first find the pits
899 let dx = [1, 1, 1, 0, -1, -1, -1, 0];
900 let dy = [-1, 0, 1, 1, 1, 0, -1, -1];
901 let mut pits: Vec<_> = (1..nrows - 1).into_par_iter().flat_map(|row| {
902 let mut local_pits = Vec::new();
903 for col in 1..ncols - 1 {
904 let z = dem[[row, col]];
905 if z == nodata {
906 continue;
907 }
908 let mut apit = true;
909 // is any neighbour lower than me or nodata?
910 for n in 0..8 {
911 let zn = dem[[(row as isize + dy[n]) as usize, (col as isize + dx[n]) as usize]];
912 if zn < z || zn == nodata {
913 apit = false;
914 break;
915 }
916 }
917 // no, so I am a pit
918 if apit {
919 local_pits.push((row, col, z));
920 }
921 }
922 local_pits
923 }).collect();
924
925 // set depth to just below min neighbour, can't do this in parallel, we update the dem and pits
926 for &mut (row, col, ref mut z) in pits.iter_mut() {
927 let min_zn: f64 = dx.iter().zip(&dy).map(|(&dxi, &dyi)|
928 dem[[
929 (row as isize + dyi) as usize,
930 (col as isize + dxi) as usize,
931 ]]
932 ).fold(f64::INFINITY, f64::min);
933 *z = min_zn - small_num;
934 dem[[row, col]] = *z;
935 }
936
937 // Sort highest to lowest so poping off the end will get the lowest, should do that first
938 // because might solve some higher ones
939 pits.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Equal));
940
941 // keep track of number we can't deal with to return
942 let mut num_unsolved: usize = 0;
943
944 // roughly how many cells we will have to look at
945 let num_to_chk: usize = (max_dist * 2 + 1) * (max_dist * 2 + 1);
946
947 // for keeping track of path be make
948 #[derive(Debug)]
949 struct NodeInfo {
950 prev: Option<(usize, usize)>,
951 length: usize,
952 }
953
954 // try and dig a channel from row, col
955 while let Some((row, col, z)) = pits.pop() {
956
957 // May have been solved during previous depression step, so can skip
958 if dx.iter().zip(&dy).any(|(&dxi, &dyi)|
959 dem[[(row as isize + dyi) as usize, (col as isize + dxi) as usize]] < z
960 ) {
961 continue;
962 }
963
964 // keep our path info in here
965 let mut visited: HashMap<(usize,usize), NodeInfo> = HashMap::default();
966 visited.insert((row,col), NodeInfo {
967 prev: None,
968 length: 0,
969 });
970
971 // cells we will check to see if lower than (row, col, z)
972 let mut cells_to_chk = BinaryHeap::with_capacity(num_to_chk);
973 cells_to_chk.push(GridCell {row: row, column: col, priority: 0.0 });
974
975 let mut dugit = false;
976 'search: while let Some(GridCell {row: chkrow, column: chkcol, priority: accum}) = cells_to_chk.pop() {
977
978 let length: usize = visited[&(chkrow,chkcol)].length;
979 let zn: f64 = dem[[chkrow, chkcol]];
980 let cost1: f64 = zn - z + length as f64 * small_num;
981 let mut breach: Option<(usize, usize)> = None;
982
983 // lets look about chkrow, chkcol
984 for n in 0..8 {
985 let rn: isize = chkrow as isize + dy[n];
986 let cn: isize = chkcol as isize + dx[n];
987
988 // chkrow, chkcol is a breach if on the edge
989 if rn < 0 || rn >= nrows as isize || cn < 0 || cn >= ncols as isize {
990 breach = Some((chkrow, chkcol));
991 // we need to force this boundary cell down
992 dem[[chkrow, chkcol]] = z - (length as f64 * small_num);
993 break;
994 }
995 let rn: usize = rn as usize;
996 let cn: usize = cn as usize;
997 let nextlen = length + 1;
998
999 // insert if not visited
1000 if let Vacant(e) = visited.entry((rn, cn)) {
1001 e.insert(NodeInfo {
1002 prev: Some((chkrow, chkcol)),
1003 length: nextlen
1004 });
1005 } else {
1006 continue;
1007 }
1008
1009 let zn: f64 = dem[[rn, cn]];
1010 // an internal nodata cannot be breach point
1011 if zn == nodata {
1012 continue;
1013 }
1014
1015 // zout is lowered from z by nextlen * slope
1016 let zout: f64 = z - (nextlen as f64 * small_num);
1017 if zn <= zout {
1018 breach = Some((rn, cn));
1019 break;
1020 }
1021
1022 // (rn, cn) no good, zn is too high, just push onto heap
1023 if zn > zout {
1024 let cost2: f64 = zn - zout;
1025 let new_cost: f64 = accum + (cost1 + cost2) / 2.0 * cost_dist[n];
1026 // but haven't travelled too far, so we will scan from this cell
1027 if nextlen <= max_dist {
1028 cells_to_chk.push(GridCell {
1029 row: rn,
1030 column: cn,
1031 priority: new_cost
1032 });
1033 }
1034 }
1035 }
1036
1037 if let Some(mut cur) = breach {
1038 loop {
1039 let node: &NodeInfo = &visited[&cur];
1040 // back at the pit?
1041 let Some(parent) = node.prev else {
1042 break;
1043 };
1044
1045 let zn = dem[[parent.0, parent.1]];
1046 let length = visited[&parent].length;
1047 let zout = z - (length as f64 * small_num);
1048 if zn > zout {
1049 dem[[parent.0, parent.1]] = zout;
1050 }
1051 cur = parent;
1052 }
1053 dugit = true;
1054 break 'search;
1055 }
1056
1057 }
1058
1059 if !dugit {
1060 // Didn't find any lower cells, tough luck
1061 num_unsolved += 1;
1062 }
1063 }
1064
1065 num_unsolved
1066}
1067