use crate::types::{Index2d, Overlap, Rectangle};
use csv::WriterBuilder;
use ndarray::{s, Array, Array2, Array3, Array4, Axis, ArrayView2, ArrayView3, ArrayView4};
use ndarray_csv::Array2Writer;
use rand::Rng;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use std::cmp::max;
use std::fs::File;
use std::path::PathBuf;
pub fn rect_view<'a, T>(
array: &'a ArrayView2<T>,
rect: Rectangle,
indices: Index2d,
) -> ArrayView2<'a, T> {
let max_row = indices.row + rect.bottom;
let min_row = max(0, indices.row as i32 - rect.top as i32) as usize;
let max_col = indices.col + rect.right;
let min_col = max(0, indices.col as i32 - rect.left as i32) as usize;
array.slice(s![min_row..max_row - 1, min_col..max_col - 1])
}
pub fn write_csv_array(data: Vec<Vec<i16>>, csv_filename: PathBuf) {
let rows = data.len();
let cols = data[0].len();
let flat: Vec<i16> = data.into_iter().flatten().collect();
let array = Array::from(flat)
.into_shape_clone((rows, cols))
.expect("Unable to reshape");
{
let file = File::create(csv_filename).expect("Unable to create file");
let mut writer = WriterBuilder::new().has_headers(true).from_writer(file);
writer
.serialize_array2(&array)
.expect("Unable to serialize array to file");
}
}
pub fn trimm_array3<T>(array: &Array3<T>, overlap_size: usize) -> ArrayView3<'_, T> {
let min_row = overlap_size;
let max_row = array.shape()[1] - overlap_size;
let min_col = overlap_size;
let max_col = array.shape()[2] - overlap_size;
array.slice(s![.., min_row..max_row, min_col..max_col])
}
pub fn argmax<T: PartialOrd>(xs: &[T]) -> usize {
if xs.len() == 1 {
0
} else {
let mut maxval = &xs[0];
let mut max_ixs: Vec<usize> = vec![0];
for (i, x) in xs.iter().enumerate().skip(1) {
if x > maxval {
maxval = x;
max_ixs = vec![i];
} else if x == maxval {
max_ixs.push(i);
}
}
max_ixs[0]
}
}
pub fn trimm_array4<T>(array: &Array4<T>, overlap_size: usize) -> ArrayView4<'_, T> {
let min_row = overlap_size;
let max_row = array.shape()[2] - overlap_size;
let min_col = overlap_size;
let max_col = array.shape()[3] - overlap_size;
let slice = s![.., .., min_row..max_row, min_col..max_col];
array.slice(slice)
}
pub fn trimm_array3_asymmetric<'a, T>(
array: &'a Array3<T>,
overlap: &Overlap,
) -> ArrayView3<'a, T> {
let min_row = overlap.top;
let max_row = array.shape()[1] - overlap.bottom;
let min_col = overlap.left;
let max_col = array.shape()[2] - overlap.right;
array.slice(s![.., min_row..max_row, min_col..max_col])
}
pub fn trimm_array4_owned<T: std::clone::Clone + std::fmt::Debug>(
array: &Array4<T>,
overlap: &Overlap,
) -> Array4<T> {
let min_row = overlap.top;
let max_row = array.shape()[2] - overlap.bottom;
let min_col = overlap.left;
let max_col = array.shape()[3] - overlap.right;
let slice = s![.., .., min_row..max_row, min_col..max_col];
let trimmed = array.slice(slice);
let result = trimmed.as_standard_layout().to_owned();
result
}
pub fn create_clustered_array(size: usize, num_values: u32, cluster_size: usize) -> Array2<u32> {
let mut array: Array2<u32> = Array::zeros((size, size));
let loop_size = size * size / cluster_size / cluster_size;
for idx in 0..loop_size {
let mut rng = ChaCha8Rng::seed_from_u64(idx.try_into().unwrap());
let value = rng.gen_range(1..=num_values);
let row = rng.gen_range(0..size);
let col = rng.gen_range(0..size);
for i in 0..cluster_size {
for j in 0..cluster_size {
let x = (row + i) % size;
let y = (col + j) % size;
array[[x, y]] = value;
}
}
}
array
}
#[allow(dead_code)]
pub(crate) fn array2_to_nested_vec<T: std::clone::Clone>(arr: &Array2<T>) -> Vec<Vec<T>> {
let mut res: Vec<Vec<T>> = Vec::new();
arr.axis_iter(Axis(0)) .for_each(|row| res.push(row.to_vec()));
res
}
pub fn fill_nodata_simple(array: &mut Array3<f32>, nodata: f32) {
let (bands, rows, cols) = array.dim();
for b in 0..bands {
let mut band_view = array.slice_mut(s![b, .., ..]);
let mut changes = true;
while changes {
changes = false;
let copy = band_view.to_owned();
for r in 0..rows {
for c in 0..cols {
if band_view[[r, c]] == nodata {
let mut sum = 0.0;
let mut count = 0;
for dr in -1..=1 {
for dc in -1..=1 {
let nr = r as isize + dr;
let nc = c as isize + dc;
if nr >= 0 && nr < rows as isize && nc >= 0 && nc < cols as isize {
let val = copy[[nr as usize, nc as usize]];
if val != nodata {
sum += val;
count += 1;
}
}
}
}
if count > 0 {
band_view[[r, c]] = sum / count as f32;
changes = true;
}
}
}
}
}
}
}