#![allow(clippy::use_self)]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_with::{serde_as, Bytes, DeserializeAs, SerializeAs};
use std::borrow::Cow;
use tch::{Device, Kind, Tensor};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(remote = "Kind")]
pub enum KindDef {
Uint8,
Int8,
Int16,
Int,
Int64,
Half,
Float,
Double,
ComplexHalf,
ComplexFloat,
ComplexDouble,
Bool,
QInt8,
QUInt8,
QInt32,
BFloat16,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(remote = "Device")]
pub enum DeviceDef {
Cpu,
Cuda(usize),
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ByteOrder {
LittleEndian,
BigEndian,
}
impl ByteOrder {
#[must_use]
pub const fn native() -> Self {
if cfg!(target_endian = "big") {
Self::BigEndian
} else {
Self::LittleEndian
}
}
}
#[serde_as]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct TensorDef<'a> {
#[serde(with = "KindDef")]
pub kind: Kind,
#[serde(borrow)]
pub shape: Cow<'a, [i64]>,
pub requires_grad: bool,
pub byte_order: ByteOrder,
#[serde_as(as = "Bytes")]
#[serde(borrow)]
pub data: Cow<'a, [u8]>,
}
impl<'a> From<&'_ Tensor> for TensorDef<'a> {
fn from(tensor: &Tensor) -> Self {
let kind = tensor.kind();
let shape = tensor.size();
let num_elements: usize = shape.iter().product::<i64>().try_into().unwrap();
let mut data = vec![0; num_elements * kind.elt_size_in_bytes()];
tensor.copy_data_u8(&mut data, num_elements);
Self {
kind,
shape: Cow::Owned(shape),
requires_grad: tensor.requires_grad(),
byte_order: ByteOrder::native(),
data: Cow::Owned(data),
}
}
}
impl<'a> From<&TensorDef<'a>> for Tensor {
fn from(t: &TensorDef<'a>) -> Self {
assert_eq!(
t.byte_order,
ByteOrder::native(),
"data has non-native byte order"
);
Self::of_data_size(&t.data, &t.shape, t.kind).set_requires_grad(t.requires_grad)
}
}
impl<'a> From<TensorDef<'a>> for Tensor {
#[inline]
fn from(t: TensorDef<'a>) -> Self {
Self::from(&t)
}
}
impl<'a> SerializeAs<Tensor> for TensorDef<'a> {
fn serialize_as<S>(source: &Tensor, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
TensorDef::from(source).serialize(serializer)
}
}
impl<'de: 'a, 'a> DeserializeAs<'de, Tensor> for TensorDef<'a> {
fn deserialize_as<D>(deserializer: D) -> Result<Tensor, D::Error>
where
D: Deserializer<'de>,
{
let tensor_def: TensorDef = Deserialize::deserialize(deserializer)?;
Ok(tensor_def.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_test::{assert_tokens, Token};
use std::ops::Deref;
#[serde_as]
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct STensor(#[serde_as(as = "TensorDef")] pub Tensor);
impl From<Tensor> for STensor {
fn from(tensor: Tensor) -> Self {
STensor(tensor)
}
}
impl From<STensor> for Tensor {
fn from(stensor: STensor) -> Self {
stensor.0
}
}
impl Deref for STensor {
type Target = Tensor;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[test]
fn ser_de_tokens_0d_u32_tensor() {
let tensor = STensor(Tensor::of_slice(&[0x1234_5678_i32]).reshape(&[]));
let byte_order = ByteOrder::native();
let bytes: &'static [u8] = match byte_order {
ByteOrder::BigEndian => &[0x12, 0x34, 0x56, 0x78],
ByteOrder::LittleEndian => &[0x78, 0x56, 0x34, 0x12],
};
let tokens = [
Token::Struct {
name: "TensorDef",
len: 5,
},
Token::Str("kind"),
Token::UnitVariant {
name: "KindDef",
variant: "Int",
},
Token::Str("shape"),
Token::Seq { len: Some(0) },
Token::SeqEnd,
Token::Str("requires_grad"),
Token::Bool(false),
Token::Str("byte_order"),
Token::UnitVariant {
name: "ByteOrder",
variant: match byte_order {
ByteOrder::BigEndian => "BigEndian",
ByteOrder::LittleEndian => "LittleEndian",
},
},
Token::Str("data"),
Token::BorrowedBytes(bytes),
Token::StructEnd,
];
assert_tokens(&tensor, &tokens);
}
#[test]
fn ser_de_tokens_empty_f32_tensor() {
let tensor = STensor(Tensor::of_slice::<f32>(&[]));
let bytes: &'static [u8] = &[];
let tokens = [
Token::Struct {
name: "TensorDef",
len: 5,
},
Token::Str("kind"),
Token::UnitVariant {
name: "KindDef",
variant: "Float",
},
Token::Str("shape"),
Token::Seq { len: Some(1) },
Token::I64(0),
Token::SeqEnd,
Token::Str("requires_grad"),
Token::Bool(false),
Token::Str("byte_order"),
Token::UnitVariant {
name: "ByteOrder",
variant: match ByteOrder::native() {
ByteOrder::BigEndian => "BigEndian",
ByteOrder::LittleEndian => "LittleEndian",
},
},
Token::Str("data"),
Token::BorrowedBytes(bytes),
Token::StructEnd,
];
assert_tokens(&tensor, &tokens);
}
#[test]
fn ser_de_tokens_1d_f32_tensor_requires_grad() {
let tensor = STensor(Tensor::of_slice::<f32>(&[1.0]).set_requires_grad(true));
let bytes: &'static [u8] = &[0, 0, 128, 63];
let tokens = [
Token::Struct {
name: "TensorDef",
len: 5,
},
Token::Str("kind"),
Token::UnitVariant {
name: "KindDef",
variant: "Float",
},
Token::Str("shape"),
Token::Seq { len: Some(1) },
Token::I64(1),
Token::SeqEnd,
Token::Str("requires_grad"),
Token::Bool(true),
Token::Str("byte_order"),
Token::UnitVariant {
name: "ByteOrder",
variant: match ByteOrder::native() {
ByteOrder::BigEndian => "BigEndian",
ByteOrder::LittleEndian => "LittleEndian",
},
},
Token::Str("data"),
Token::BorrowedBytes(bytes),
Token::StructEnd,
];
assert_tokens(&tensor, &tokens);
}
#[test]
fn ser_de_tokens_2d_u8_tensor() {
let tensor = STensor(Tensor::of_slice::<u8>(&[1, 2, 3, 4, 5, 6]).reshape(&[2, 3]));
let bytes: &'static [u8] = &[1, 2, 3, 4, 5, 6];
let tokens = [
Token::Struct {
name: "TensorDef",
len: 5,
},
Token::Str("kind"),
Token::UnitVariant {
name: "KindDef",
variant: "Uint8",
},
Token::Str("shape"),
Token::Seq { len: Some(2) },
Token::I64(2),
Token::I64(3),
Token::SeqEnd,
Token::Str("requires_grad"),
Token::Bool(false),
Token::Str("byte_order"),
Token::UnitVariant {
name: "ByteOrder",
variant: match ByteOrder::native() {
ByteOrder::BigEndian => "BigEndian",
ByteOrder::LittleEndian => "LittleEndian",
},
},
Token::Str("data"),
Token::BorrowedBytes(bytes),
Token::StructEnd,
];
assert_tokens(&tensor, &tokens);
}
#[test]
fn to_from_tensordef() {
let t0 = Tensor::of_slice::<u8>(&[1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let td: TensorDef = (&t0).into();
let t1: Tensor = td.into();
assert_eq!(t0, t1);
}
#[test]
fn to_from_tensordef_transposed_stride() {
let t0 = Tensor::of_slice::<u8>(&[1, 2, 3, 4, 5, 6, 7, 8, 9])
.reshape(&[3, 3])
.tr();
let td: TensorDef = (&t0).into();
let t1: Tensor = td.into();
assert_eq!(t0, t1);
}
}