use serde::{Deserialize, Serialize};
use burn_backend::{DType, Shape};
#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
pub struct TensorId {
value: u64,
}
impl core::fmt::Display for TensorId {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("TensorId({:?})", self.value))
}
}
#[derive(Hash, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum TensorStatus {
ReadOnly,
ReadWrite,
NotInit,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct TensorIr {
pub id: TensorId,
pub shape: Shape,
pub status: TensorStatus,
pub dtype: DType,
}
impl TensorId {
pub fn new(value: u64) -> Self {
Self { value }
}
}
impl TensorIr {
pub fn uninit(id: TensorId, shape: Shape, dtype: DType) -> Self {
Self {
id,
status: TensorStatus::NotInit,
shape,
dtype,
}
}
}