use crate::{LibTorchDevice, QuantElement};
use burn_tensor::{
quantization::{
AffineQuantization, QTensorPrimitive, QuantizationScheme, QuantizationStrategy,
QuantizationType, SymmetricQuantization,
},
Element, Shape, TensorData,
};
use libc::c_void;
use std::{marker::PhantomData, sync::Arc};
#[allow(clippy::arc_with_non_send_sync)]
pub type StorageRef = Arc<*mut c_void>;
#[derive(PartialEq, Debug, Clone)]
pub enum Storage {
View {
buffer_ref: StorageRef,
view_ref: StorageRef,
},
Owned {
buffer_ref: StorageRef,
},
}
impl Storage {
pub fn can_mut(&self) -> bool {
match self {
Storage::View {
buffer_ref: start_ref,
view_ref,
} => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
Storage::Owned {
buffer_ref: start_ref,
} => Arc::strong_count(start_ref) == 1,
}
}
pub fn buffer_ref(&self) -> &StorageRef {
match self {
Storage::View {
buffer_ref: start_ref,
view_ref: _,
} => start_ref,
Storage::Owned {
buffer_ref: start_ref,
} => start_ref,
}
}
}
#[derive(Debug, PartialEq)]
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
pub tensor: tch::Tensor,
pub storage: Storage,
phantom: PhantomData<E>,
}
impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
pub fn new(tensor: tch::Tensor) -> Self {
#[allow(clippy::arc_with_non_send_sync)]
let storage = Storage::Owned {
buffer_ref: Arc::new(tensor.data_ptr()),
};
Self {
tensor,
storage,
phantom: PhantomData,
}
}
pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
let storage_child = tensor.data_ptr();
let mut is_a_new_tensor = true;
match &storage_parent {
Storage::View {
buffer_ref: start_ref,
view_ref,
} => {
if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
is_a_new_tensor = false;
}
}
Storage::Owned {
buffer_ref: start_ref,
} => {
if storage_child == *start_ref.as_ref() {
is_a_new_tensor = false;
}
}
};
let storage = match is_a_new_tensor {
true => Storage::Owned {
#[allow(clippy::arc_with_non_send_sync)]
buffer_ref: Arc::new(storage_child),
},
false => storage_parent.clone(),
};
Self {
tensor,
storage,
phantom: PhantomData,
}
}
pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
let storage = Storage::View {
buffer_ref: storage_parent.buffer_ref().clone(),
#[allow(clippy::arc_with_non_send_sync)]
view_ref: Arc::new(tensor.data_ptr()),
};
Self {
tensor,
storage,
phantom: PhantomData,
}
}
}
impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
pub(crate) fn shape(&self) -> Shape<D> {
Shape::from(self.tensor.size())
}
}
unsafe impl<E: tch::kind::Element, const D: usize> Send for TchTensor<E, D> {}
unsafe impl<E: tch::kind::Element, const D: usize> Sync for TchTensor<E, D> {}
impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
pub fn can_mut(&self) -> bool {
let stride_contains_zero = self.tensor.stride().iter().any(|&s| s == 0);
!stride_contains_zero && self.storage.can_mut()
}
pub fn mut_ops<
F: Fn(&mut tch::Tensor) -> tch::Tensor,
EOut: tch::kind::Element,
const D_OUT: usize,
>(
&mut self,
func: F,
) -> Option<TchTensor<EOut, D_OUT>> {
if !self.can_mut() {
return None;
}
let data = self.storage.clone();
Some(TchTensor::from_existing(func(&mut self.tensor), data))
}
pub fn unary_ops<FOwn, FRef, EOut: tch::kind::Element, const D_OUT: usize>(
self,
fown: FOwn,
fref: FRef,
) -> TchTensor<EOut, D_OUT>
where
FOwn: Fn(tch::Tensor) -> tch::Tensor,
FRef: Fn(&tch::Tensor) -> tch::Tensor,
{
if !self.can_mut() {
return TchTensor::from_existing(fref(&self.tensor), self.storage);
}
TchTensor::from_existing(fown(self.tensor), self.storage)
}
pub fn binary_ops_tensor<FLMut, FRMut, FRef, EOut: tch::kind::Element, const D_OUT: usize>(
mut lhs: Self,
mut rhs: Self,
flmut: FLMut,
frmut: FRMut,
fref: FRef,
) -> TchTensor<EOut, D_OUT>
where
FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor,
FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
{
let lhs_shape = lhs.shape();
let rhs_shape = rhs.shape();
let mut out_shape = Shape::new([1; D_OUT]);
for i in 0..D_OUT {
out_shape.dims[i] = usize::max(lhs_shape.dims[i], rhs_shape.dims[i]);
}
let num_elements_out = out_shape.num_elements();
if lhs_shape.num_elements() == num_elements_out {
if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) {
return output;
}
}
if rhs_shape.num_elements() == num_elements_out {
if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) {
return output;
}
}
let storage = lhs.storage;
let tensor = fref(&lhs.tensor, &rhs.tensor);
TchTensor::from_existing(tensor, storage)
}
}
impl<P: tch::kind::Element, const D: usize> Clone for TchTensor<P, D> {
fn clone(&self) -> Self {
Self {
tensor: self.tensor.shallow_clone(),
phantom: PhantomData,
storage: self.storage.clone(),
}
}
}
#[derive(Debug)]
pub struct TchShape<const D: usize> {
pub dims: [i64; D],
}
impl<const D: usize> From<Shape<D>> for TchShape<D> {
fn from(shape: Shape<D>) -> Self {
let mut dims = [0; D];
for (i, dim) in dims.iter_mut().enumerate().take(D) {
*dim = shape.dims[i] as i64;
}
TchShape { dims }
}
}
impl<const D: usize> From<&[usize]> for TchShape<D> {
fn from(shape: &[usize]) -> Self {
let mut dims = [0; D];
for (i, dim) in dims.iter_mut().enumerate().take(D) {
*dim = shape[i] as i64;
}
TchShape { dims }
}
}
impl<E: tch::kind::Element + Default + Element, const D: usize> TchTensor<E, D> {
pub fn from_data(data: TensorData, device: tch::Device) -> Self {
let shape_tch = TchShape::<D>::from(data.shape.as_slice());
let tensor =
tch::Tensor::from_slice(data.convert::<E>().as_slice::<E>().unwrap()).to(device);
let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND);
Self::new(tensor)
}
}
impl<E: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<E, D> {
pub fn empty(shape: Shape<D>, device: LibTorchDevice) -> Self {
let shape_tch = TchShape::from(shape);
let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into()));
Self::new(tensor)
}
}
#[derive(Clone, Debug)]
pub struct TchQTensor<Q: QuantElement, const D: usize> {
pub qtensor: TchTensor<Q, D>,
pub scheme: QuantizationScheme,
}
impl<Q: QuantElement, const D: usize> QTensorPrimitive for TchQTensor<Q, D> {
fn scheme(&self) -> &QuantizationScheme {
&self.scheme
}
fn strategy(&self) -> QuantizationStrategy {
match &self.scheme {
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
QuantizationType::QInt8 => {
let scale = self.qtensor.tensor.q_scale();
let offset = self.qtensor.tensor.q_zero_point();
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
scale as f32,
offset as i8,
))
}
},
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => {
let scale = self.qtensor.tensor.q_scale();
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
scale as f32,
))
}
},
}
}
}
#[cfg(test)]
mod tests {
use crate::LibTorch;
use super::*;
use burn_tensor::ops::QTensorOps;
use burn_tensor::quantization::QuantizationParametersPrimitive;
use burn_tensor::{Distribution, Tensor, TensorPrimitive};
use rand::prelude::StdRng;
use rand::SeedableRng;
#[test]
fn should_support_into_and_from_data_1d() {
let data_expected = TensorData::random::<f32, _, _>(
Shape::new([3]),
Distribution::Default,
&mut StdRng::from_entropy(),
);
let tensor = TchTensor::<f32, 1>::from_data(data_expected.clone(), tch::Device::Cpu);
let data_actual =
Tensor::<LibTorch<f32>, 1>::from_primitive(TensorPrimitive::Float(tensor)).into_data();
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_support_into_and_from_data_2d() {
let data_expected = TensorData::random::<f32, _, _>(
Shape::new([2, 3]),
Distribution::Default,
&mut StdRng::from_entropy(),
);
let tensor = TchTensor::<f32, 2>::from_data(data_expected.clone(), tch::Device::Cpu);
let data_actual =
Tensor::<LibTorch<f32>, 2>::from_primitive(TensorPrimitive::Float(tensor)).into_data();
assert_eq!(data_expected, data_actual);
}
#[test]
fn should_not_update_inplace_after_reshape() {
let tensor_1 = Tensor::<LibTorch<f32>, 1>::from_floats([4.0, 4.0], &Default::default());
let tensor_2 = tensor_1.clone();
let tensor_3 = tensor_2.reshape([1, 2]).add_scalar(2.0);
assert_ne!(
tensor_3.to_data().as_slice::<f32>().unwrap(),
tensor_1.to_data().as_slice::<f32>().unwrap()
);
}
#[test]
fn should_not_update_inplace_after_slice() {
let tensor_1 = Tensor::<LibTorch<f32>, 1>::from_floats([4.0, 4.0], &Default::default());
let tensor_2 = tensor_1.clone();
let tensor_3 = tensor_2.slice([0..2]).add_scalar(2.0);
assert_ne!(
tensor_3.to_data().as_slice::<f32>().unwrap(),
tensor_1.to_data().as_slice::<f32>().unwrap()
);
}
#[test]
fn should_support_qtensor_strategy() {
let tensor = TchTensor::<f32, 1>::from_data(
TensorData::from([-1.8, -1.0, 0.0, 0.5]),
tch::Device::Cpu,
);
let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
let qparams = QuantizationParametersPrimitive {
scale: TchTensor::from_data(TensorData::from([0.009_019_608]), tch::Device::Cpu),
offset: Some(TchTensor::from_data(
TensorData::from([72]),
tch::Device::Cpu,
)),
};
let qtensor: TchQTensor<i8, 1> = LibTorch::quantize(tensor, &scheme, qparams);
assert_eq!(qtensor.scheme(), &scheme);
assert_eq!(
qtensor.strategy(),
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72))
);
}
}