eorst 1.0.1

Earth Observation and Remote Sensing Toolkit - library for raster processing pipelines
//! Array manipulation operations for raster processing.
//!
//! This module provides functions for slicing, trimming, and transforming
//! multidimensional arrays used in raster data processing.

use crate::types::{Index2d, Overlap, Rectangle};
use csv::WriterBuilder;
use ndarray::{s, Array, Array2, Array3, Array4, Axis, ArrayView2, ArrayView3, ArrayView4};
use ndarray_csv::Array2Writer;
use rand::Rng;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use std::cmp::max;
use std::fs::File;
use std::path::PathBuf;

/// Extracts a rectangular view from a 2D array.
pub fn rect_view<'a, T>(
    array: &'a ArrayView2<T>,
    rect: Rectangle,
    indices: Index2d,
) -> ArrayView2<'a, T> {
    let max_row = indices.row + rect.bottom;
    let min_row = max(0, indices.row as i32 - rect.top as i32) as usize;

    let max_col = indices.col + rect.right;
    let min_col = max(0, indices.col as i32 - rect.left as i32) as usize;

    array.slice(s![min_row..max_row - 1, min_col..max_col - 1])
}

/// Writes a 2D vector to a CSV file.
pub fn write_csv_array(data: Vec<Vec<i16>>, csv_filename: PathBuf) {
    let rows = data.len();
    let cols = data[0].len();
    let flat: Vec<i16> = data.into_iter().flatten().collect();
    let array = Array::from(flat)
        .into_shape_clone((rows, cols))
        .expect("Unable to reshape");
    {
        let file = File::create(csv_filename).expect("Unable to create file");
        let mut writer = WriterBuilder::new().has_headers(true).from_writer(file);
        writer
            .serialize_array2(&array)
            .expect("Unable to serialize array to file");
    }
}

/// Trims the overlap border from a 3D array.
pub fn trimm_array3<T>(array: &Array3<T>, overlap_size: usize) -> ArrayView3<'_, T> {
    let min_row = overlap_size;
    let max_row = array.shape()[1] - overlap_size;

    let min_col = overlap_size;
    let max_col = array.shape()[2] - overlap_size;

    array.slice(s![.., min_row..max_row, min_col..max_col])
}

/// Returns the index of the maximum value in a slice.
pub fn argmax<T: PartialOrd>(xs: &[T]) -> usize {
    if xs.len() == 1 {
        0
    } else {
        let mut maxval = &xs[0];
        let mut max_ixs: Vec<usize> = vec![0];
        for (i, x) in xs.iter().enumerate().skip(1) {
            if x > maxval {
                maxval = x;
                max_ixs = vec![i];
            } else if x == maxval {
                max_ixs.push(i);
            }
        }
        max_ixs[0]
    }
}

/// Trims the overlap border from a 4D array.
/// The memory layout of the output is not guaranteed.
pub fn trimm_array4<T>(array: &Array4<T>, overlap_size: usize) -> ArrayView4<'_, T> {
    let min_row = overlap_size;
    let max_row = array.shape()[2] - overlap_size;

    let min_col = overlap_size;
    let max_col = array.shape()[3] - overlap_size;
    let slice = s![.., .., min_row..max_row, min_col..max_col];
    array.slice(slice)
}

/// Trims the overlap border from a 3D array using asymmetric overlap.
pub fn trimm_array3_asymmetric<'a, T>(
    array: &'a Array3<T>,
    overlap: &Overlap,
) -> ArrayView3<'a, T> {
    let min_row = overlap.top;
    let max_row = array.shape()[1] - overlap.bottom;

    let min_col = overlap.left;
    let max_col = array.shape()[2] - overlap.right;

    array.slice(s![.., min_row..max_row, min_col..max_col])
}

/// Trims the overlap border from a 4D array, returning an owned array with standard layout.
pub fn trimm_array4_owned<T: std::clone::Clone + std::fmt::Debug>(
    array: &Array4<T>,
    overlap: &Overlap,
) -> Array4<T> {
    let min_row = overlap.top;
    let max_row = array.shape()[2] - overlap.bottom;

    let min_col = overlap.left;
    let max_col = array.shape()[3] - overlap.right;
    let slice = s![.., .., min_row..max_row, min_col..max_col];
    let trimmed = array.slice(slice);
    let result = trimmed.as_standard_layout().to_owned();
    result
}

/// Creates an array with clustered values (for testing).
pub fn create_clustered_array(size: usize, num_values: u32, cluster_size: usize) -> Array2<u32> {
    let mut array: Array2<u32> = Array::zeros((size, size));

    let loop_size = size * size / cluster_size / cluster_size;
    for idx in 0..loop_size {
        let mut rng = ChaCha8Rng::seed_from_u64(idx.try_into().unwrap());
        let value = rng.gen_range(1..=num_values);
        let row = rng.gen_range(0..size);
        let col = rng.gen_range(0..size);

        for i in 0..cluster_size {
            for j in 0..cluster_size {
                let x = (row + i) % size;
                let y = (col + j) % size;
                array[[x, y]] = value;
            }
        }
    }

    array
}

#[allow(dead_code)]
pub(crate) fn array2_to_nested_vec<T: std::clone::Clone>(arr: &Array2<T>) -> Vec<Vec<T>> {
    let mut res: Vec<Vec<T>> = Vec::new();
    arr.axis_iter(Axis(0)) //
        .for_each(|row| res.push(row.to_vec()));
    res
}

/// Fills no-data values in a 3D array using simple neighbor propagation.
pub fn fill_nodata_simple(array: &mut Array3<f32>, nodata: f32) {
    let (bands, rows, cols) = array.dim();

    for b in 0..bands {
        let mut band_view = array.slice_mut(s![b, .., ..]);
        let mut changes = true;

        // repeat until all NaNs are filled
        while changes {
            changes = false;
            let copy = band_view.to_owned();
            for r in 0..rows {
                for c in 0..cols {
                    if band_view[[r, c]] == nodata {
                        let mut sum = 0.0;
                        let mut count = 0;
                        for dr in -1..=1 {
                            for dc in -1..=1 {
                                let nr = r as isize + dr;
                                let nc = c as isize + dc;
                                if nr >= 0 && nr < rows as isize && nc >= 0 && nc < cols as isize {
                                    let val = copy[[nr as usize, nc as usize]];
                                    if val != nodata {
                                        sum += val;
                                        count += 1;
                                    }
                                }
                            }
                        }
                        if count > 0 {
                            band_view[[r, c]] = sum / count as f32;
                            changes = true;
                        }
                    }
                }
            }
        }
    }
}