burn_tensor/repr/
tensor.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
use serde::{Deserialize, Serialize};

use alloc::vec::Vec;

use crate::DType;

/// The tensor unique identifier.
#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
pub struct TensorId {
    value: u64,
}

/// The status of the current tensor.
#[derive(Hash, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum TensorStatus {
    /// The tensor can be read, but not written.
    ReadOnly,
    /// The tensor can be mutated inplace.
    ReadWrite,
    /// No handle exists for that tensor.
    NotInit,
}

/// A tensor definition represents a snapshot of a tensor when it was used.
///
/// # Example
///
/// A tensor that is used multiple times has its status updated for each operation.
///
///   1. Status::NotInit
///   2. Status::ReadOnly
///   3. Status::ReadOnly
///   4. Status::ReadWrite
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct TensorDescription {
    /// The [tensor id](TensorId).
    pub id: TensorId,
    /// The shape of the tensor.
    pub shape: Vec<usize>,
    /// The [status](TensorStatus) of the tensor when it was used.
    pub status: TensorStatus,
    /// The [type](DType) of the tensor.
    pub dtype: DType,
}

impl TensorId {
    /// Create a new tensor id.
    pub fn new(value: u64) -> Self {
        Self { value }
    }
}