eorst 1.0.1

Earth Observation and Remote Sensing Toolkit - library for raster processing pipelines
//! Sampling and extraction methods for RasterDataset.
//!
//! This module contains methods for extracting values from raster blocks at vector points.

use crate::core_types::RasterType;
use crate::gdal_utils::create_rayon_pool;
use crate::rasterdataset::builder::n_block_cols;
use crate::types::{Coordinates, Index2d, Rectangle, SamplingMethod};
use crate::rasterdataset::RasterDataset;

use anyhow::Result;
use gdal::Dataset;
use gdal::vector::{Geometry, LayerAccess};
use itertools::Itertools;
use kdam::par_tqdm;
use ndarray::{s, Array2};
use rayon::prelude::*;
use std::collections::BTreeMap;
use std::hash::Hash;
use std::path::Path;

/// Samples a value from a 2D array at the given point using the specified method.
///
/// Replaces the duplicated `match method { SamplingMethod::* }` pattern found in
/// `extract_blockwise()` and `extract()`.
fn sample_value(
    band_data: &ndarray::ArrayView2<i16>,
    rect: Rectangle,
    point: Index2d,
    method: SamplingMethod,
) -> i16 {
    match method {
        SamplingMethod::Value => band_data[(point.row, point.col)],
        SamplingMethod::Avg => {
            let window_data: Vec<i16> =
                crate::array_ops::rect_view(band_data, rect, point)
                    .iter()
                    .copied()
                    .collect();
            let window_size = window_data.len();
            let avg: f32 = window_data
                .iter()
                .map(|v| *v as f32 / window_size as f32)
                .sum();
            avg.round() as i16
        }
        SamplingMethod::Mode => {
            let mut window_data: Vec<i16> =
                crate::array_ops::rect_view(band_data, rect, point)
                    .iter()
                    .copied()
                    .collect();
            window_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
            window_data[window_data.len() / 2]
        }
        SamplingMethod::Min => {
            let window_data: Vec<i16> =
                crate::array_ops::rect_view(band_data, rect, point)
                    .iter()
                    .copied()
                    .collect();
            *window_data.iter().min().unwrap()
        }
        SamplingMethod::StdDev => {
            let window_data: Vec<i16> =
                crate::array_ops::rect_view(band_data, rect, point)
                    .iter()
                    .copied()
                    .collect();
            let sum: i32 = window_data.iter().map(|&x| x as i32).sum();
            let mean = sum as f64 / window_data.len() as f64;
            let variance: f64 = window_data
                .iter()
                .map(|&x| (x as f64 - mean).powi(2))
                .sum::<f64>()
                / window_data.len() as f64;
            variance.sqrt().round() as i16
        }
    }
}

/// Shared helper: validates buffer size against overlap size.
fn validate_buffer_size(buffer_size: usize, overlap_size: usize) {
    assert!(
        buffer_size <= overlap_size,
        "Buffer size has to be > overlap size"
    );
}

/// Shared helper: creates a Rectangle from buffer size.
fn make_rectangle(buffer_size: usize) -> Rectangle {
    Rectangle {
        left: buffer_size,
        top: buffer_size,
        right: buffer_size,
        bottom: buffer_size,
    }
}

/// Shared helper: builds the index-to-blocks pipeline.
/// Returns (id_indices, blocks_to_process) from geometry data.
fn build_block_index_pipeline<R: RasterType>(
    raster: &RasterDataset<R>,
    geoms: &BTreeMap<i64, Vec<(f64, f64, f64)>>,
) -> (BTreeMap<i64, Index2d>, Vec<(usize, (i64, Index2d))>, Vec<usize>) {
    let idx_global = raster.geoms_to_global_indices(geoms.clone());

    let id_indices: Vec<(usize, (i64, Index2d))> = idx_global
        .par_iter()
        .map(|(pid, index)| raster.block_id_rowcol(*pid, *index))
        .collect();

    let block_ids: Vec<_> = idx_global
        .par_iter()
        .map(|(_, index)| raster.id_from_indices(*index))
        .collect();

    let blocks_to_process: Vec<usize> = block_ids.iter().unique().copied().collect();

    (idx_global, id_indices, blocks_to_process)
}

