use alloc::format;
use crate::backend::Backend;
use crate::{Distribution, NamedDims, Shape, Tensor};
#[derive(Debug, Clone)]
pub struct NamedTensor<B: Backend, D: NamedDims<B>> {
pub(crate) tensor: D::Tensor,
}
impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<NamedTensor<B, ND>>
for Tensor<B, D>
{
fn from(nt: NamedTensor<B, ND>) -> Self {
nt.tensor
}
}
impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<Tensor<B, D>>
for NamedTensor<B, ND>
{
fn from(tensor: Tensor<B, D>) -> Self {
Self::from_tensor(tensor)
}
}
impl<B: Backend, const D: usize, ND: NamedDims<B>> core::fmt::Display for NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&format!(
"NamedTensor[shape={:?}, dims={}]",
self.shape().dims,
ND::to_string(),
))
}
}
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
{
pub fn from_tensor(tensor: Tensor<B, D>) -> Self {
Self { tensor }
}
pub fn random<S: Into<Shape>>(
shape: S,
distribution: Distribution,
device: &B::Device,
) -> Self {
Self::from_tensor(Tensor::random(shape, distribution, device))
}
pub fn shape(&self) -> Shape {
self.tensor.shape()
}
#[allow(clippy::should_implement_trait)]
pub fn mul(self, rhs: Self) -> Self {
Self::from_tensor(self.tensor.mul(rhs.tensor))
}
pub fn reshape<const D2: usize, S, ND2>(self, shape: S, _: ND2) -> NamedTensor<B, ND2>
where
S: Into<Shape>,
ND2: NamedDims<B, Tensor = Tensor<B, D2>>,
{
NamedTensor::from_tensor(self.tensor.reshape(shape.into()))
}
}