use crate::{client::FusionClient, stream::StreamId, Client, FusionBackend, FusionRuntime};
use burn_tensor::{
quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy},
repr::{TensorDescription, TensorId, TensorStatus},
DType, Shape, TensorData,
};
use std::sync::Arc;
pub struct FusionTensor<R: FusionRuntime> {
pub id: Arc<TensorId>,
pub shape: Vec<usize>,
pub client: Client<R>,
pub dtype: DType,
pub(crate) is_orphan: bool,
pub(crate) stream: StreamId,
}
impl<R: FusionRuntime> Clone for FusionTensor<R> {
fn clone(&self) -> Self {
Self {
id: self.id.clone(),
shape: self.shape.clone(),
client: self.client.clone(),
dtype: self.dtype,
is_orphan: self.is_orphan,
stream: self.stream,
}
}
}
impl<R: FusionRuntime> core::fmt::Debug for FusionTensor<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(
format!(
"{{ id: {:?}, shape: {:?}, should_drop: {:?}, device: {:?} }}",
self.id,
self.shape,
self.is_orphan,
self.client.device().clone(),
)
.as_str(),
)
}
}
impl<R: FusionRuntime> FusionTensor<R> {
pub(crate) fn new(
id: Arc<TensorId>,
shape: Vec<usize>,
dtype: DType,
client: Client<R>,
stream: StreamId,
) -> Self {
Self {
id,
shape,
client,
dtype,
is_orphan: true,
stream,
}
}
pub(crate) fn shape<const D: usize>(&self) -> Shape<D> {
Shape::from(self.shape.clone())
}
fn status(&self) -> TensorStatus {
if Arc::strong_count(&self.id) <= 1 {
TensorStatus::ReadWrite
} else {
TensorStatus::ReadOnly
}
}
pub(crate) fn to_description_out(&self) -> TensorDescription {
TensorDescription {
status: TensorStatus::NotInit,
shape: self.shape.clone(),
id: *self.id.as_ref(),
dtype: self.dtype,
}
}
pub(crate) fn into_description(mut self) -> TensorDescription {
let status = self.status();
let mut shape_out = Vec::new();
core::mem::swap(&mut self.shape, &mut shape_out);
if let TensorStatus::ReadWrite = status {
self.is_orphan = false;
}
TensorDescription {
status,
shape: shape_out,
id: *self.id.as_ref(),
dtype: self.dtype,
}
}
pub(crate) async fn into_data<B, const D: usize>(self) -> TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
let id = self.stream;
self.client
.clone()
.read_tensor_float::<B, D>(self.into_description(), id)
.await
}
pub(crate) async fn int_into_data<B, const D: usize>(self) -> TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
let id = self.stream;
self.client
.clone()
.read_tensor_int::<B, D>(self.into_description(), id)
.await
}
pub(crate) async fn bool_into_data<B, const D: usize>(self) -> TensorData
where
B: FusionBackend<FusionRuntime = R>,
{
let id = self.stream;
self.client
.clone()
.read_tensor_bool::<B, D>(self.into_description(), id)
.await
}
}
impl<R: FusionRuntime> Drop for FusionTensor<R> {
fn drop(&mut self) {
if !self.is_orphan {
return;
}
match self.status() {
TensorStatus::ReadWrite => {
self.client.register_orphan(&self.id);
}
TensorStatus::ReadOnly => {}
TensorStatus::NotInit => {}
}
}
}
#[derive(Debug)]
pub struct QFusionTensor<R: FusionRuntime> {
pub qtensor: FusionTensor<R>,
pub scheme: QuantizationScheme,
}
impl<R: FusionRuntime> QTensorPrimitive for QFusionTensor<R> {
fn scheme(&self) -> &QuantizationScheme {
&self.scheme
}
fn strategy(&self) -> QuantizationStrategy {
todo!()
}
}
impl<R: FusionRuntime> Clone for QFusionTensor<R> {
fn clone(&self) -> Self {
Self {
qtensor: self.qtensor.clone(),
scheme: self.scheme.clone(),
}
}
}