burn_named_tensor/
base.rs1use alloc::format;
2
3use crate::NamedDims;
4use burn_tensor::backend::Backend;
5use burn_tensor::{Distribution, Shape, Tensor};
6
7#[derive(Debug, Clone)]
9pub struct NamedTensor<B: Backend, D: NamedDims<B>> {
10 pub(crate) tensor: D::Tensor,
11}
12
13impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<NamedTensor<B, ND>>
14 for Tensor<B, D>
15{
16 fn from(nt: NamedTensor<B, ND>) -> Self {
17 nt.tensor
18 }
19}
20
21impl<B: Backend, ND: NamedDims<B, Tensor = Tensor<B, D>>, const D: usize> From<Tensor<B, D>>
22 for NamedTensor<B, ND>
23{
24 fn from(tensor: Tensor<B, D>) -> Self {
25 Self::from_tensor(tensor)
26 }
27}
28
29impl<B: Backend, const D: usize, ND: NamedDims<B>> core::fmt::Display for NamedTensor<B, ND>
30where
31 ND: NamedDims<B, Tensor = Tensor<B, D>>,
32{
33 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34 f.write_str(&format!(
35 "NamedTensor[shape={:?}, dims={}]",
36 self.shape(),
37 ND::to_string(),
38 ))
39 }
40}
41
42impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
43where
44 ND: NamedDims<B, Tensor = Tensor<B, D>>,
45{
46 pub fn from_tensor(tensor: Tensor<B, D>) -> Self {
48 Self { tensor }
49 }
50
51 pub fn random<S: Into<Shape>>(
54 shape: S,
55 distribution: Distribution,
56 device: &B::Device,
57 ) -> Self {
58 Self::from_tensor(Tensor::random(shape, distribution, device))
59 }
60
61 pub fn shape(&self) -> Shape {
63 self.tensor.shape()
64 }
65
66 #[allow(clippy::should_implement_trait)]
70 pub fn mul(self, rhs: Self) -> Self {
71 Self::from_tensor(self.tensor.mul(rhs.tensor))
72 }
73
74 pub fn reshape<const D2: usize, S, ND2>(self, shape: S, _: ND2) -> NamedTensor<B, ND2>
80 where
81 S: Into<Shape>,
82 ND2: NamedDims<B, Tensor = Tensor<B, D2>>,
83 {
84 NamedTensor::from_tensor(self.tensor.reshape(shape.into()))
85 }
86}