use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::tensor::Tensor;
#[derive(Clone, Debug)]
pub struct NamedTensor<T: Float> {
inner: Tensor<T>,
names: Vec<Option<String>>,
}
impl<T: Float> NamedTensor<T> {
pub fn new(inner: Tensor<T>, names: Vec<Option<String>>) -> FerrotorchResult<Self> {
if names.len() != inner.ndim() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"NamedTensor::new: names.len()={} != ndim={}",
names.len(),
inner.ndim()
),
});
}
let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
for n in names.iter().flatten() {
if !seen.insert(n.as_str()) {
return Err(FerrotorchError::InvalidArgument {
message: format!("NamedTensor::new: duplicate dim name '{n}'"),
});
}
}
Ok(Self { inner, names })
}
pub fn refined(inner: Tensor<T>, names: &[&str]) -> FerrotorchResult<Self> {
let owned: Vec<Option<String>> = names
.iter()
.map(|s| {
if s.is_empty() {
None
} else {
Some((*s).to_string())
}
})
.collect();
Self::new(inner, owned)
}
pub fn tensor(&self) -> &Tensor<T> {
&self.inner
}
pub fn into_tensor(self) -> Tensor<T> {
self.inner
}
pub fn names(&self) -> &[Option<String>] {
&self.names
}
pub fn shape(&self) -> &[usize] {
self.inner.shape()
}
pub fn ndim(&self) -> usize {
self.inner.ndim()
}
pub fn numel(&self) -> usize {
self.inner.numel()
}
pub fn dim_index(&self, name: &str) -> Option<usize> {
self.names.iter().position(|n| n.as_deref() == Some(name))
}
pub fn size_of(&self, name: &str) -> Option<usize> {
self.dim_index(name).map(|i| self.shape()[i])
}
pub fn rename(&self, mapping: &[(&str, &str)]) -> FerrotorchResult<Self> {
let map: std::collections::HashMap<&str, &str> = mapping.iter().copied().collect();
let new_names: Vec<Option<String>> = self
.names
.iter()
.map(|n| {
n.as_ref().map(|s| {
map.get(s.as_str())
.map_or_else(|| s.clone(), |n| (*n).to_string())
})
})
.collect();
Self::new(self.inner.clone(), new_names)
}
pub fn align_to(&self, target_names: &[&str]) -> FerrotorchResult<Self> {
if target_names.len() != self.ndim() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"NamedTensor::align_to: target len={} != ndim={}",
target_names.len(),
self.ndim()
),
});
}
let mut perm: Vec<usize> = Vec::with_capacity(self.ndim());
for &t in target_names {
let idx = self.dim_index(t).ok_or(FerrotorchError::InvalidArgument {
message: format!(
"NamedTensor::align_to: target name '{t}' not present in {:?}",
self.names
),
})?;
perm.push(idx);
}
let permuted = crate::methods::permute_t(&self.inner, &perm)?;
let new_names: Vec<Option<String>> = perm.iter().map(|&i| self.names[i].clone()).collect();
Self::new(permuted, new_names)
}
pub fn detached(&self) -> Self {
Self {
inner: self.inner.clone(),
names: vec![None; self.ndim()],
}
}
}
impl<T: Float> std::fmt::Display for NamedTensor<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let names: Vec<&str> = self
.names
.iter()
.map(|n| n.as_deref().unwrap_or("_"))
.collect();
write!(
f,
"NamedTensor(shape={:?}, names={:?})",
self.shape(),
names
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn t_f32(shape: &[usize]) -> Tensor<f32> {
let n: usize = shape.iter().product::<usize>().max(1);
let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false).unwrap()
}
#[test]
fn named_tensor_basic_construction() {
let nt = NamedTensor::refined(t_f32(&[2, 3, 4]), &["batch", "seq", "feat"]).unwrap();
assert_eq!(nt.ndim(), 3);
assert_eq!(nt.size_of("batch"), Some(2));
assert_eq!(nt.size_of("seq"), Some(3));
assert_eq!(nt.size_of("feat"), Some(4));
assert_eq!(nt.size_of("missing"), None);
}
#[test]
fn named_tensor_rejects_length_mismatch() {
let err = NamedTensor::refined(t_f32(&[2, 3]), &["only", "two", "names"]).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn named_tensor_rejects_duplicate_names() {
let err = NamedTensor::refined(t_f32(&[2, 3]), &["x", "x"]).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn named_tensor_anonymous_dim_via_empty_string() {
let nt = NamedTensor::refined(t_f32(&[2, 3]), &["batch", ""]).unwrap();
assert_eq!(nt.names()[0].as_deref(), Some("batch"));
assert_eq!(nt.names()[1], None);
}
#[test]
fn named_tensor_align_permutes_dims() {
let nt = NamedTensor::refined(t_f32(&[2, 3, 4]), &["batch", "seq", "feat"]).unwrap();
let aligned = nt.align_to(&["feat", "batch", "seq"]).unwrap();
assert_eq!(aligned.shape(), &[4, 2, 3]);
assert_eq!(aligned.names()[0].as_deref(), Some("feat"));
assert_eq!(aligned.names()[1].as_deref(), Some("batch"));
assert_eq!(aligned.names()[2].as_deref(), Some("seq"));
}
#[test]
fn named_tensor_align_identity_is_clone() {
let nt = NamedTensor::refined(t_f32(&[2, 3]), &["a", "b"]).unwrap();
let aligned = nt.align_to(&["a", "b"]).unwrap();
assert_eq!(aligned.shape(), nt.shape());
}
#[test]
fn named_tensor_align_rejects_unknown_name() {
let nt = NamedTensor::refined(t_f32(&[2, 3]), &["a", "b"]).unwrap();
let err = nt.align_to(&["a", "z"]).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn named_tensor_align_rejects_length_mismatch() {
let nt = NamedTensor::refined(t_f32(&[2, 3]), &["a", "b"]).unwrap();
let err = nt.align_to(&["a"]).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn named_tensor_rename_replaces_specified_names() {
let nt = NamedTensor::refined(t_f32(&[2, 3, 4]), &["b", "s", "f"]).unwrap();
let renamed = nt.rename(&[("b", "batch"), ("s", "seq")]).unwrap();
assert_eq!(renamed.names()[0].as_deref(), Some("batch"));
assert_eq!(renamed.names()[1].as_deref(), Some("seq"));
assert_eq!(renamed.names()[2].as_deref(), Some("f"));
}
#[test]
fn named_tensor_detached_drops_names() {
let nt = NamedTensor::refined(t_f32(&[2, 3]), &["a", "b"]).unwrap();
let d = nt.detached();
for n in d.names() {
assert!(n.is_none());
}
}
#[test]
fn named_tensor_into_tensor_recovers_inner() {
let nt = NamedTensor::refined(t_f32(&[2, 3]), &["a", "b"]).unwrap();
let t = nt.into_tensor();
assert_eq!(t.shape(), &[2, 3]);
}
#[test]
fn named_tensor_dim_index_lookup() {
let nt = NamedTensor::refined(t_f32(&[2, 3, 4]), &["batch", "seq", "feat"]).unwrap();
assert_eq!(nt.dim_index("batch"), Some(0));
assert_eq!(nt.dim_index("seq"), Some(1));
assert_eq!(nt.dim_index("feat"), Some(2));
assert_eq!(nt.dim_index("missing"), None);
}
}