use ndarray::{ArrayD, Axis, SliceInfoElem};
use std::collections::HashMap;
pub enum Selector<'a> {
Label(&'a str),
Slice(Vec<&'a str>),
}
#[derive(Debug, Clone)]
pub struct SpatialArray<T> {
data: ArrayD<T>,
dims: Vec<String>,
coords: HashMap<String, Vec<String>>,
}
impl<T> SpatialArray<T>
where
T: Clone,
{
pub fn new(data: ArrayD<T>, dims: Vec<String>) -> Self {
assert_eq!(
data.ndim(),
dims.len(),
"Number of dimension names must match the number of dimensions in data"
);
Self {
data,
dims,
coords: HashMap::new(),
}
}
pub fn new_with_coords(
data: ArrayD<T>,
dims: Vec<String>,
coords: HashMap<String, Vec<String>>,
) -> Self {
assert_eq!(
data.ndim(),
dims.len(),
"Number of dimension names must match the number of dimensions in data"
);
for (i, dim_name) in dims.iter().enumerate() {
if let Some(coord) = coords.get(dim_name) {
assert_eq!(
coord.len(),
data.shape()[i],
"Coordinate length for dimension '{}' must match dimension size",
dim_name
);
}
}
Self { data, dims, coords }
}
pub fn dims(&self) -> &[String] {
&self.dims
}
pub fn shape(&self) -> &[usize] {
self.data.shape()
}
pub fn ndim(&self) -> usize {
self.data.ndim()
}
pub fn data(&self) -> &ArrayD<T> {
&self.data
}
pub fn data_mut(&mut self) -> &mut ArrayD<T> {
&mut self.data
}
pub fn coords(&self, dim: &str) -> Option<&Vec<String>> {
self.coords.get(dim)
}
pub fn all_coords(&self) -> &HashMap<String, Vec<String>> {
&self.coords
}
pub fn set_coords(&mut self, dim: &str, coords: Vec<String>) {
let dim_index: usize = self
.dims
.iter()
.position(|d| d == dim)
.expect("Dimension not found");
assert_eq!(
coords.len(),
self.data.shape()[dim_index],
"Coordinate length must match dimension size"
);
self.coords.insert(dim.to_string(), coords);
}
pub fn dim_index(&self, dim: &str) -> Option<usize> {
self.dims.iter().position(|d| d == dim)
}
pub fn select_by_label(&self, dim: &str, label: &str) -> Option<usize> {
let coords: &Vec<String> = self.coords.get(dim)?;
coords.iter().position(|c| c == label)
}
pub fn sel(&self, selectors: HashMap<&str, Selector>) -> Self {
let mut data = self.data.clone();
let mut dims = self.dims.clone();
let mut coords = self.coords.clone();
for (dim_name, selector) in &selectors {
if let Selector::Slice(labels) = selector {
let dim_index = dims
.iter()
.position(|d| d == *dim_name)
.expect("Dimension not found");
let current_coords = coords
.get(*dim_name)
.expect("Coordinates not found for dimension");
let indices: Vec<usize> = labels
.iter()
.map(|label| {
current_coords
.iter()
.position(|c| c == *label)
.expect("Label not found")
})
.collect();
data = data.select(Axis(dim_index), &indices).to_owned();
let new_dim_coords: Vec<String> = labels.iter().map(|l| l.to_string()).collect();
coords.insert(dim_name.to_string(), new_dim_coords);
}
}
let mut slice_info = Vec::new();
let mut dims_to_remove = Vec::new();
for dim_name in dims.iter() {
if let Some(Selector::Label(label)) = selectors.get(dim_name.as_str()) {
let current_coords = coords
.get(dim_name)
.expect("Coordinates not found for dimension");
let index = current_coords
.iter()
.position(|c| c == *label)
.expect("Label not found");
slice_info.push(SliceInfoElem::Index(index as isize));
dims_to_remove.push(dim_name.clone());
} else {
slice_info.push(SliceInfoElem::Slice {
start: 0,
end: None,
step: 1,
});
}
}
let sliced_data = data.slice(slice_info.as_slice()).to_owned();
dims.retain(|d| !dims_to_remove.contains(d));
coords.retain(|k, _| !dims_to_remove.contains(k));
Self {
data: sliced_data,
dims,
coords,
}
}
pub fn with_coords(self, coords: HashMap<String, Vec<f64>>) -> Result<Self, String> {
let mut new_coords: HashMap<String, Vec<String>> = HashMap::new();
for (dim, coord_vec) in coords {
if let Some(dim_index) = self.dim_index(&dim) {
if coord_vec.len() != self.shape()[dim_index] {
return Err(format!(
"Coordinate length for dimension '{}' must match dimension size",
dim
));
}
let coord_strings: Vec<String> = coord_vec.iter().map(|c| c.to_string()).collect();
new_coords.insert(dim, coord_strings);
} else {
return Err(format!("Dimension '{}' not found", dim));
}
}
Ok(SpatialArray {
data: self.data,
dims: self.dims,
coords: new_coords,
})
}
}
impl<T> SpatialArray<T>
where
T: Clone + std::fmt::Display,
{
pub fn info(&self) -> String {
let mut info = String::new();
info.push_str(&format!("SpatialArray<{}>\n", std::any::type_name::<T>()));
info.push_str(&format!("Dimensions: {:?}\n", self.dims));
info.push_str(&format!("Shape: {:?}\n", self.shape()));
if !self.coords.is_empty() {
info.push_str("Coordinates:\n");
for (dim, coords) in &self.coords {
info.push_str(&format!(" {}: {} labels\n", dim, coords.len()));
}
}
info
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_spatial_array() {
let data: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, ndarray::Dim<ndarray::IxDynImpl>> =
ArrayD::from_shape_vec(vec![2, 3], vec![1, 2, 3, 4, 5, 6]).unwrap();
let dims: Vec<String> = vec!["y".to_string(), "x".to_string()];
let _array: SpatialArray<i32> = SpatialArray::new(data, dims);
}
#[test]
#[should_panic(
expected = "Number of dimension names must match the number of dimensions in data"
)]
fn test_new_spatial_array_dimension_mismatch() {
let data: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, ndarray::Dim<ndarray::IxDynImpl>> =
ArrayD::from_shape_vec(vec![2, 3], vec![1, 2, 3, 4, 5, 6]).unwrap();
let dims: Vec<String> = vec!["y".to_string()]; let _array: SpatialArray<i32> = SpatialArray::new(data, dims);
}
#[test]
fn test_dim_index_and_missing() {
let data: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, ndarray::Dim<ndarray::IxDynImpl>> =
ArrayD::from_shape_vec(vec![2, 3, 4], vec![0; 24]).unwrap();
let dims: Vec<String> = vec!["time".to_string(), "y".to_string(), "x".to_string()];
let array: SpatialArray<i32> = SpatialArray::new(data, dims);
assert_eq!(array.dim_index("time"), Some(0));
assert_eq!(array.dim_index("band"), None);
}
#[test]
fn test_data_mut_roundtrip() {
let data: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, ndarray::Dim<ndarray::IxDynImpl>> =
ArrayD::from_shape_vec(vec![2, 2], vec![1, 2, 3, 4]).unwrap();
let dims: Vec<String> = vec!["y".to_string(), "x".to_string()];
let mut array: SpatialArray<i32> = SpatialArray::new(data, dims);
{
let dm = array.data_mut();
let slice = dm.as_slice_mut().unwrap();
slice[0] = 99;
}
assert_eq!(array.data().as_slice().unwrap()[0], 99);
}
#[test]
#[should_panic(expected = "Dimension not found")]
fn test_set_coords_nonexistent_dimension_panics() {
let data: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, ndarray::Dim<ndarray::IxDynImpl>> =
ArrayD::from_shape_vec(vec![2], vec![1, 2]).unwrap();
let dims: Vec<String> = vec!["a".to_string()];
let mut array: SpatialArray<i32> = SpatialArray::new(data, dims);
array.set_coords("b", vec!["x".to_string(), "y".to_string()]);
}
#[test]
fn test_dims_and_all_coords_empty() {
let data: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, ndarray::Dim<ndarray::IxDynImpl>> =
ArrayD::from_shape_vec(vec![2, 2], vec![1, 2, 3, 4]).unwrap();
let dims: Vec<String> = vec!["y".to_string(), "x".to_string()];
let array: SpatialArray<i32> = SpatialArray::new(data, dims.clone());
assert_eq!(array.dims(), &dims);
assert!(array.all_coords().is_empty());
}
#[test]
fn test_with_coords_success_and_all_coords() {
let data: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, ndarray::Dim<ndarray::IxDynImpl>> =
ArrayD::from_shape_vec(vec![2, 3], vec![1, 2, 3, 4, 5, 6]).unwrap();
let dims: Vec<String> = vec!["y".to_string(), "x".to_string()];
let mut numeric_coords: HashMap<String, Vec<f64>> = HashMap::new();
numeric_coords.insert("y".to_string(), vec![0.0, 10.0]);
numeric_coords.insert("x".to_string(), vec![0.0, 10.0, 20.0]);
let array: SpatialArray<i32> = SpatialArray::new(data, dims);
let result: SpatialArray<i32> = array
.with_coords(numeric_coords)
.expect("with_coords should succeed");
assert_eq!(result.coords("y").unwrap().len(), 2);
assert_eq!(result.coords("x").unwrap().len(), 3);
}
#[test]
fn test_with_coords_wrong_length_returns_err() {
let data: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, ndarray::Dim<ndarray::IxDynImpl>> =
ArrayD::from_shape_vec(vec![2, 3], vec![1, 2, 3, 4, 5, 6]).unwrap();
let dims: Vec<String> = vec!["y".to_string(), "x".to_string()];
let mut numeric_coords: HashMap<String, Vec<f64>> = HashMap::new();
numeric_coords.insert("y".to_string(), vec![0.0]);
let array: SpatialArray<i32> = SpatialArray::new(data, dims);
let err: String = array.with_coords(numeric_coords).unwrap_err();
assert!(err.contains("must match dimension size"));
}
#[test]
fn test_set_coords_success() {
let data: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, ndarray::Dim<ndarray::IxDynImpl>> =
ArrayD::from_shape_vec(vec![2], vec![1, 2]).unwrap();
let dims: Vec<String> = vec!["a".to_string()];
let mut array: SpatialArray<i32> = SpatialArray::new(data, dims);
array.set_coords("a", vec!["x".to_string(), "y".to_string()]);
assert_eq!(
array.coords("a").unwrap(),
&vec!["x".to_string(), "y".to_string()]
);
assert_eq!(array.select_by_label("a", "y"), Some(1));
}
#[test]
fn test_with_coords_dimension_not_found_returns_err() {
let data: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, ndarray::Dim<ndarray::IxDynImpl>> =
ArrayD::from_shape_vec(vec![2], vec![1, 2]).unwrap();
let dims: Vec<String> = vec!["a".to_string()];
let mut numeric_coords: HashMap<String, Vec<f64>> = HashMap::new();
numeric_coords.insert("z".to_string(), vec![0.0, 1.0]);
let array: SpatialArray<i32> = SpatialArray::new(data, dims);
let err = array.with_coords(numeric_coords).unwrap_err();
assert!(err.contains("not found"));
}
#[test]
fn test_info_no_coords() {
let data: ndarray::ArrayBase<ndarray::OwnedRepr<i32>, ndarray::Dim<ndarray::IxDynImpl>> =
ArrayD::from_shape_vec(vec![1, 2], vec![1, 2]).unwrap();
let dims: Vec<String> = vec!["y".to_string(), "x".to_string()];
let array: SpatialArray<i32> = SpatialArray::new(data, dims);
let info = array.info();
assert!(info.contains("SpatialArray"));
assert!(info.contains("Dimensions"));
assert!(info.contains("Shape"));
assert!(!info.contains("Coordinates:"));
}
}