use crate::core_types::{RasterData, RasterType};
use crate::data_sources::DateType;
use crate::metadata::{RasterDataBlock, RasterMetadata};
use crate::types::{Dimension, RasterDataShape};
use anyhow::Result;
use gdal::raster::GdalType;
use ndarray::{Array2, Array3, Axis};
use num_traits::{FromPrimitive, ToPrimitive};
use std::fmt::{self, Debug};
use std::ops::{Add, Div, Index, IndexMut};
pub trait Stack {
fn stack(&mut self, other: RasterDataShape, dim_to_stack: Dimension) -> &mut RasterDataShape;
fn extend(&mut self, other: RasterDataShape) -> &mut RasterDataShape;
}
impl Stack for RasterDataShape {
fn extend(&mut self, other: RasterDataShape) -> &mut RasterDataShape {
let mut extendable = true;
for dim_loc in 1..4 {
if self[dim_loc] != other[dim_loc] {
extendable = false
}
}
if extendable {
self[0] += other[0];
self
} else {
panic!("Unable to extend layers");
}
}
fn stack(&mut self, other: RasterDataShape, dim_to_stack: Dimension) -> &mut RasterDataShape {
let dimension_axis = dim_to_stack.get_axis();
let mut stackable = true;
for dim_loc in 0..4 {
if dim_loc != dimension_axis && self[dim_loc] != other[dim_loc] {
stackable = false;
}
}
if stackable {
self[dimension_axis] += other[dimension_axis];
self
} else {
panic!("Unable to stack layers");
}
}
}
impl Dimension {
pub fn get_axis(&self) -> usize {
match self {
Dimension::Layer => 1,
Dimension::Time => 0,
}
}
}
impl Index<usize> for RasterDataShape {
type Output = usize;
fn index(&self, index: usize) -> &usize {
match index {
0 => &self.times,
1 => &self.layers,
2 => &self.rows,
3 => &self.cols,
n => panic!("Invalid index: {}", n),
}
}
}
impl IndexMut<usize> for RasterDataShape {
fn index_mut(&mut self, index: usize) -> &mut usize {
match index {
0 => &mut self.times,
1 => &mut self.layers,
2 => &mut self.rows,
3 => &mut self.cols,
n => panic!("Invalid index: {}", n),
}
}
}
pub trait SumDimension<T>
where
T: GdalType + num_traits::identities::Zero + Copy + FromPrimitive + Add<Output = T> + Div<Output = T>,
{
fn sum_dimension(&self, dimension: Dimension) -> Array3<T>;
}
impl<T> SumDimension<T> for RasterData<T>
where
T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
{
fn sum_dimension(&self, dimension: Dimension) -> Array3<T> {
match dimension {
Dimension::Layer => self.sum_axis(Axis(1)),
Dimension::Time => self.sum_axis(Axis(0)),
}
}
}
#[allow(dead_code)]
pub(crate) trait MeanDimension<T>
where
T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
{
fn mean_dimension(&self, dimension: Dimension) -> Array3<T>;
}
impl<T> MeanDimension<T> for RasterData<T>
where
T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
{
fn mean_dimension(&self, dimension: Dimension) -> Array3<T> {
let mean = match dimension {
Dimension::Layer => self.mean_axis(Axis(1)),
Dimension::Time => self.mean_axis(Axis(0)),
};
mean.unwrap()
}
}
pub trait VarDimension<T>
where
T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
{
fn var_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T>;
}
impl<T> VarDimension<T> for RasterData<T>
where
T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
{
fn var_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T> {
match dimension {
Dimension::Layer => self.var_axis(Axis(1), ddof),
Dimension::Time => self.var_axis(Axis(0), ddof),
}
}
}
#[allow(dead_code)]
pub(crate) trait StdDimension<T>
where
T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
{
fn std_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T>;
}
impl<T> StdDimension<T> for RasterData<T>
where
T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
{
fn std_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T> {
match dimension {
Dimension::Layer => self.std_axis(Axis(1), ddof),
Dimension::Time => self.std_axis(Axis(0), ddof),
}
}
}
#[derive(Debug)]
pub enum SelectError {
LayerNotFound {
requested: String,
available: Vec<String>,
},
TimeNotFound {
requested: DateType,
available: Vec<DateType>,
},
EmptySelection,
ConcatenationError(String),
}
impl fmt::Display for SelectError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SelectError::LayerNotFound { requested, available } => {
write!(f, "Layer '{}' not found. Available: {:?}", requested, available)
}
SelectError::TimeNotFound { requested, available } => {
write!(f, "Time {:?} not found. Available: {:?}", requested, available)
}
SelectError::EmptySelection => {
write!(f, "Empty selection requested")
}
SelectError::ConcatenationError(msg) => {
write!(f, "Array concatenation failed: {}", msg)
}
}
}
}
impl std::error::Error for SelectError {}
pub trait Select<T>
where
T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
{
fn select_layers(&self, layer_names: &[&str]) -> Result<RasterDataBlock<T>, SelectError>;
fn select_times(&self, dates: &[DateType]) -> Result<RasterDataBlock<T>, SelectError>;
fn find_layer_index(&self, name: &str) -> Result<usize, SelectError>;
fn find_time_index(&self, date: &DateType) -> Result<usize, SelectError>;
}
impl<T> Select<T> for RasterDataBlock<T>
where
T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
{
fn find_layer_index(&self, name: &str) -> Result<usize, SelectError> {
self.metadata
.layer_indices
.iter()
.position(|s| s.as_str() == name)
.ok_or(SelectError::LayerNotFound {
requested: name.to_string(),
available: self.metadata.layer_indices.clone(),
})
}
fn find_time_index(&self, date: &DateType) -> Result<usize, SelectError> {
self.metadata
.date_indices
.iter()
.position(|d| d == date)
.ok_or(SelectError::TimeNotFound {
requested: date.clone(),
available: self.metadata.date_indices.clone(),
})
}
fn select_layers(&self, layer_names: &[&str]) -> Result<RasterDataBlock<T>, SelectError> {
if layer_names.is_empty() {
return Err(SelectError::EmptySelection);
}
let indices: Vec<usize> = layer_names
.iter()
.map(|name| self.find_layer_index(name))
.collect::<Result<_, _>>()?;
let views: Vec<_> = indices
.iter()
.map(|&idx| self.data.index_axis(Axis(1), idx))
.collect();
let data = ndarray::stack(Axis(1), &views)
.map_err(|e| SelectError::ConcatenationError(e.to_string()))?;
let new_metadata = RasterMetadata {
layer_indices: layer_names.iter().map(|s| s.to_string()).collect(),
shape: RasterDataShape {
layers: layer_names.len(),
..self.metadata.shape
},
..self.metadata.clone()
};
Ok(RasterDataBlock {
data,
metadata: new_metadata,
no_data: self.no_data,
})
}
fn select_times(&self, dates: &[DateType]) -> Result<RasterDataBlock<T>, SelectError> {
if dates.is_empty() {
return Err(SelectError::EmptySelection);
}
let indices: Vec<usize> = dates
.iter()
.map(|d| self.find_time_index(d))
.collect::<Result<_, _>>()?;
let views: Vec<_> = indices
.iter()
.map(|&idx| self.data.index_axis(Axis(0), idx))
.collect();
let data = ndarray::stack(Axis(0), &views)
.map_err(|e| SelectError::ConcatenationError(e.to_string()))?;
let new_metadata = RasterMetadata {
date_indices: dates.to_vec(),
shape: RasterDataShape {
times: dates.len(),
..self.metadata.shape
},
..self.metadata.clone()
};
Ok(RasterDataBlock {
data,
metadata: new_metadata,
no_data: self.no_data,
})
}
}
impl<T> RasterDataBlock<T>
where
T: RasterType,
{
pub fn available_layer_names(&self) -> &[String] {
&self.metadata.layer_indices
}
pub fn available_time_indices(&self) -> &[DateType] {
&self.metadata.date_indices
}
}
pub trait RasterBlockTrait<U>
where
U: RasterType,
{
fn into_frc(&self, data: &Array2<U>) -> Array3<U>;
fn write_samples_feature<T>(&self, data: &Array2<T>, file_name: &std::path::PathBuf, na: T)
where
T: RasterType + ToPrimitive;
fn write3<T>(&self, data: Array3<T>, out_fn: &std::path::PathBuf)
where
T: RasterType + ToPrimitive;
}
#[cfg(test)]
mod tests {
use crate::data_sources::DateType;
use crate::metadata::{RasterDataBlock, RasterMetadata};
use crate::types::RasterDataShape;
use ndarray::Array4;
use num_traits::NumCast;
use super::{Select, SelectError};
fn make_test_block() -> RasterDataBlock<f32> {
let data = Array4::<f32>::zeros((3, 4, 2, 2));
let metadata = RasterMetadata {
layer_indices: vec!["red".into(), "green".into(), "nir".into(), "swir".into()],
date_indices: vec![
DateType::Index(0),
DateType::Index(1),
DateType::Index(2),
],
shape: RasterDataShape {
times: 3,
layers: 4,
rows: 2,
cols: 2,
},
..RasterMetadata::new()
};
RasterDataBlock {
data,
metadata,
no_data: NumCast::from(0.0f32).unwrap(),
}
}
#[test]
fn test_select_layers_basic() {
let block = make_test_block();
let result = block.select_layers(&["red", "nir"]).unwrap();
assert_eq!(result.metadata.shape.layers, 2);
assert_eq!(result.data.shape(), &[3, 2, 2, 2]);
assert_eq!(result.metadata.layer_indices, vec!["red", "nir"]);
assert_eq!(result.metadata.shape.times, 3);
assert_eq!(result.metadata.shape.rows, 2);
assert_eq!(result.metadata.shape.cols, 2);
}
#[test]
fn test_select_layers_single() {
let block = make_test_block();
let result = block.select_layers(&["nir"]).unwrap();
assert_eq!(result.metadata.shape.layers, 1);
assert_eq!(result.data.shape(), &[3, 1, 2, 2]);
assert_eq!(result.metadata.layer_indices, vec!["nir"]);
}
#[test]
fn test_select_layers_all() {
let block = make_test_block();
let result = block
.select_layers(&["red", "green", "nir", "swir"])
.unwrap();
assert_eq!(result.metadata.shape.layers, 4);
assert_eq!(result.data.shape(), &[3, 4, 2, 2]);
}
#[test]
fn test_select_layers_not_found() {
let block = make_test_block();
let err = block.select_layers(&["red", "blue"]).unwrap_err();
assert!(matches!(err, SelectError::LayerNotFound { .. }));
if let SelectError::LayerNotFound { requested, available } = err {
assert_eq!(requested, "blue");
assert_eq!(available.len(), 4);
}
}
#[test]
fn test_select_layers_empty() {
let block = make_test_block();
let err = block.select_layers(&[]).unwrap_err();
assert!(matches!(err, SelectError::EmptySelection));
}
#[test]
fn test_select_times_basic() {
let block = make_test_block();
let dates = vec![DateType::Index(0), DateType::Index(2)];
let result = block.select_times(&dates).unwrap();
assert_eq!(result.metadata.shape.times, 2);
assert_eq!(result.data.shape(), &[2, 4, 2, 2]);
assert_eq!(result.metadata.shape.layers, 4);
}
#[test]
fn test_select_times_single() {
let block = make_test_block();
let result = block.select_times(&[DateType::Index(1)]).unwrap();
assert_eq!(result.metadata.shape.times, 1);
assert_eq!(result.data.shape(), &[1, 4, 2, 2]);
}
#[test]
fn test_select_times_not_found() {
let block = make_test_block();
let err = block
.select_times(&[DateType::Index(99)])
.unwrap_err();
assert!(matches!(err, SelectError::TimeNotFound { .. }));
}
#[test]
fn test_select_times_empty() {
let block = make_test_block();
let err = block.select_times(&[]).unwrap_err();
assert!(matches!(err, SelectError::EmptySelection));
}
#[test]
fn test_select_chaining_layers_then_times() {
let block = make_test_block();
let result = block
.select_layers(&["red", "nir"])
.unwrap()
.select_times(&[DateType::Index(0)])
.unwrap();
assert_eq!(result.data.shape(), &[1, 2, 2, 2]);
assert_eq!(result.metadata.layer_indices, vec!["red", "nir"]);
assert_eq!(result.metadata.date_indices.len(), 1);
}
#[test]
fn test_select_chaining_times_then_layers() {
let block = make_test_block();
let result = block
.select_times(&[DateType::Index(1), DateType::Index(2)])
.unwrap()
.select_layers(&["swir"])
.unwrap();
assert_eq!(result.data.shape(), &[2, 1, 2, 2]);
assert_eq!(result.metadata.date_indices.len(), 2);
assert_eq!(result.metadata.layer_indices, vec!["swir"]);
}
#[test]
fn test_available_layer_names() {
let block = make_test_block();
assert_eq!(
block.available_layer_names(),
&["red", "green", "nir", "swir"]
);
}
#[test]
fn test_available_time_indices() {
let block = make_test_block();
let times = block.available_time_indices();
assert_eq!(times.len(), 3);
assert_eq!(times[0], DateType::Index(0));
assert_eq!(times[1], DateType::Index(1));
assert_eq!(times[2], DateType::Index(2));
}
#[test]
fn test_select_preserves_no_data() {
let block = make_test_block();
let result = block.select_layers(&["red"]).unwrap();
assert_eq!(result.no_data, block.no_data);
}
#[test]
fn test_select_error_display() {
let block = make_test_block();
let err = block.select_layers(&["missing"]).unwrap_err();
let msg = format!("{}", err);
assert!(msg.contains("missing"));
assert!(msg.contains("not found"));
}
#[test]
fn test_select_layers_order_preserved() {
let block = make_test_block();
let result = block.select_layers(&["swir", "red"]).unwrap();
assert_eq!(result.metadata.layer_indices, vec!["swir", "red"]);
}
}