1use serde::{Deserialize, Serialize};
2
3use burn_backend::{DType, Shape};
4
5#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
7pub struct TensorId {
8 value: u64,
9}
10
11impl core::fmt::Display for TensorId {
12 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
13 f.write_fmt(format_args!("TensorId({:?})", self.value))
14 }
15}
16
17#[derive(Hash, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
19pub enum TensorStatus {
20 ReadOnly,
22 ReadWrite,
24 NotInit,
26}
27
28#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
39pub struct TensorIr {
40 pub id: TensorId,
42 pub shape: Shape,
44 pub status: TensorStatus,
46 pub dtype: DType,
48}
49
50impl TensorId {
51 pub fn new(value: u64) -> Self {
53 Self { value }
54 }
55}
56
57impl TensorIr {
58 pub fn uninit(id: TensorId, shape: Shape, dtype: DType) -> Self {
60 Self {
61 id,
62 status: TensorStatus::NotInit,
63 shape,
64 dtype,
65 }
66 }
67}