use rayon::prelude::*;
use std::collections::{HashMap, BinaryHeap, VecDeque};
use std::cmp::Ordering;
use std::cmp::Ordering::Equal;
use std::collections::hash_map::Entry::Vacant;
use ndarray::{Array2, par_azip};
use std::{fs::File, f64, path::PathBuf};
use thiserror::Error;
use bytemuck::cast_slice;
use tiff::decoder::DecodingResult;
use tiff::encoder::compression::Deflate;
use tiff::encoder::colortype::{Gray8,Gray16,Gray32,Gray64,Gray32Float,Gray64Float,GrayI8,GrayI16,GrayI32,GrayI64};
use tiff::tags::Tag;
use tiff::TiffFormatError;
#[derive(Debug, Error)]
pub enum RasterError {
#[error("TIFF error: {0}")]
Tiff(#[from] tiff::TiffError),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("NDarray: {0}")]
Shape(#[from] ndarray::ShapeError),
#[error("Failed to parse nodata value")]
ParseIntError(#[from] std::num::ParseIntError),
#[error("Failed to parse nodata value")]
ParseFloatError(#[from] std::num::ParseFloatError),
#[error("Unsupported type: {0}")]
UnsupportedType(String)
}
pub fn rasterfile_to_array<T>(fname: &PathBuf) -> Result<
(
Array2<T>,
T, // nodata
u16, // crs
[f64; 6], // geo transform [start_x, psize_x, rotation, starty, rotation, psize_y]
Vec<u64>, // geo dir, it has the crs in it
String // the projection string
),
RasterError
>
where T: std::str::FromStr + num::FromPrimitive,
<T as std::str::FromStr>::Err: std::fmt::Debug,
RasterError: std::convert::From<<T as std::str::FromStr>::Err> { let file = File::open(fname)?;
let mut decoder = tiff::decoder::Decoder::new(file)?;
decoder = decoder.with_limits(tiff::decoder::Limits::unlimited());
let (width, height) = decoder.dimensions()?;
fn estr<T>(etype: &'static str) -> RasterError {
RasterError::Tiff(TiffFormatError::Format(format!("Raster is {}, I was expecting {}", etype, std::any::type_name::<T>()).into()).into())
}
let data: Vec<T> = match decoder.read_image()? {
DecodingResult::I8(buf) => buf.into_iter().map(|v| <T>::from_i8(v).ok_or(estr::<T>("I8"))).collect::<Result<_, _>>(),
DecodingResult::I16(buf) => buf.into_iter().map(|v| <T>::from_i16(v).ok_or(estr::<T>("I16"))).collect::<Result<_, _>>(),
DecodingResult::I32(buf) => buf.into_iter().map(|v| <T>::from_i32(v).ok_or(estr::<T>("I32"))).collect::<Result<_, _>>(),
DecodingResult::I64(buf) => buf.into_iter().map(|v| <T>::from_i64(v).ok_or(estr::<T>("I64"))).collect::<Result<_, _>>(),
DecodingResult::U8(buf) => buf.into_iter().map(|v| <T>::from_u8(v).ok_or(estr::<T>("U8"))).collect::<Result<_, _>>(),
DecodingResult::U16(buf) => buf.into_iter().map(|v| <T>::from_u16(v).ok_or(estr::<T>("U16"))).collect::<Result<_, _>>(),
DecodingResult::U32(buf) => buf.into_iter().map(|v| <T>::from_u32(v).ok_or(estr::<T>("U32"))).collect::<Result<_, _>>(),
DecodingResult::U64(buf) => buf.into_iter().map(|v| <T>::from_u64(v).ok_or(estr::<T>("U64"))).collect::<Result<_, _>>(),
DecodingResult::F32(buf) => buf.into_iter().map(|v| <T>::from_f32(v).ok_or(estr::<T>("F32"))).collect::<Result<_, _>>(),
DecodingResult::F64(buf) => buf.into_iter().map(|v| <T>::from_f64(v).ok_or(estr::<T>("F64"))).collect::<Result<_, _>>(),
}?;
let array: Array2<T> = Array2::from_shape_vec((height as usize, width as usize), data)?;
let nodata: T = decoder.get_tag_ascii_string(Tag::GdalNodata)?.trim().parse::<T>()?;
let pscale: Vec<f64> = decoder.get_tag_f64_vec(Tag::ModelPixelScaleTag)?.into_iter().collect();
let tie: Vec<f64> = decoder.get_tag_f64_vec(Tag::ModelTiepointTag)?.into_iter().collect();
let geotrans: [f64; 6] = [tie[3], pscale[0], 0.0, tie[4], 0.0, -pscale[1]];
let projection: String = decoder.get_tag_ascii_string(Tag::GeoAsciiParamsTag)?;
let geokeydir: Vec<u64> = decoder .get_tag_u64_vec(Tag::GeoKeyDirectoryTag)?;
let crs = geokeydir.windows(4).find(|w| w[0] == 3072).map(|w| w[3])
.ok_or(RasterError::Tiff(tiff::TiffFormatError::InvalidTagValueType(Tag::GeoKeyDirectoryTag).into()))? as u16;
Ok((array, nodata, crs, geotrans, geokeydir, projection))
}
pub fn array_to_rasterfile<T>(
data: &Array2<T>,
nd: T, geotrans: &[f64; 6], geokeydir: &[u64], proj: &str, outfile: &PathBuf
) -> Result<(), RasterError>
where T: bytemuck::Pod + ToString
{
let (nrows, ncols) = (data.nrows(), data.ncols());
let fh = File::create(outfile)?;
let mut encoder = tiff::encoder::TiffEncoder::new(fh)?;
macro_rules! writit {
($pix:ty) => {{
let mut image = encoder.new_image_with_compression::<$pix, Deflate>(ncols as u32, nrows as u32, Deflate::default())?;
image.encoder().write_tag(Tag::GdalNodata, &nd.to_string()[..])?;
image.encoder().write_tag(Tag::ModelPixelScaleTag, &[geotrans[1], -geotrans[5], 0.0][..])?;
image.encoder().write_tag(Tag::ModelTiepointTag, &[0.0, 0.0, 0.0, geotrans[0], geotrans[3], 0.0][..])?;
image.encoder().write_tag(Tag::GeoKeyDirectoryTag, geokeydir)?;
image.encoder().write_tag(Tag::GeoAsciiParamsTag, &proj)?;
image.write_data(cast_slice(data.as_slice().unwrap()))?;
}};
}
match std::any::TypeId::of::<T>() {
id if id == std::any::TypeId::of::<u8>() => writit!(Gray8),
id if id == std::any::TypeId::of::<u16>() => writit!(Gray16),
id if id == std::any::TypeId::of::<u32>() => writit!(Gray32),
id if id == std::any::TypeId::of::<u64>() => writit!(Gray64),
id if id == std::any::TypeId::of::<f32>() => writit!(Gray32Float),
id if id == std::any::TypeId::of::<f64>() => writit!(Gray64Float),
id if id == std::any::TypeId::of::<i8>() => writit!(GrayI8),
id if id == std::any::TypeId::of::<i16>() => writit!(GrayI16),
id if id == std::any::TypeId::of::<i32>() => writit!(GrayI32),
id if id == std::any::TypeId::of::<i64>() => writit!(GrayI64),
_ => return Err(RasterError::UnsupportedType(format!("Cannot handle type {}", std::any::type_name::<T>())))
};
Ok(())
}
#[derive(PartialEq, Debug)]
struct GridCell {
row: usize,
column: usize,
priority: f64,
}
impl Eq for GridCell {}
impl PartialOrd for GridCell {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.priority.partial_cmp(&self.priority)
}
}
impl Ord for GridCell {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
#[derive(PartialEq, Debug)]
struct GridCell2 {
row: usize,
column: usize,
z: f64,
priority: f64,
}
impl Eq for GridCell2 {}
impl PartialOrd for GridCell2 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.priority.partial_cmp(&self.priority)
}
}
impl Ord for GridCell2 {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
pub fn fill_depressions(
dem: &mut Array2<f64>, nodata: f64, resx: f64, resy: f64, fix_flats: bool
)
{
let (rows, columns) = (dem.nrows(), dem.ncols());
let small_num = {
let diagres = (resx * resx + resy * resy).sqrt();
let elev_digits = (dem.iter().cloned().fold(f64::NEG_INFINITY, f64::max) as i64).to_string().len();
let elev_multiplier = 10.0_f64.powi((9 - elev_digits) as i32);
1.0_f64 / elev_multiplier as f64 * diagres.ceil()
};
let dx = [1, 1, 1, 0, -1, -1, -1, 0];
let dy = [-1, 0, 1, 1, 1, 0, -1, -1];
let mut pits: Vec<_> = (1..rows - 1)
.into_par_iter()
.flat_map(|row| {
let mut local_pits = Vec::new();
for col in 1..columns - 1 {
let z = dem[[row, col]];
if z == nodata {
continue;
}
let mut apit = true;
for n in 0..8 {
let zn = dem[[(row as isize + dy[n]) as usize, (col as isize + dx[n]) as usize]];
if zn < z || zn == nodata {
apit = false;
break;
}
}
if apit {
local_pits.push((row, col, z));
}
}
local_pits
}).collect();
let mut minheap = BinaryHeap::new();
let mut minheap2 = BinaryHeap::new();
let mut visited = Array2::<u8>::zeros((rows, columns));
let mut flats = Array2::<u8>::zeros((rows, columns));
let mut possible_outlets = vec![];
let mut queue = VecDeque::new();
pits.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(Equal));
while let Some(cell) = pits.pop() {
let row: usize = cell.0;
let col: usize = cell.1;
if flats[[row, col]] != 1 {
minheap.clear();
minheap.push(GridCell {
row: row,
column: col,
priority: dem[[row, col]],
});
visited[[row, col]] = 1;
let mut outlet_found = false;
let mut outlet_z = f64::INFINITY;
if !queue.is_empty() {
queue.clear();
}
while let Some(cell2) = minheap.pop() {
let z = cell2.priority;
if outlet_found && z > outlet_z {
break;
}
if !outlet_found {
for n in 0..8 {
let cn = cell2.column as isize + dx[n];
let rn = cell2.row as isize + dy[n];
if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
continue;
}
let cn = cn as usize;
let rn = rn as usize;
if visited[[rn, cn]] == 0 {
let zn = dem[[rn, cn]];
if !outlet_found {
if zn >= z && zn != nodata {
minheap.push(GridCell {
row: rn,
column: cn,
priority: zn,
});
visited[[rn, cn]] = 1;
} else if zn != nodata {
outlet_found = true;
outlet_z = z;
queue.push_back((cell2.row, cell2.column));
possible_outlets.push((cell2.row, cell2.column));
}
} else if zn == outlet_z {
minheap.push(GridCell {
row: rn,
column: cn,
priority: zn,
});
visited[[rn, cn]] = 1;
}
}
}
} else {
if z == outlet_z {
let mut anoutlet = false;
for n in 0..8 {
let cn = cell2.column as isize + dx[n];
let rn = cell2.row as isize + dy[n];
if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
continue;
}
let cn = cn as usize;
let rn = rn as usize;
if visited[[rn, cn]] == 0 {
let zn = dem[[rn, cn]];
if zn < z {
anoutlet = true;
} else if zn == outlet_z {
minheap.push(GridCell {
row: rn,
column: cn,
priority: zn,
});
visited[[rn, cn]] = 1;
}
}
}
if anoutlet {
queue.push_back((cell2.row, cell2.column));
possible_outlets.push((cell2.row, cell2.column));
} else {
visited[[cell2.row, cell2.column]] = 1;
}
}
}
}
if outlet_found {
while let Some(cell2) = queue.pop_front() {
for n in 0..8 {
let cn = cell2.1 as isize + dx[n];
let rn = cell2.0 as isize + dy[n];
if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
continue;
}
let cn = cn as usize;
let rn = rn as usize;
if visited[[rn, cn]] == 1 {
visited[[rn, cn]] = 0;
queue.push_back((rn, cn));
let z = dem[[rn, cn]];
if z < outlet_z {
dem[[rn, cn]] = outlet_z;
flats[[rn, cn]] = 1;
} else if z == outlet_z {
flats[[rn, cn]] = 1;
}
}
}
}
} else {
queue.push_back((row, col)); while let Some(cell2) = queue.pop_front() {
for n in 0..8 {
let cn = cell2.1 as isize + dx[n];
let rn = cell2.0 as isize + dy[n];
if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
continue;
}
let cn = cn as usize;
let rn = rn as usize;
if visited[[rn, cn]] == 1 {
visited[[rn, cn]] = 0;
queue.push_back((rn, cn));
}
}
}
}
}
}
drop(visited);
if small_num > 0.0 && fix_flats {
minheap.clear();
while let Some(cell) = possible_outlets.pop() {
let z = dem[[cell.0, cell.1]];
let mut anoutlet = false;
for n in 0..8 {
let rn: isize = cell.0 as isize + dy[n];
let cn: isize = cell.1 as isize + dx[n];
if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
continue;
}
let zn = dem[[rn as usize, cn as usize]];
if zn < z && zn != nodata {
anoutlet = true;
break;
}
}
if anoutlet {
minheap.push(GridCell {
row: cell.0,
column: cell.1,
priority: z,
});
}
}
let mut outlets = vec![];
while let Some(cell) = minheap.pop() {
if flats[[cell.row, cell.column]] != 3 {
let z = dem[[cell.row, cell.column]];
flats[[cell.row, cell.column]] = 3;
if !outlets.is_empty() {
outlets.clear();
}
outlets.push(cell);
let mut flag = true;
while flag {
match minheap.peek() {
Some(cell2) => {
if cell2.priority == z {
flats[[cell2.row, cell2.column]] = 3;
outlets
.push(minheap.pop().expect("Error during pop operation."));
} else {
flag = false;
}
}
None => {
flag = false;
}
}
}
if !minheap2.is_empty() {
minheap2.clear();
}
for cell2 in &outlets {
let z = dem[[cell2.row, cell2.column]];
for n in 0..8 {
let cn = cell2.column as isize + dx[n];
let rn = cell2.row as isize + dy[n];
if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
continue;
}
let cn = cn as usize;
let rn = rn as usize;
if flats[[rn, cn]] != 3 {
let zn = dem[[rn, cn]];
if zn == z && zn != nodata {
minheap2.push(GridCell2 {
row: rn,
column: cn,
z: z,
priority: dem[[rn, cn]], });
dem[[rn, cn]] = z + small_num;
flats[[rn, cn]] = 3;
}
}
}
}
while let Some(cell2) = minheap2.pop() {
let z = dem[[cell2.row, cell2.column]];
for n in 0..8 {
let cn = cell2.column as isize + dx[n];
let rn = cell2.row as isize + dy[n];
if rn < 0 || rn as usize >= rows || cn < 0 || cn as usize >= columns {
continue;
}
let cn = cn as usize;
let rn = rn as usize;
if flats[[rn, cn]] != 3 {
let zn = dem[[rn, cn]];
if zn < z + small_num && zn >= cell2.z && zn != nodata {
minheap2.push(GridCell2 {
row: rn,
column: cn,
z: cell2.z,
priority: dem[[rn, cn]], });
dem[[rn, cn]] = z + small_num;
flats[[rn, cn]] = 3;
}
}
}
}
}
}
}
}
pub fn d8_pointer(dem: &Array2<f64>, nodata: f64, resx: f64, resy: f64) -> (Array2<u8>, u8)
{
let (nrows, ncols) = (dem.nrows(), dem.ncols());
let out_nodata: u8 = 255;
let mut d8: Array2<u8> = Array2::from_elem((nrows, ncols), out_nodata);
let diag = (resx * resx + resy * resy).sqrt();
let grid_lengths = [diag, resx, diag, resy, diag, resx, diag, resy];
let dx = [1, 1, 1, 0, -1, -1, -1, 0];
let dy = [-1, 0, 1, 1, 1, 0, -1, -1];
d8.axis_iter_mut(ndarray::Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(row, mut d8_row)| {
for col in 0..ncols {
let z = dem[[row, col]];
if z == nodata {
continue;
}
let mut dir = 0;
let mut max_slope = f64::MIN;
for i in 0..8 {
let rn: isize = row as isize + dy[i];
let cn: isize = col as isize + dx[i];
if rn < 0 || rn >= nrows as isize || cn < 0 || cn >= ncols as isize {
continue;
}
let z_n = dem[[rn as usize, cn as usize]];
if z_n != nodata {
let slope = (z - z_n) / grid_lengths[i];
if slope > max_slope && slope > 0.0 {
max_slope = slope;
dir = i;
}
}
}
if max_slope >= 0.0 {
d8_row[col] = 1 << dir;
} else {
d8_row[col] = 0u8;
}
}
});
return (d8, out_nodata);
}
pub fn d8_clipped(d8: &mut Array2<u8>, nodata: u8, onland: &Array2<bool>) -> Vec<usize>
{
let (nrows, ncols) = (d8.nrows(), d8.ncols());
par_azip!((d in &mut *d8, m in onland){ if !m { *d = nodata } });
let slice = d8.as_slice().unwrap();
let tozero: Vec<usize> = (0..nrows).into_par_iter().flat_map(|r| {
let mut local: Vec<usize> = Vec::new();
for c in 0..ncols {
let idx = r*ncols + c;
if slice[idx] == nodata || slice[idx] == 0 {
continue;
}
let (dr, dc) = match slice[idx] {
1 => (-1, 1),
2 => ( 0, 1),
4 => ( 1, 1),
8 => ( 1, 0),
16 => ( 1, -1),
32 => ( 0, -1),
64 => (-1, -1),
128 => (-1, 0),
_ => unreachable!(),
};
let rn = r as isize + dr;
let cn = c as isize + dc;
if rn < 0 || rn >= nrows as isize || cn < 0 || cn >= ncols as isize {
continue;
}
if d8[[rn as usize, cn as usize]] == nodata {
local.push(idx);
}
}
local
}).collect();
let slice = d8.as_slice_mut().unwrap();
for idx in &tozero {
slice[*idx] = 0;
}
tozero
}
pub fn breach_depressions(dem: &mut Array2<f64>, nodata: f64, resx: f64, resy: f64, max_dist: usize) -> usize
{
let diagres: f64 = (resx * resx + resy * resy).sqrt();
let cost_dist = [diagres, resx, diagres, resy, diagres, resx, diagres, resy];
let (nrows, ncols) = (dem.nrows(), dem.ncols());
let small_num = {
let diagres = (resx * resx + resy * resy).sqrt();
let elev_digits = (dem.iter().cloned().fold(f64::NEG_INFINITY, f64::max) as i64).to_string().len();
let elev_multiplier = 10.0_f64.powi((9 - elev_digits) as i32);
1.0_f64 / elev_multiplier as f64 * diagres.ceil()
};
let dx = [1, 1, 1, 0, -1, -1, -1, 0];
let dy = [-1, 0, 1, 1, 1, 0, -1, -1];
let mut pits: Vec<_> = (1..nrows - 1).into_par_iter().flat_map(|row| {
let mut local_pits = Vec::new();
for col in 1..ncols - 1 {
let z = dem[[row, col]];
if z == nodata {
continue;
}
let mut apit = true;
for n in 0..8 {
let zn = dem[[(row as isize + dy[n]) as usize, (col as isize + dx[n]) as usize]];
if zn < z || zn == nodata {
apit = false;
break;
}
}
if apit {
local_pits.push((row, col, z));
}
}
local_pits
}).collect();
for &mut (row, col, ref mut z) in pits.iter_mut() {
let min_zn: f64 = dx.iter().zip(&dy).map(|(&dxi, &dyi)|
dem[[
(row as isize + dyi) as usize,
(col as isize + dxi) as usize,
]]
).fold(f64::INFINITY, f64::min);
*z = min_zn - small_num;
dem[[row, col]] = *z;
}
pits.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Equal));
let mut num_unsolved: usize = 0;
let num_to_chk: usize = (max_dist * 2 + 1) * (max_dist * 2 + 1);
#[derive(Debug)]
struct NodeInfo {
prev: Option<(usize, usize)>,
length: usize,
}
while let Some((row, col, z)) = pits.pop() {
if dx.iter().zip(&dy).any(|(&dxi, &dyi)|
dem[[(row as isize + dyi) as usize, (col as isize + dxi) as usize]] < z
) {
continue;
}
let mut visited: HashMap<(usize,usize), NodeInfo> = HashMap::default();
visited.insert((row,col), NodeInfo {
prev: None,
length: 0,
});
let mut cells_to_chk = BinaryHeap::with_capacity(num_to_chk);
cells_to_chk.push(GridCell {row: row, column: col, priority: 0.0 });
let mut dugit = false;
'search: while let Some(GridCell {row: chkrow, column: chkcol, priority: accum}) = cells_to_chk.pop() {
let length: usize = visited[&(chkrow,chkcol)].length;
let zn: f64 = dem[[chkrow, chkcol]];
let cost1: f64 = zn - z + length as f64 * small_num;
let mut breach: Option<(usize, usize)> = None;
for n in 0..8 {
let rn: isize = chkrow as isize + dy[n];
let cn: isize = chkcol as isize + dx[n];
if rn < 0 || rn >= nrows as isize || cn < 0 || cn >= ncols as isize {
breach = Some((chkrow, chkcol));
dem[[chkrow, chkcol]] = z - (length as f64 * small_num);
break;
}
let rn: usize = rn as usize;
let cn: usize = cn as usize;
let nextlen = length + 1;
if let Vacant(e) = visited.entry((rn, cn)) {
e.insert(NodeInfo {
prev: Some((chkrow, chkcol)),
length: nextlen
});
} else {
continue;
}
let zn: f64 = dem[[rn, cn]];
if zn == nodata {
continue;
}
let zout: f64 = z - (nextlen as f64 * small_num);
if zn <= zout {
breach = Some((rn, cn));
break;
}
if zn > zout {
let cost2: f64 = zn - zout;
let new_cost: f64 = accum + (cost1 + cost2) / 2.0 * cost_dist[n];
if nextlen <= max_dist {
cells_to_chk.push(GridCell {
row: rn,
column: cn,
priority: new_cost
});
}
}
}
if let Some(mut cur) = breach {
loop {
let node: &NodeInfo = &visited[&cur];
let Some(parent) = node.prev else {
break;
};
let zn = dem[[parent.0, parent.1]];
let length = visited[&parent].length;
let zout = z - (length as f64 * small_num);
if zn > zout {
dem[[parent.0, parent.1]] = zout;
}
cur = parent;
}
dugit = true;
break 'search;
}
}
if !dugit {
num_unsolved += 1;
}
}
num_unsolved
}