/// Shared helper: collects pos/idx/pids for a given block ID.
fn collect_points_for_block(
    id_indices: &[(usize, (i64, Index2d))],
    block_id: usize,
) -> (Vec<Index2d>, Vec<usize>, Vec<usize>) {
    let mut pos: Vec<Index2d> = Vec::new();
    let mut idx: Vec<usize> = Vec::new();
    let mut pids: Vec<usize> = Vec::new();
    for (pid, p) in id_indices.iter().enumerate() {
        if p.0 == block_id {
            pos.push(Index2d {
                col: p.1 .1.col,
                row: p.1 .1.row,
            });
            pids.push(pid);
            idx.push(p.1 .0 as usize);
        }
    }
    (pos, idx, pids)
}

/// Shared helper: assembles parallel block results into a BTreeMap.
/// Generic over the key type (i16 or i64).
fn assemble_block_results<K>(
    collected: &[(Vec<usize>, Vec<usize>, Vec<Vec<i16>>)],
    key_converter: fn(usize) -> K,
) -> BTreeMap<K, Vec<i16>>
where
    K: Ord + Hash,
{
    let pids: Vec<_> = collected.iter().map(|(pid, _, _)| pid).collect();
    let vals: Vec<_> = collected.iter().map(|(_, _, vals)| vals).collect();
    let idxs: Vec<_> = collected.iter().map(|(_, idx, _)| idx).collect();
    let mut results = BTreeMap::new();

    let num_bands = vals[0].len();
    let num_blocks = pids.len();
    for block in 0..num_blocks {
        for i in 0..pids[block].len() {
            let mut vals_point: Vec<i16> = Vec::new();
            let id = idxs[block][i];
            for band in 0..num_bands {
                vals_point.push(vals[block][band][i]);
            }
            let mut res_point = BTreeMap::new();
            res_point.insert(key_converter(id), vals_point);
            results.append(&mut res_point);
        }
    }
    results
}

