eorst 1.0.1

Earth Observation and Remote Sensing Toolkit - library for raster processing pipelines
//! Zonal statistics methods for RasterDataset (Polars integration).
//!
//! This module requires the `use_polars` feature flag.

use crate::core_types::RasterType;
use crate::gdal_utils::{create_rayon_pool, read_raster_band};
use crate::rasterdataset::RasterDataset;

use anyhow::Result;
use ndarray::{Axis, ArrayView2, ArrayView3};
use ndhistogram::axis::UniformNoFlow;
use ndhistogram::Histogram;
use polars::prelude::{Column, DataFrame, ParquetReader, SerReader};
use rayon::prelude::*;
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::process::Command;
use uuid::Uuid;

type ZoneHistogram = ndhistogram::Hist1D<UniformNoFlow<f32>>;

/// Converts a nested histogram map into Polars columns and saves to Parquet.
///
/// Replaces the duplicated block found in both `zonal_histograms_raster()`
/// and `zonal_histograms_polygons()`.
macro_rules! histograms_to_dataframe {
    ($histograms_thread:expr, $histogram_fn:expr) => {{
        let mut zones = Vec::new();
        let mut times = Vec::new();
        let mut layers = Vec::new();
        let mut bin_starts = Vec::new();
        let mut bin_ends = Vec::new();
        let mut counts = Vec::new();

        for (z, histogram) in $histograms_thread {
            for col_histogram in histogram {
                for hist in col_histogram.1.iter() {
                    let tl: Vec<&str> = col_histogram.0.split("--").collect();
                    let t = tl[0];
                    let l = tl[1];
                    zones.push(*z as i64);
                    times.push(t);
                    layers.push(l);
                    bin_starts.push(hist.bin.start());
                    bin_ends.push(hist.bin.end());
                    counts.push(*hist.value as i64);
                }
            }
        }

        let zones = Column::new("zone".into(), &zones);
        let times = Column::new("time".into(), &times);
        let layers = Column::new("layer".into(), &layers);
        let bin_starts = Column::new("bin_star".into(), &bin_starts);
        let bin_ends = Column::new("bin_end".into(), &bin_ends);
        let counts = Column::new("count".into(), &counts);
        let vec_columns = vec![zones, times, layers, bin_starts, bin_ends, counts];

        let mut df = DataFrame::new_infer_height(vec_columns).unwrap();
        save_zonal_histograms(&mut df, $histogram_fn);
    }};
}

/// Saves zonal histogram data to a Parquet file.
pub fn save_zonal_histograms(df: &mut DataFrame, path: &std::path::Path) {
    use polars::prelude::ParquetWriter;
    let file = std::fs::File::create(path).unwrap();
    ParquetWriter::new(file).finish(df).unwrap();
}

fn merge_histogram_dfs(paths: Vec<PathBuf>) -> Result<DataFrame> {
    let mut dfs = Vec::new();
    for path in paths {
        let file = std::fs::File::open(&path)?;
        let df = ParquetReader::new(file).finish()?;
        dfs.push(df);
        std::fs::remove_file(path).ok();
    }
    let merged = polars::functions::concat_df_diagonal(&dfs)?;
    Ok(merged)
}

/// Computes zonal histograms from data and zones arrays.
///
/// Shared by both `zonal_histograms_raster()` and `zonal_histograms_polygons()`.
/// Extracts the duplicated histogram initialization + azip filling loop.
fn compute_zonal_histograms(
    data: ndarray::Array4<i16>,
    zones: ndarray::Array2<i16>,
    col_names: Vec<String>,
    zones_unique: &HashSet<i16>,
    na: i16,
    hist_bins: usize,
    hist_min: f32,
    hist_max: f32,
) -> HashMap<i16, HashMap<String, ZoneHistogram>> {
    let axes_thread = UniformNoFlow::new(hist_bins, hist_min, hist_max);

    // Initialize histograms for each column/zone combination
    let mut histograms_thread: HashMap<i16, HashMap<String, ZoneHistogram>> = HashMap::new();
    for zone in zones_unique.iter() {
        let mut hm = HashMap::new();
        for column in col_names.iter() {
            let h = ndhistogram::ndhistogram!(axes_thread.clone().unwrap());
            hm.insert(column.clone(), h);
        }
        histograms_thread.insert(*zone, hm);
    }

    // Fill histograms by iterating over time and layer dimensions
    let mut idx = 0;
    for time in 0..data.shape()[0] {
        let data_t: ArrayView3<_> = data.index_axis(Axis(0), time);
        for layer in 0..data_t.shape()[0] {
            let column = col_names[idx].clone();
            idx += 1;
            let data_l: ArrayView2<_> = data_t.index_axis(Axis(0), layer);

            ndarray::azip!((d in &data_l, z in &zones) {
                let zone = *z;
                if zone != na {
                    let msg = format!("Zone: {zone:?} not found");
                    let histogram = histograms_thread.get_mut(&zone).expect(&msg);
                    let histogram_time_layer = histogram.get_mut(&column).expect("key not available");
                    histogram_time_layer.fill(&(*d as f32));
                }
            });
        }
    }

    histograms_thread
}

