burn_ir/tensor.rs
1use serde::{Deserialize, Serialize};
2
3use burn_tensor::{DType, Shape};
4
5/// The tensor unique identifier.
6#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
7pub struct TensorId {
8 value: u64,
9}
10
11impl core::fmt::Display for TensorId {
12 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
13 f.write_fmt(format_args!("TensorId({:?})", self.value))
14 }
15}
16
17/// The status of the current tensor.
18#[derive(Hash, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
19pub enum TensorStatus {
20 /// The tensor can be read, but not written.
21 ReadOnly,
22 /// The tensor can be mutated inplace.
23 ReadWrite,
24 /// No handle exists for that tensor.
25 NotInit,
26}
27
28/// A tensor definition represents a snapshot of a tensor when it was used.
29///
30/// # Example
31///
32/// A tensor that is used multiple times has its status updated for each operation.
33///
34/// 1. Status::NotInit
35/// 2. Status::ReadOnly
36/// 3. Status::ReadOnly
37/// 4. Status::ReadWrite
38#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
39pub struct TensorIr {
40 /// The [tensor id](TensorId).
41 pub id: TensorId,
42 /// The shape of the tensor.
43 pub shape: Shape,
44 /// The [status](TensorStatus) of the tensor when it was used.
45 pub status: TensorStatus,
46 /// The [type](DType) of the tensor.
47 pub dtype: DType,
48}
49
50impl TensorId {
51 /// Create a new tensor id.
52 pub fn new(value: u64) -> Self {
53 Self { value }
54 }
55}