use crate::datatypes::{TensorData, TensorDimension};
use re_types_core::ArrowString;
use super::Tensor;
impl Tensor {
pub fn data(&self) -> &TensorData {
&self.data.0
}
pub fn try_from<T: TryInto<TensorData>>(data: T) -> Result<Self, T::Error> {
let data: TensorData = data.try_into()?;
Ok(Self { data: data.into() })
}
pub fn with_dim_names(self, names: impl IntoIterator<Item = impl Into<ArrowString>>) -> Self {
let names: Vec<_> = names.into_iter().map(|x| Some(x.into())).collect();
if names.len() != self.data.0.shape.len() {
re_log::warn_once!(
"Wrong number of names provided for tensor dimension. {} provided but {} expected.",
names.len(),
self.data.0.shape.len(),
);
}
Self {
data: TensorData {
shape: self
.data
.0
.shape
.into_iter()
.zip(names.into_iter().chain(std::iter::repeat(None)))
.map(|(dim, name)| TensorDimension {
size: dim.size,
name: name.or(dim.name),
})
.collect(),
buffer: self.data.0.buffer,
}
.into(),
}
}
}
macro_rules! forward_array_views {
($type:ty, $alias:ty) => {
impl<'a> TryFrom<&'a $alias> for ::ndarray::ArrayViewD<'a, $type> {
type Error = crate::tensor_data::TensorCastError;
#[inline]
fn try_from(value: &'a $alias) -> Result<Self, Self::Error> {
value.data().try_into()
}
}
};
}
forward_array_views!(u8, Tensor);
forward_array_views!(u16, Tensor);
forward_array_views!(u32, Tensor);
forward_array_views!(u64, Tensor);
forward_array_views!(i8, Tensor);
forward_array_views!(i16, Tensor);
forward_array_views!(i32, Tensor);
forward_array_views!(i64, Tensor);
forward_array_views!(half::f16, Tensor);
forward_array_views!(f32, Tensor);
forward_array_views!(f64, Tensor);