burn_ir/
tensor.rs

1use serde::{Deserialize, Serialize};
2
3use burn_tensor::{DType, Shape};
4
5/// The tensor unique identifier.
6#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
7pub struct TensorId {
8    value: u64,
9}
10
11impl core::fmt::Display for TensorId {
12    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
13        f.write_fmt(format_args!("TensorId({:?})", self.value))
14    }
15}
16
17/// The status of the current tensor.
18#[derive(Hash, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
19pub enum TensorStatus {
20    /// The tensor can be read, but not written.
21    ReadOnly,
22    /// The tensor can be mutated inplace.
23    ReadWrite,
24    /// No handle exists for that tensor.
25    NotInit,
26}
27
28/// A tensor definition represents a snapshot of a tensor when it was used.
29///
30/// # Example
31///
32/// A tensor that is used multiple times has its status updated for each operation.
33///
34///   1. Status::NotInit
35///   2. Status::ReadOnly
36///   3. Status::ReadOnly
37///   4. Status::ReadWrite
38#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
39pub struct TensorIr {
40    /// The [tensor id](TensorId).
41    pub id: TensorId,
42    /// The shape of the tensor.
43    pub shape: Shape,
44    /// The [status](TensorStatus) of the tensor when it was used.
45    pub status: TensorStatus,
46    /// The [type](DType) of the tensor.
47    pub dtype: DType,
48}
49
50impl TensorId {
51    /// Create a new tensor id.
52    pub fn new(value: u64) -> Self {
53        Self { value }
54    }
55}