impl<R> RasterDataset<R>
where
    R: RasterType,
{
    /// Computes zonal histograms using another raster as zones.
    #[cfg(feature = "use_polars")]
    pub fn zonal_histograms_raster(
        &self,
        other: &crate::rasterdataset::RasterDataset<u8>,
        hist_bins: usize,
        hist_min: f32,
        hist_max: f32,
    ) -> Result<DataFrame> {
        let blocks_to_process: Vec<usize> = (0..self.blocks.len()).collect();

        let col_names = self.column_names();

        let zones_unique: HashSet<i16> = other
            .get_unique_values()
            .into_iter()
            .filter(|&x| x != 0)
            .map(|x| x as i16)
            .collect();
        let pool = create_rayon_pool(1);
        let handle = pool.install(|| {
            blocks_to_process.into_par_iter()
                .map(
               |id|
                   -> PathBuf
                 {
                     let data = self.read_block::<i16>(id);
                     let zones_block = other.read_block::<i32>(id);

                     if (zones_block.shape()[0] != 1) | (zones_block.shape()[1] != 1) {
                         panic!("Zones raster dataset can only have 1 time and 1 layer");
                     }
                 let zones_t = zones_block.index_axis(Axis(0), 0);
                      let zones_view: ArrayView2<_> = zones_t.index_axis(Axis(0), 0);
                      let zones: ndarray::Array2<i16> = zones_view.as_slice()
                          .unwrap()
                          .iter()
                          .map(|v| *v as i16)
                          .collect::<Vec<_>>()
                          .into_iter()
                          .collect::<ndarray::Array1<i16>>()
                          .into_shape_with_order(zones_view.raw_dim())
                          .unwrap();

                     let histograms_thread = compute_zonal_histograms(
                         data, zones, col_names.clone(), &zones_unique,
                         0, hist_bins, hist_min, hist_max,
                     );

                     let uid = Uuid::new_v4();
                     let histogram_fn = PathBuf::from(&format!("/tmp/eorst_{uid:?}_{}.parquet", id));
                     histograms_to_dataframe!(&histograms_thread, &histogram_fn);
                     histogram_fn
        })
        });

        let collected_histograms: Vec<_> = handle.collect();
        let df = merge_histogram_dfs(collected_histograms)?;
        Ok(df)
    }

    /// Computes zonal histograms using polygons as zones.
    #[cfg(feature = "use_polars")]
    pub fn zonal_histograms_polygons(
        &self,
        polygons: &str,
        id_column_name: &str,
        na: i16,
        hist_bins: usize,
        hist_min: f32,
        hist_max: f32,
    ) -> Result<DataFrame> {
        let blocks_to_process: Vec<usize> = (0..self.blocks.len()).collect();
        let col_names = self.column_names();

        let pool = create_rayon_pool(1);
        let handle = pool.install(|| {
             blocks_to_process.into_par_iter()
                 .map(
                     |id|
                   -> PathBuf
                 {
                    let block = self.blocks[id].clone();
                    let target_gt = block.geo_transform;
                    let rows = block.read_window.size.rows;
                    let cols = block.read_window.size.cols;
                    let x_ll = target_gt.x_ul + (cols as f64 * target_gt.x_res);
                    let y_ll = target_gt.y_ul + (rows as f64 * target_gt.y_res);

                    let extent = crate::metadata::Extent {
                                    xmin: (target_gt.x_ul * 100.).round() / 100.,
                                    ymin: (y_ll * 100.).round() / 100.,
                                    xmax: (x_ll * 100.).round() / 100.,
                                    ymax: (target_gt.y_ul * 100.).round() / 100.,
                                };
                    let rasterized_fn = format!("rasterized_block_{id:?}.tif");
                    let mut cmd = Command::new("gdal_rasterize");
                    cmd.arg("-q");
                    cmd.args([
                       "-tr",
                       &format!("{}", self.metadata.geo_transform.x_res.abs()),
                      &format!("{}", self.metadata.geo_transform.y_res.abs()),
                    ]);
                    cmd.args(["-a", id_column_name]);
    cmd.args([
        "-te",
        &format!("{}", (extent.xmin * 100.).round() / 100.),
        &format!("{}", (extent.ymin * 100.).round() / 100.),
        &format!("{}", (extent.xmax * 100.).round() / 100.),
        &format!("{}", (extent.ymax * 100.).round() / 100.),
    ]);
                        cmd.arg(polygons);
                        cmd.arg(rasterized_fn.clone());

   cmd.spawn()
       .expect("failed to start creating raster")
       .wait()
       .expect("failed to wait for the raster");
                     //
                   let data = self.read_block::<i16>(id);
                   let zones = read_raster_band::<i16>(
                              std::path::Path::new(&rasterized_fn),
                             1,
                             (0, 0),
                             (cols as usize, rows as usize),
                         );
                       let mut zones_unique: HashSet<_> = zones.iter().copied().collect();

                       zones_unique.retain(|&x| x != na);

                       let histograms_thread = compute_zonal_histograms(
                           data, zones, col_names.clone(), &zones_unique,
                           na, hist_bins, hist_min, hist_max,
                       );

                       let uid = Uuid::new_v4();
                       let histogram_fn = PathBuf::from(&format!("/tmp/{}_{}.parquet", uid, id));
                       histograms_to_dataframe!(&histograms_thread, &histogram_fn);
                       histogram_fn
                    }
                )
            }
            );
        let collected_histograms: Vec<PathBuf> = handle.collect();
        let df = merge_histogram_dfs(collected_histograms)?;

        Ok(df)
    }
}