use crate::core_types::RasterType;
use crate::gdal_utils::{create_rayon_pool, read_raster_band};
use crate::rasterdataset::RasterDataset;
use anyhow::Result;
use ndarray::{Axis, ArrayView2, ArrayView3};
use ndhistogram::axis::UniformNoFlow;
use ndhistogram::Histogram;
use polars::prelude::{Column, DataFrame, ParquetReader, SerReader};
use rayon::prelude::*;
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::process::Command;
use uuid::Uuid;
type ZoneHistogram = ndhistogram::Hist1D<UniformNoFlow<f32>>;
macro_rules! histograms_to_dataframe {
($histograms_thread:expr, $histogram_fn:expr) => {{
let mut zones = Vec::new();
let mut times = Vec::new();
let mut layers = Vec::new();
let mut bin_starts = Vec::new();
let mut bin_ends = Vec::new();
let mut counts = Vec::new();
for (z, histogram) in $histograms_thread {
for col_histogram in histogram {
for hist in col_histogram.1.iter() {
let tl: Vec<&str> = col_histogram.0.split("--").collect();
let t = tl[0];
let l = tl[1];
zones.push(*z as i64);
times.push(t);
layers.push(l);
bin_starts.push(hist.bin.start());
bin_ends.push(hist.bin.end());
counts.push(*hist.value as i64);
}
}
}
let zones = Column::new("zone".into(), &zones);
let times = Column::new("time".into(), ×);
let layers = Column::new("layer".into(), &layers);
let bin_starts = Column::new("bin_star".into(), &bin_starts);
let bin_ends = Column::new("bin_end".into(), &bin_ends);
let counts = Column::new("count".into(), &counts);
let vec_columns = vec![zones, times, layers, bin_starts, bin_ends, counts];
let mut df = DataFrame::new_infer_height(vec_columns).unwrap();
save_zonal_histograms(&mut df, $histogram_fn);
}};
}
pub fn save_zonal_histograms(df: &mut DataFrame, path: &std::path::Path) {
use polars::prelude::ParquetWriter;
let file = std::fs::File::create(path).unwrap();
ParquetWriter::new(file).finish(df).unwrap();
}
fn merge_histogram_dfs(paths: Vec<PathBuf>) -> Result<DataFrame> {
let mut dfs = Vec::new();
for path in paths {
let file = std::fs::File::open(&path)?;
let df = ParquetReader::new(file).finish()?;
dfs.push(df);
std::fs::remove_file(path).ok();
}
let merged = polars::functions::concat_df_diagonal(&dfs)?;
Ok(merged)
}
fn compute_zonal_histograms(
data: ndarray::Array4<i16>,
zones: ndarray::Array2<i16>,
col_names: Vec<String>,
zones_unique: &HashSet<i16>,
na: i16,
hist_bins: usize,
hist_min: f32,
hist_max: f32,
) -> HashMap<i16, HashMap<String, ZoneHistogram>> {
let axes_thread = UniformNoFlow::new(hist_bins, hist_min, hist_max);
let mut histograms_thread: HashMap<i16, HashMap<String, ZoneHistogram>> = HashMap::new();
for zone in zones_unique.iter() {
let mut hm = HashMap::new();
for column in col_names.iter() {
let h = ndhistogram::ndhistogram!(axes_thread.clone().unwrap());
hm.insert(column.clone(), h);
}
histograms_thread.insert(*zone, hm);
}
let mut idx = 0;
for time in 0..data.shape()[0] {
let data_t: ArrayView3<_> = data.index_axis(Axis(0), time);
for layer in 0..data_t.shape()[0] {
let column = col_names[idx].clone();
idx += 1;
let data_l: ArrayView2<_> = data_t.index_axis(Axis(0), layer);
ndarray::azip!((d in &data_l, z in &zones) {
let zone = *z;
if zone != na {
let msg = format!("Zone: {zone:?} not found");
let histogram = histograms_thread.get_mut(&zone).expect(&msg);
let histogram_time_layer = histogram.get_mut(&column).expect("key not available");
histogram_time_layer.fill(&(*d as f32));
}
});
}
}
histograms_thread
}
impl<R> RasterDataset<R>
where
R: RasterType,
{
#[cfg(feature = "use_polars")]
pub fn zonal_histograms_raster(
&self,
other: &crate::rasterdataset::RasterDataset<u8>,
hist_bins: usize,
hist_min: f32,
hist_max: f32,
) -> Result<DataFrame> {
let blocks_to_process: Vec<usize> = (0..self.blocks.len()).collect();
let col_names = self.column_names();
let zones_unique: HashSet<i16> = other
.get_unique_values()
.into_iter()
.filter(|&x| x != 0)
.map(|x| x as i16)
.collect();
let pool = create_rayon_pool(1);
let handle = pool.install(|| {
blocks_to_process.into_par_iter()
.map(
|id|
-> PathBuf
{
let data = self.read_block::<i16>(id);
let zones_block = other.read_block::<i32>(id);
if (zones_block.shape()[0] != 1) | (zones_block.shape()[1] != 1) {
panic!("Zones raster dataset can only have 1 time and 1 layer");
}
let zones_t = zones_block.index_axis(Axis(0), 0);
let zones_view: ArrayView2<_> = zones_t.index_axis(Axis(0), 0);
let zones: ndarray::Array2<i16> = zones_view.as_slice()
.unwrap()
.iter()
.map(|v| *v as i16)
.collect::<Vec<_>>()
.into_iter()
.collect::<ndarray::Array1<i16>>()
.into_shape_with_order(zones_view.raw_dim())
.unwrap();
let histograms_thread = compute_zonal_histograms(
data, zones, col_names.clone(), &zones_unique,
0, hist_bins, hist_min, hist_max,
);
let uid = Uuid::new_v4();
let histogram_fn = PathBuf::from(&format!("/tmp/eorst_{uid:?}_{}.parquet", id));
histograms_to_dataframe!(&histograms_thread, &histogram_fn);
histogram_fn
})
});
let collected_histograms: Vec<_> = handle.collect();
let df = merge_histogram_dfs(collected_histograms)?;
Ok(df)
}
#[cfg(feature = "use_polars")]
pub fn zonal_histograms_polygons(
&self,
polygons: &str,
id_column_name: &str,
na: i16,
hist_bins: usize,
hist_min: f32,
hist_max: f32,
) -> Result<DataFrame> {
let blocks_to_process: Vec<usize> = (0..self.blocks.len()).collect();
let col_names = self.column_names();
let pool = create_rayon_pool(1);
let handle = pool.install(|| {
blocks_to_process.into_par_iter()
.map(
|id|
-> PathBuf
{
let block = self.blocks[id].clone();
let target_gt = block.geo_transform;
let rows = block.read_window.size.rows;
let cols = block.read_window.size.cols;
let x_ll = target_gt.x_ul + (cols as f64 * target_gt.x_res);
let y_ll = target_gt.y_ul + (rows as f64 * target_gt.y_res);
let extent = crate::metadata::Extent {
xmin: (target_gt.x_ul * 100.).round() / 100.,
ymin: (y_ll * 100.).round() / 100.,
xmax: (x_ll * 100.).round() / 100.,
ymax: (target_gt.y_ul * 100.).round() / 100.,
};
let rasterized_fn = format!("rasterized_block_{id:?}.tif");
let mut cmd = Command::new("gdal_rasterize");
cmd.arg("-q");
cmd.args([
"-tr",
&format!("{}", self.metadata.geo_transform.x_res.abs()),
&format!("{}", self.metadata.geo_transform.y_res.abs()),
]);
cmd.args(["-a", id_column_name]);
cmd.args([
"-te",
&format!("{}", (extent.xmin * 100.).round() / 100.),
&format!("{}", (extent.ymin * 100.).round() / 100.),
&format!("{}", (extent.xmax * 100.).round() / 100.),
&format!("{}", (extent.ymax * 100.).round() / 100.),
]);
cmd.arg(polygons);
cmd.arg(rasterized_fn.clone());
cmd.spawn()
.expect("failed to start creating raster")
.wait()
.expect("failed to wait for the raster");
let data = self.read_block::<i16>(id);
let zones = read_raster_band::<i16>(
std::path::Path::new(&rasterized_fn),
1,
(0, 0),
(cols as usize, rows as usize),
);
let mut zones_unique: HashSet<_> = zones.iter().copied().collect();
zones_unique.retain(|&x| x != na);
let histograms_thread = compute_zonal_histograms(
data, zones, col_names.clone(), &zones_unique,
na, hist_bins, hist_min, hist_max,
);
let uid = Uuid::new_v4();
let histogram_fn = PathBuf::from(&format!("/tmp/{}_{}.parquet", uid, id));
histograms_to_dataframe!(&histograms_thread, &histogram_fn);
histogram_fn
}
)
}
);
let collected_histograms: Vec<PathBuf> = handle.collect();
let df = merge_histogram_dfs(collected_histograms)?;
Ok(df)
}
}