use core::sync::atomic::{AtomicU32, Ordering};
use alloc::format;
use alloc::{sync::Arc, vec::Vec};
use super::RunnerClient;
use burn_backend::{DType, Shape, TensorData, TensorMetadata, backend::ExecutionError};
use burn_ir::{TensorId, TensorIr, TensorStatus};
pub struct RouterTensor<C: RunnerClient> {
pub(crate) id: TensorId,
pub(crate) shape: Shape,
pub(crate) dtype: DType,
pub client: C,
pub(crate) count: Arc<AtomicU32>,
}
impl<C: RunnerClient> TensorMetadata for RouterTensor<C> {
fn dtype(&self) -> DType {
self.dtype
}
fn shape(&self) -> Shape {
self.shape.clone()
}
fn rank(&self) -> usize {
self.shape.num_dims()
}
}
impl<C: RunnerClient> RouterTensor<C> {
pub fn new(id: TensorId, shape: Shape, dtype: DType, client: C) -> Self {
Self {
id,
shape,
dtype,
client,
count: Arc::new(AtomicU32::new(1)),
}
}
pub(crate) async fn into_data(self) -> Result<TensorData, ExecutionError> {
self.client.clone().read_tensor_async(self.into_ir()).await
}
pub fn into_ir(mut self) -> TensorIr {
let count = self.count.load(Ordering::Relaxed);
let status = self.status(count);
let mut shape_out = Shape::from(Vec::<usize>::new());
core::mem::swap(&mut self.shape, &mut shape_out);
if let TensorStatus::ReadWrite = status {
self.count.fetch_add(1, Ordering::Relaxed);
}
TensorIr {
status,
shape: shape_out,
id: self.id,
dtype: self.dtype,
}
}
pub(crate) fn to_ir_out(&self) -> TensorIr {
TensorIr {
status: TensorStatus::NotInit,
shape: self.shape.clone(),
id: self.id,
dtype: self.dtype,
}
}
pub(crate) fn status(&self, count: u32) -> TensorStatus {
if count <= 1 {
TensorStatus::ReadWrite
} else {
TensorStatus::ReadOnly
}
}
}
impl<C: RunnerClient> core::fmt::Debug for RouterTensor<C> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
format!(
"{{ id: {:?}, shape: {:?}, dtype: {:?}, device: {:?} }}",
self.id,
self.shape,
self.dtype,
self.client.device().clone(),
)
.as_str(),
)
}
}
impl<C: RunnerClient> Clone for RouterTensor<C> {
fn clone(&self) -> Self {
self.count.fetch_add(1, Ordering::Relaxed);
Self {
id: self.id,
shape: self.shape.clone(),
client: self.client.clone(),
dtype: self.dtype,
count: self.count.clone(),
}
}
}
impl<C: RunnerClient> Drop for RouterTensor<C> {
fn drop(&mut self) {
let count = self.count.fetch_sub(1, Ordering::Relaxed);
match self.status(count) {
TensorStatus::ReadWrite => {
let id = self.id;
let mut shape = Shape::from(Vec::<usize>::new());
core::mem::swap(&mut shape, &mut self.shape);
let ir = TensorIr {
id,
shape,
status: TensorStatus::ReadWrite,
dtype: self.dtype,
};
self.client.register_op(burn_ir::OperationIr::Drop(ir));
}
TensorStatus::ReadOnly => {}
TensorStatus::NotInit => {}
}
}
}