use ndarray::{ArrayD, Axis};
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
pub trait AnyVec: Any + Debug {
fn as_any(&self) -> &dyn Any;
fn clone_box(&self) -> Box<dyn AnyVec>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
}
impl<T: Any + Clone + Debug + 'static> AnyVec for Vec<T> {
fn as_any(&self) -> &dyn Any {
self
}
fn clone_box(&self) -> Box<dyn AnyVec> {
Box::new(self.clone())
}
fn len(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.is_empty()
}
}
impl Clone for Box<dyn AnyVec> {
fn clone(&self) -> Self {
self.clone_box()
}
}
#[derive(Clone)]
pub struct CoordinateVec(pub Box<dyn AnyVec>);
impl Debug for CoordinateVec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Coordinate vector with {} elements", self.0.len())
}
}
pub trait Coordinate: Any + Clone + Debug + PartialEq + 'static {}
impl<T: Any + Clone + Debug + PartialEq + 'static> Coordinate for T {}
pub trait FindIndex: 'static {
fn find_index_in(&self, coords: &dyn AnyVec) -> Option<usize>;
}
impl<C: Coordinate> FindIndex for C {
fn find_index_in(&self, coords: &dyn AnyVec) -> Option<usize> {
coords
.as_any()
.downcast_ref::<Vec<C>>()
.and_then(|vec| vec.iter().position(|item| item == self))
}
}
pub trait SliceSelector {
fn find_indices(&self, coords: &dyn AnyVec) -> Vec<usize>;
fn new_coords(&self) -> Box<dyn AnyVec>;
}
impl<C: Coordinate> SliceSelector for Vec<C> {
fn find_indices(&self, coords: &dyn AnyVec) -> Vec<usize> {
let coord_vec = coords
.as_any()
.downcast_ref::<Vec<C>>()
.expect("Coordinate type mismatch for slicing");
self.iter()
.map(|label| {
coord_vec
.iter()
.position(|item| item == label)
.expect("Label not found")
})
.collect()
}
fn new_coords(&self) -> Box<dyn AnyVec> {
Box::new(self.clone())
}
}
pub enum Selector {
Label(Box<dyn FindIndex>),
Slice(Box<dyn SliceSelector>),
}
#[derive(Debug, thiserror::Error)]
pub enum LabeledError {
#[error("Dimension not found: {0}")]
DimensionNotFound(String),
#[error(
"Coordinate length mismatch for dimension '{dim}': expected {expected}, found {found}"
)]
CoordinateLengthMismatch {
dim: String,
expected: usize,
found: usize,
},
}
#[derive(Debug, Clone)]
pub struct LabeledArray<T> {
data: ArrayD<T>,
dims: Vec<String>,
coords: HashMap<String, CoordinateVec>,
}
impl<T> LabeledArray<T>
where
T: Clone,
{
pub fn new(data: ArrayD<T>, dims: Vec<String>) -> Self {
assert_eq!(data.ndim(), dims.len());
Self {
data,
dims,
coords: HashMap::new(),
}
}
pub fn new_with_coords(
data: ArrayD<T>,
dims: Vec<String>,
coords: HashMap<String, CoordinateVec>,
) -> Self {
assert_eq!(data.ndim(), dims.len());
for (i, dim_name) in dims.iter().enumerate() {
if let Some(coord) = coords.get(dim_name) {
assert_eq!(coord.0.len(), data.shape()[i]);
}
}
Self { data, dims, coords }
}
pub fn dims(&self) -> &[String] {
&self.dims
}
pub fn shape(&self) -> &[usize] {
self.data.shape()
}
pub fn ndim(&self) -> usize {
self.data.ndim()
}
pub fn data(&self) -> &ArrayD<T> {
&self.data
}
pub fn data_mut(&mut self) -> &mut ArrayD<T> {
&mut self.data
}
pub fn coords<C: Coordinate>(&self, dim: &str) -> Option<&[C]> {
self.coords
.get(dim)
.and_then(|any_vec| any_vec.0.as_any().downcast_ref::<Vec<C>>())
.map(|vec| vec.as_slice())
}
pub fn all_coords(&self) -> &HashMap<String, CoordinateVec> {
&self.coords
}
pub fn set_coords<C: Coordinate>(
&mut self,
dim: &str,
coords: Vec<C>,
) -> Result<(), LabeledError> {
let dim_index = self
.dims
.iter()
.position(|d| d == dim)
.ok_or_else(|| LabeledError::DimensionNotFound(dim.to_string()))?;
let expected_len = self.data.shape()[dim_index];
if coords.len() != expected_len {
return Err(LabeledError::CoordinateLengthMismatch {
dim: dim.to_string(),
expected: expected_len,
found: coords.len(),
});
}
self.coords
.insert(dim.to_string(), CoordinateVec(Box::new(coords)));
Ok(())
}
pub fn dim_index(&self, dim: &str) -> Option<usize> {
self.dims.iter().position(|d| d == dim)
}
pub fn select_by_label<C: Coordinate>(&self, dim: &str, label: &C) -> Option<usize> {
self.coords::<C>(dim)
.and_then(|coords| coords.iter().position(|c| c == label))
}
pub fn sel(&self, selectors: HashMap<&str, Selector>) -> Self {
let mut new_data = self.data.clone();
let mut new_dims = self.dims.clone();
let mut new_coords = self.coords.clone();
for dim_name in selectors.keys() {
self.dim_index(dim_name).expect("Dimension not found");
}
let mut sorted_selectors: Vec<_> = selectors
.iter()
.map(|(dim_name, selector)| {
let dim_index = new_dims.iter().position(|d| d == *dim_name).unwrap();
(dim_index, *dim_name, selector)
})
.collect();
sorted_selectors.sort_by(|a, b| b.0.cmp(&a.0));
for (dim_index, dim_name, selector) in sorted_selectors {
let current_coords = self
.coords
.get(dim_name)
.expect("Coordinates not found for dimension.");
match selector {
Selector::Label(label_selector) => {
let index = label_selector
.find_index_in(&*current_coords.0)
.expect("Label not found in coordinates.");
new_data = new_data.index_axis(Axis(dim_index), index).to_owned();
new_dims.remove(dim_index);
new_coords.remove(dim_name);
}
Selector::Slice(slice_selector) => {
let indices = slice_selector.find_indices(&*current_coords.0);
new_data = new_data.select(Axis(dim_index), &indices).to_owned();
new_coords.insert(
dim_name.to_string(),
CoordinateVec(slice_selector.new_coords()),
);
}
}
}
LabeledArray {
data: new_data,
dims: new_dims,
coords: new_coords,
}
}
}
impl<T: Clone + Debug> LabeledArray<T> {
pub fn info(&self) -> String {
let mut info = String::new();
info.push_str(&format!("LabeledArray<{:?}>\n", std::any::type_name::<T>()));
info.push_str(&format!("Dimensions: {:?}\n", self.dims));
info.push_str(&format!("Shape: {:?}\n", self.shape()));
if !self.coords.is_empty() {
info.push_str("Coordinates:\n");
for (dim, coord_vec) in &self.coords {
info.push_str(&format!(" {}: {:?}\n", dim, coord_vec));
}
}
info
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_labeled_array() {
let data = ArrayD::from_shape_vec(vec![2, 3], vec![1, 2, 3, 4, 5, 6]).unwrap();
let dims = vec!["y".to_string(), "x".to_string()];
let array: LabeledArray<i32> = LabeledArray::new(data, dims);
assert_eq!(array.shape(), &[2, 3]);
}
#[test]
#[should_panic]
fn test_new_labeled_array_dimension_mismatch() {
let data = ArrayD::from_shape_vec(vec![2, 3], vec![1, 2, 3, 4, 5, 6]).unwrap();
let dims = vec!["y".to_string()];
LabeledArray::new(data, dims);
}
#[test]
fn test_set_coords_and_coords() {
let data = ArrayD::from_shape_vec(vec![2, 3], vec![0; 6]).unwrap();
let mut array = LabeledArray::new(data, vec!["y".to_string(), "x".to_string()]);
array
.set_coords("y", vec![10.0, 20.0])
.expect("Failed to set y coords");
array
.set_coords("x", vec!["a".to_string(), "b".to_string(), "c".to_string()])
.expect("Failed to set x coords");
assert_eq!(array.coords::<f64>("y").unwrap(), &[10.0, 20.0]);
assert_eq!(
array.coords::<String>("x").unwrap(),
&["a".to_string(), "b".to_string(), "c".to_string()]
);
}
#[test]
fn test_select_by_label_generic() {
let data = ArrayD::from_shape_vec(vec![2, 3], vec![0; 6]).unwrap();
let mut array = LabeledArray::new(data, vec!["y".to_string(), "x".to_string()]);
array
.set_coords("y", vec![10, 20])
.expect("Failed to set y coords");
array
.set_coords("x", vec!["a".to_string(), "b".to_string(), "c".to_string()])
.expect("Failed to set x coords");
assert_eq!(array.select_by_label("y", &20), Some(1));
assert_eq!(array.select_by_label("x", &"b".to_string()), Some(1));
}
#[test]
fn test_sel_label_and_slice() {
let data = ArrayD::from_shape_vec(vec![2, 3, 4], vec![0.0; 24]).unwrap();
let mut array = LabeledArray::new(
data,
vec!["time".to_string(), "y".to_string(), "x".to_string()],
);
array
.set_coords("time", vec![0, 1])
.expect("Failed to set time coords");
array
.set_coords("y", vec![10.0, 20.0, 30.0])
.expect("Failed to set y coords");
array
.set_coords(
"x",
vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
],
)
.expect("Failed to set x coords");
let result = array.sel(HashMap::from([
("time", Selector::Label(Box::new(1))),
("y", Selector::Slice(Box::new(vec![10.0, 30.0]))),
]));
assert_eq!(result.ndim(), 2);
assert_eq!(result.shape(), &[2, 4]);
assert_eq!(result.dims(), &["y", "x"]);
assert_eq!(result.coords::<f64>("y").unwrap(), &[10.0, 30.0]);
}
#[test]
fn test_sel_single_label() {
let data = ArrayD::from_shape_vec(vec![2, 3], vec![0; 6]).unwrap();
let mut array = LabeledArray::new(data, vec!["y".to_string(), "x".to_string()]);
array
.set_coords("y", vec![10, 20])
.expect("Failed to set y coords");
array
.set_coords("x", vec!["a".to_string(), "b".to_string(), "c".to_string()])
.expect("Failed to set x coords");
let result = array.sel(HashMap::from([(
"y",
Selector::Label(Box::new(20) as Box<dyn FindIndex>),
)]));
assert_eq!(result.shape(), &[3]);
assert_eq!(result.dims(), &["x"]);
assert_eq!(
result.coords::<String>("x").unwrap(),
&["a".to_string(), "b".to_string(), "c".to_string()]
);
}
#[test]
fn test_sel_multi_slice() {
let data = ArrayD::from_shape_vec(vec![2, 3, 4], (0..24).collect::<Vec<i32>>()).unwrap();
let mut array = LabeledArray::new(
data,
vec!["time".to_string(), "y".to_string(), "x".to_string()],
);
array
.set_coords("time", vec![100, 200])
.expect("Failed to set time coords");
array
.set_coords("y", vec![10.0, 20.0, 30.0])
.expect("Failed to set y coords");
array
.set_coords(
"x",
vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
],
)
.expect("Failed to set x coords");
let result = array.sel(HashMap::from([
(
"x",
Selector::Slice(Box::new(vec!["a".to_string(), "c".to_string()])),
),
("y", Selector::Slice(Box::new(vec![10.0, 30.0]))),
]));
assert_eq!(result.shape(), &[2, 2, 2]);
assert_eq!(result.dims(), &["time", "y", "x"]);
assert_eq!(result.coords::<f64>("y").unwrap(), &[10.0, 30.0]);
assert_eq!(
result.coords::<String>("x").unwrap(),
&["a".to_string(), "c".to_string()]
);
}
#[test]
fn test_info_with_mixed_coords() {
let data = ArrayD::from_shape_vec(vec![2, 3], vec![0; 6]).unwrap();
let mut array = LabeledArray::new(data, vec!["y".to_string(), "x".to_string()]);
array
.set_coords("y", vec![10, 20])
.expect("Failed to set y coords");
array
.set_coords("x", vec!["a".to_string(), "b".to_string(), "c".to_string()])
.expect("Failed to set x coords");
let info = array.info();
assert!(info.contains("y: Coordinate vector with 2 elements"));
assert!(info.contains("x: Coordinate vector with 3 elements"));
}
}