use crate::core_types::RasterType;
use crate::gdal_utils::create_rayon_pool;
use crate::rasterdataset::builder::n_block_cols;
use crate::types::{Coordinates, Index2d, Rectangle, SamplingMethod};
use crate::rasterdataset::RasterDataset;
use anyhow::Result;
use gdal::Dataset;
use gdal::vector::{Geometry, LayerAccess};
use itertools::Itertools;
use kdam::par_tqdm;
use ndarray::{s, Array2};
use rayon::prelude::*;
use std::collections::BTreeMap;
use std::hash::Hash;
use std::path::Path;
fn sample_value(
band_data: &ndarray::ArrayView2<i16>,
rect: Rectangle,
point: Index2d,
method: SamplingMethod,
) -> i16 {
match method {
SamplingMethod::Value => band_data[(point.row, point.col)],
SamplingMethod::Avg => {
let window_data: Vec<i16> =
crate::array_ops::rect_view(band_data, rect, point)
.iter()
.copied()
.collect();
let window_size = window_data.len();
let avg: f32 = window_data
.iter()
.map(|v| *v as f32 / window_size as f32)
.sum();
avg.round() as i16
}
SamplingMethod::Mode => {
let mut window_data: Vec<i16> =
crate::array_ops::rect_view(band_data, rect, point)
.iter()
.copied()
.collect();
window_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
window_data[window_data.len() / 2]
}
SamplingMethod::Min => {
let window_data: Vec<i16> =
crate::array_ops::rect_view(band_data, rect, point)
.iter()
.copied()
.collect();
*window_data.iter().min().unwrap()
}
SamplingMethod::StdDev => {
let window_data: Vec<i16> =
crate::array_ops::rect_view(band_data, rect, point)
.iter()
.copied()
.collect();
let sum: i32 = window_data.iter().map(|&x| x as i32).sum();
let mean = sum as f64 / window_data.len() as f64;
let variance: f64 = window_data
.iter()
.map(|&x| (x as f64 - mean).powi(2))
.sum::<f64>()
/ window_data.len() as f64;
variance.sqrt().round() as i16
}
}
}
fn validate_buffer_size(buffer_size: usize, overlap_size: usize) {
assert!(
buffer_size <= overlap_size,
"Buffer size has to be > overlap size"
);
}
fn make_rectangle(buffer_size: usize) -> Rectangle {
Rectangle {
left: buffer_size,
top: buffer_size,
right: buffer_size,
bottom: buffer_size,
}
}
fn build_block_index_pipeline<R: RasterType>(
raster: &RasterDataset<R>,
geoms: &BTreeMap<i64, Vec<(f64, f64, f64)>>,
) -> (BTreeMap<i64, Index2d>, Vec<(usize, (i64, Index2d))>, Vec<usize>) {
let idx_global = raster.geoms_to_global_indices(geoms.clone());
let id_indices: Vec<(usize, (i64, Index2d))> = idx_global
.par_iter()
.map(|(pid, index)| raster.block_id_rowcol(*pid, *index))
.collect();
let block_ids: Vec<_> = idx_global
.par_iter()
.map(|(_, index)| raster.id_from_indices(*index))
.collect();
let blocks_to_process: Vec<usize> = block_ids.iter().unique().copied().collect();
(idx_global, id_indices, blocks_to_process)
}
fn collect_points_for_block(
id_indices: &[(usize, (i64, Index2d))],
block_id: usize,
) -> (Vec<Index2d>, Vec<usize>, Vec<usize>) {
let mut pos: Vec<Index2d> = Vec::new();
let mut idx: Vec<usize> = Vec::new();
let mut pids: Vec<usize> = Vec::new();
for (pid, p) in id_indices.iter().enumerate() {
if p.0 == block_id {
pos.push(Index2d {
col: p.1 .1.col,
row: p.1 .1.row,
});
pids.push(pid);
idx.push(p.1 .0 as usize);
}
}
(pos, idx, pids)
}
fn assemble_block_results<K>(
collected: &[(Vec<usize>, Vec<usize>, Vec<Vec<i16>>)],
key_converter: fn(usize) -> K,
) -> BTreeMap<K, Vec<i16>>
where
K: Ord + Hash,
{
let pids: Vec<_> = collected.iter().map(|(pid, _, _)| pid).collect();
let vals: Vec<_> = collected.iter().map(|(_, _, vals)| vals).collect();
let idxs: Vec<_> = collected.iter().map(|(_, idx, _)| idx).collect();
let mut results = BTreeMap::new();
let num_bands = vals[0].len();
let num_blocks = pids.len();
for block in 0..num_blocks {
for i in 0..pids[block].len() {
let mut vals_point: Vec<i16> = Vec::new();
let id = idxs[block][i];
for band in 0..num_bands {
vals_point.push(vals[block][band][i]);
}
let mut res_point = BTreeMap::new();
res_point.insert(key_converter(id), vals_point);
results.append(&mut res_point);
}
}
results
}
impl<R> RasterDataset<R>
where
R: RasterType,
{
pub fn geoms_to_global_indices(
&self,
geoms: BTreeMap<i64, Vec<(f64, f64, f64)>>,
) -> BTreeMap<i64, Index2d> {
let idx_global: BTreeMap<_, _> = geoms
.par_iter()
.map(|(pid, p)| {
let point: Coordinates = Coordinates {
x: p[0].0,
y: p[0].1,
};
(*pid, self.geo_to_global_rc(point))
})
.collect();
idx_global
}
fn geo_to_global_rc(&self, point: Coordinates) -> Index2d {
let gt = self.metadata.geo_transform.to_array();
let row = ((point.y - gt[3]) / gt[5]) as usize;
let col = ((point.x - gt[0]) / gt[1]) as usize;
Index2d { col, row }
}
pub fn block_id_rowcol(&self, pid: i64, index: Index2d) -> (usize, (i64, Index2d)) {
let id = self.id_from_indices(index);
let row_col = self.global_rc_to_block_rc(index);
(id, (pid, row_col))
}
fn global_rc_to_block_rc(&self, global_index: Index2d) -> Index2d {
let mut block_col = global_index.col % self.metadata.block_size.cols;
let mut block_row = global_index.row % self.metadata.block_size.rows;
let block_col_ov = block_col + self.metadata.overlap_size;
let block_row_ov = block_row + self.metadata.overlap_size;
if (global_index.col as i16 - block_col_ov as i16) > 0 {
block_col = block_col_ov;
};
if global_index.row as i16 - block_row_ov as i16 > 0 {
block_row = block_row_ov;
};
Index2d {
col: block_col,
row: block_row,
}
}
fn id_from_indices(&self, index: Index2d) -> usize {
let n_block_cols = self.n_block_cols();
(index.col / self.metadata.block_size.cols)
+ (index.row / self.metadata.block_size.rows) * n_block_cols
}
fn n_block_cols(&self) -> usize {
let image_size = crate::types::ImageSize {
rows: self.metadata.shape.rows,
cols: self.metadata.shape.cols,
};
n_block_cols(image_size, self.metadata.block_size)
}
pub fn extract_blockwise(
&self,
vector_path: &std::path::PathBuf,
id_col_name: &str,
method: SamplingMethod,
buffer_size: Option<usize>,
) -> BTreeMap<i16, Vec<i16>> {
log::debug!("Starting extract.");
let buffer_size = buffer_size.unwrap_or(0);
validate_buffer_size(buffer_size, self.metadata.overlap_size);
let vector_dataset = Dataset::open(Path::new(vector_path)).unwrap();
let mut layer = vector_dataset.layer(0).unwrap();
let mut geoms = BTreeMap::new();
for feature in layer.features() {
let mut geom = Vec::new();
feature
.geometry()
.expect("Geometries")
.get_points(&mut geom);
let field_index = feature.field_index(id_col_name).expect("Bad column name.");
let pid_filed = feature.field(field_index).unwrap().unwrap();
let pid = pid_filed.into_int64().unwrap();
geoms.insert(pid, geom);
}
let (_idx_global, id_indices, blocks_to_process) =
build_block_index_pipeline(self, &geoms);
drop(geoms);
let pool = create_rayon_pool(1);
let handle = pool.install(|| {
par_tqdm!(blocks_to_process
.into_par_iter())
.map(|id| -> (Vec<usize>, Vec<usize>, Vec<Vec<i16>>) {
let (pos, idx, pids) = collect_points_for_block(&id_indices, id);
let mut res = Vec::new();
let rect = make_rectangle(buffer_size);
let data = self.read_block(id);
let bands = data.shape()[1];
log::debug!("Bands {:?}", bands);
for band_n in 0..bands {
let mut res_band = Vec::new();
let band_data = data.slice(s![0_i32, band_n, .., ..]);
for point in pos.iter() {
let val = sample_value(&band_data, rect, *point, method);
res_band.push(val);
}
res.push(res_band);
}
(pids, idx, res)
})
});
let collected: Vec<_> = handle.collect();
assemble_block_results(&collected, |id| id as i16)
}
pub fn extract(
&self,
geometries: &[Geometry],
point_ids: &[i64],
method: SamplingMethod,
buffer_size: Option<usize>,
) -> Result<(Array2<i16>, Vec<i64>)> {
let buffer_size = buffer_size.unwrap_or(0);
validate_buffer_size(buffer_size, self.metadata.overlap_size);
let mut geoms = BTreeMap::new();
for (idx, point_id) in point_ids.iter().enumerate() {
let geometry = &geometries[idx];
let point = geometry.get_point(0);
let (x, y, z) = point;
geoms.insert(*point_id, vec![(x, y, z)]);
}
let (_idx_global, id_indices, blocks_to_process) =
build_block_index_pipeline(self, &geoms);
drop(geoms);
let blocks_to_process: Vec<usize> = blocks_to_process;
let pool = create_rayon_pool(1);
let handle = pool.install(|| {
par_tqdm!(blocks_to_process
.into_par_iter())
.map(|id| -> (Vec<usize>, Vec<usize>, Vec<Vec<i16>>) {
let (pos, idx, pids) = collect_points_for_block(&id_indices, id);
log::debug!("Extracting {} points, from block: {}", pos.len(), id);
let mut res = Vec::new();
let rect = make_rectangle(buffer_size);
let data = self.read_block(id);
let n_times = data.shape()[0];
let n_layers = data.shape()[1];
for time in 0..n_times {
for layer in 0..n_layers {
let mut res_band = Vec::new();
let band_data = data.slice(s![time, layer, .., ..]);
for point in pos.iter() {
let col = point.col.checked_sub(self.blocks[id].overlap.left);
let row = point.row.checked_sub(self.blocks[id].overlap.top);
let col = col.unwrap_or(point.col);
let row = row.unwrap_or(point.row);
let val = sample_value(&band_data, rect, Index2d { col, row }, method);
res_band.push(val);
}
res.push(res_band);
}
}
(pids, idx, res)
})
});
let collected: Vec<_> = handle.collect();
let results = assemble_block_results(&collected, |id| id as i64);
let k = results.keys().next().unwrap();
let n_rows = results.len();
let n_cols = results[k].len();
let mut array: Array2<i16> = ndarray::Array::zeros((n_rows, n_cols));
for (row_index, values) in results.values().enumerate() {
for (col_index, value) in values.iter().enumerate() {
array[[row_index, col_index]] = *value;
}
}
let pids: Vec<i64> = results.into_keys().collect();
Ok((array, pids))
}
}