Skip to main content

burn_named_tensor/
base.rs

1use alloc::format;
2
3use crate::NamedDims;
4use burn_tensor::backend::Backend;
5use burn_tensor::{Distribution, Shape, Tensor};
6
7/// A tensor with named dimensions.
8#[derive(Debug, Clone)]
9pub struct NamedTensor<B: Backend, D: NamedDims<B>> {
10    pub(crate) tensor: D::Tensor,
11}
12
13impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<NamedTensor<B, ND>>
14    for Tensor<B, D>
15{
16    fn from(nt: NamedTensor<B, ND>) -> Self {
17        nt.tensor
18    }
19}
20
21impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<Tensor<B, D>>
22    for NamedTensor<B, ND>
23{
24    fn from(tensor: Tensor<B, D>) -> Self {
25        Self::from_tensor(tensor)
26    }
27}
28
29impl<B: Backend, const D: usize, ND: NamedDims<B>> core::fmt::Display for NamedTensor<B, ND>
30where
31    ND: NamedDims<B, Tensor = Tensor<B, D>>,
32{
33    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34        f.write_str(&format!(
35            "NamedTensor[shape={:?}, dims={}]",
36            self.shape(),
37            ND::to_string(),
38        ))
39    }
40}
41
42impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
43where
44    ND: NamedDims<B, Tensor = Tensor<B, D>>,
45{
46    /// Create a named tensor from a tensor.
47    pub fn from_tensor(tensor: Tensor<B, D>) -> Self {
48        Self { tensor }
49    }
50
51    /// Create a random named tensor of the given shape where each element is sampled from
52    /// the given distribution.
53    pub fn random<S: Into<Shape>>(
54        shape: S,
55        distribution: Distribution,
56        device: &B::Device,
57    ) -> Self {
58        Self::from_tensor(Tensor::random(shape, distribution, device))
59    }
60
61    /// Returns the shape of the current tensor.
62    pub fn shape(&self) -> Shape {
63        self.tensor.shape()
64    }
65
66    /// Applies element wise multiplication operation.
67    ///
68    /// `y = x2 * x1`
69    #[allow(clippy::should_implement_trait)]
70    pub fn mul(self, rhs: Self) -> Self {
71        Self::from_tensor(self.tensor.mul(rhs.tensor))
72    }
73
74    /// Reshape the tensor to have the given shape.
75    ///
76    /// # Panics
77    ///
78    /// If the tensor can not be reshape to the given shape.
79    pub fn reshape<const D2: usize, S, ND2>(self, shape: S, _: ND2) -> NamedTensor<B, ND2>
80    where
81        S: Into<Shape>,
82        ND2: NamedDims<B, Tensor = Tensor<B, D2>>,
83    {
84        NamedTensor::from_tensor(self.tensor.reshape(shape.into()))
85    }
86}