impl<R> RasterDataset<R>
where
    R: RasterType,
{
    /// Converts geometries to global array indices.
    pub fn geoms_to_global_indices(
        &self,
        geoms: BTreeMap<i64, Vec<(f64, f64, f64)>>,
    ) -> BTreeMap<i64, Index2d> {
        let idx_global: BTreeMap<_, _> = geoms
            .par_iter()
            .map(|(pid, p)| {
                let point: Coordinates = Coordinates {
                    x: p[0].0,
                    y: p[0].1,
                };
                (*pid, self.geo_to_global_rc(point))
            })
            .collect();
        idx_global
    }

    fn geo_to_global_rc(&self, point: Coordinates) -> Index2d {
        let gt = self.metadata.geo_transform.to_array();
        let row = ((point.y - gt[3]) / gt[5]) as usize;
        let col = ((point.x - gt[0]) / gt[1]) as usize;
        Index2d { col, row }
    }

    /// Gets the block ID and local row/col for a global point ID.
    pub fn block_id_rowcol(&self, pid: i64, index: Index2d) -> (usize, (i64, Index2d)) {
        let id = self.id_from_indices(index);
        let row_col = self.global_rc_to_block_rc(index);
        (id, (pid, row_col))
    }

    fn global_rc_to_block_rc(&self, global_index: Index2d) -> Index2d {
        let mut block_col = global_index.col % self.metadata.block_size.cols;
        let mut block_row = global_index.row % self.metadata.block_size.rows;

        let block_col_ov = block_col + self.metadata.overlap_size;
        let block_row_ov = block_row + self.metadata.overlap_size;

        if (global_index.col as i16 - block_col_ov as i16) > 0 {
            block_col = block_col_ov;
        };

        if global_index.row as i16 - block_row_ov as i16 > 0 {
            block_row = block_row_ov;
        };

        Index2d {
            col: block_col,
            row: block_row,
        }
    }

    fn id_from_indices(&self, index: Index2d) -> usize {
        let n_block_cols = self.n_block_cols();
        (index.col / self.metadata.block_size.cols)
            + (index.row / self.metadata.block_size.rows) * n_block_cols
    }

    fn n_block_cols(&self) -> usize {
        let image_size = crate::types::ImageSize {
            rows: self.metadata.shape.rows,
            cols: self.metadata.shape.cols,
        };
        n_block_cols(image_size, self.metadata.block_size)
    }

    /// Extracts values from the raster dataset for vector features, block-wise.
    pub fn extract_blockwise(
        &self,
        vector_path: &std::path::PathBuf,
        id_col_name: &str,
        method: SamplingMethod,
        buffer_size: Option<usize>,
    ) -> BTreeMap<i16, Vec<i16>> {
        log::debug!("Starting extract.");
        let buffer_size = buffer_size.unwrap_or(0);
        validate_buffer_size(buffer_size, self.metadata.overlap_size);

        let vector_dataset = Dataset::open(Path::new(vector_path)).unwrap();
        let mut layer = vector_dataset.layer(0).unwrap();
        let mut geoms = BTreeMap::new();

        for feature in layer.features() {
            let mut geom = Vec::new();
            feature
                .geometry()
                .expect("Geometries")
                .get_points(&mut geom);
            let field_index = feature.field_index(id_col_name).expect("Bad column name.");
            let pid_filed = feature.field(field_index).unwrap().unwrap();
            let pid = pid_filed.into_int64().unwrap();
            geoms.insert(pid, geom);
        }

        let (_idx_global, id_indices, blocks_to_process) =
            build_block_index_pipeline(self, &geoms);
        drop(geoms);

        let pool = create_rayon_pool(1);
        let handle = pool.install(|| {
            par_tqdm!(blocks_to_process
                .into_par_iter())
                .map(|id| -> (Vec<usize>, Vec<usize>, Vec<Vec<i16>>) {
                    let (pos, idx, pids) = collect_points_for_block(&id_indices, id);

                    let mut res = Vec::new();
                    let rect = make_rectangle(buffer_size);

                    let data = self.read_block(id);

                    let bands = data.shape()[1];
                    log::debug!("Bands {:?}", bands);
                    for band_n in 0..bands {
                        let mut res_band = Vec::new();
                        let band_data = data.slice(s![0_i32, band_n, .., ..]);
                        for point in pos.iter() {
                            let val = sample_value(&band_data, rect, *point, method);
                            res_band.push(val);
                        }
                        res.push(res_band);
                    }
                    (pids, idx, res)
                })
        });

        let collected: Vec<_> = handle.collect();
        assemble_block_results(&collected, |id| id as i16)
    }

    /// Extracts values from the raster dataset for point geometries.
    pub fn extract(
        &self,
        geometries: &[Geometry],
        point_ids: &[i64],
        method: SamplingMethod,
        buffer_size: Option<usize>,
    ) -> Result<(Array2<i16>, Vec<i64>)> {
        let buffer_size = buffer_size.unwrap_or(0);
        validate_buffer_size(buffer_size, self.metadata.overlap_size);

        let mut geoms = BTreeMap::new();
        for (idx, point_id) in point_ids.iter().enumerate() {
            let geometry = &geometries[idx];
            let point = geometry.get_point(0);
            let (x, y, z) = point;
            geoms.insert(*point_id, vec![(x, y, z)]);
        }

        let (_idx_global, id_indices, blocks_to_process) =
            build_block_index_pipeline(self, &geoms);
        drop(geoms);

        let blocks_to_process: Vec<usize> = blocks_to_process;

        let pool = create_rayon_pool(1);
        let handle = pool.install(|| {
            par_tqdm!(blocks_to_process
                .into_par_iter())
                .map(|id| -> (Vec<usize>, Vec<usize>, Vec<Vec<i16>>) {
                    let (pos, idx, pids) = collect_points_for_block(&id_indices, id);
                    log::debug!("Extracting {} points, from block: {}", pos.len(), id);

                    let mut res = Vec::new();
                    let rect = make_rectangle(buffer_size);

                    let data = self.read_block(id);
                    let n_times = data.shape()[0];
                    let n_layers = data.shape()[1];
                    for time in 0..n_times {
                        for layer in 0..n_layers {
                            let mut res_band = Vec::new();
                            let band_data = data.slice(s![time, layer, .., ..]);
                            for point in pos.iter() {
                                let col = point.col.checked_sub(self.blocks[id].overlap.left);
                                let row = point.row.checked_sub(self.blocks[id].overlap.top);
                                let col = col.unwrap_or(point.col);
                                let row = row.unwrap_or(point.row);
                                let val = sample_value(&band_data, rect, Index2d { col, row }, method);
                                res_band.push(val);
                            }
                            res.push(res_band);
                        }
                    }
                    (pids, idx, res)
                })
        });

        let collected: Vec<_> = handle.collect();
        let results = assemble_block_results(&collected, |id| id as i64);

        let k = results.keys().next().unwrap();
        let n_rows = results.len();
        let n_cols = results[k].len();
        let mut array: Array2<i16> = ndarray::Array::zeros((n_rows, n_cols));
        for (row_index, values) in results.values().enumerate() {
            for (col_index, value) in values.iter().enumerate() {
                array[[row_index, col_index]] = *value;
            }
        }
        let pids: Vec<i64> = results.into_keys().collect();
        Ok((array, pids))
    }
}