burn_tensor/repr/
tensor.rs

1use serde::{Deserialize, Serialize};
2
3use alloc::vec::Vec;
4
5use crate::DType;
6
7/// The tensor unique identifier.
8#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
9pub struct TensorId {
10    value: u64,
11}
12
13/// The status of the current tensor.
14#[derive(Hash, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
15pub enum TensorStatus {
16    /// The tensor can be read, but not written.
17    ReadOnly,
18    /// The tensor can be mutated inplace.
19    ReadWrite,
20    /// No handle exists for that tensor.
21    NotInit,
22}
23
24/// A tensor definition represents a snapshot of a tensor when it was used.
25///
26/// # Example
27///
28/// A tensor that is used multiple times has its status updated for each operation.
29///
30///   1. Status::NotInit
31///   2. Status::ReadOnly
32///   3. Status::ReadOnly
33///   4. Status::ReadWrite
34#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
35pub struct TensorDescription {
36    /// The [tensor id](TensorId).
37    pub id: TensorId,
38    /// The shape of the tensor.
39    pub shape: Vec<usize>,
40    /// The [status](TensorStatus) of the tensor when it was used.
41    pub status: TensorStatus,
42    /// The [type](DType) of the tensor.
43    pub dtype: DType,
44}
45
46impl TensorId {
47    /// Create a new tensor id.
48    pub fn new(value: u64) -> Self {
49        Self { value }
50    }
51}