extern crate arrayfire;
extern crate serde;
use arrayfire::{Array, DType, Dim4, HasAfEnum};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::de::{SeqAccess, Visitor};
use serde::ser::SerializeTuple;
use std::fmt;
pub fn serialize<T, S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
for<'a> Ser<'a, T>: Serialize,
S: Serializer,
{
Ser::new(value).serialize(serializer)
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
De<T>: Deserialize<'de>,
D: Deserializer<'de>,
{
De::deserialize(deserializer).map(De::into_inner)
}
pub struct Ser<'a, T: 'a>(&'a T);
impl<'a, T> Ser<'a, T>
where
Ser<'a, T>: serde::Serialize,
{
pub fn new(value: &'a T) -> Self {
Ser(value)
}
}
pub struct De<T>(T);
impl<'de, T> De<T>
where
De<T>: Deserialize<'de>,
{
pub fn into_inner(self) -> T {
self.0
}
}
struct Serde<T>(pub T);
impl<T> Serialize for Serde<T>
where
for<'de> De<T>: Deserialize<'de>,
for<'a> Ser<'a, T>: Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Ser::new(&self.0).serialize(serializer)
}
}
impl<'b, T> Deserialize<'b> for Serde<T>
where
for<'de> De<T>: Deserialize<'de>,
for<'a> Ser<'a, T>: Serialize,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'b>,
{
De::deserialize(deserializer).map(De::into_inner).map(Serde)
}
}
impl<'a> Serialize for Ser<'a, Dim4> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut tup = serializer.serialize_tuple(4)?;
tup.serialize_element(&self.0.get()[0])?;
tup.serialize_element(&self.0.get()[1])?;
tup.serialize_element(&self.0.get()[2])?;
tup.serialize_element(&self.0.get()[3])?;
tup.end()
}
}
impl<'de> Deserialize<'de> for De<Dim4> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Dim4Visitor;
impl<'de> Visitor<'de> for Dim4Visitor {
type Value = De<Dim4>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "tuple as a seq of 4 elements")
}
fn visit_seq<V>(self, mut visitor: V) -> Result<Self::Value, V::Error>
where
V: SeqAccess<'de>,
{
let d0: u64 = visitor.next_element()?.expect("has element");
let d1: u64 = visitor.next_element()?.expect("has element");
let d2: u64 = visitor.next_element()?.expect("has element");
let d3: u64 = visitor.next_element()?.expect("has element");
let dim = Dim4::new(&[d0, d1, d2, d3]);
Ok(De(dim))
}
}
deserializer.deserialize_tuple(4, Dim4Visitor)
}
}
impl<'a> Serialize for Ser<'a, DType> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let enum_value = *self.0 as u8;
serializer.serialize_u8(enum_value)
}
}
impl<'de> Deserialize<'de> for De<DType> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct DTypeVisitor;
impl<'de> Visitor<'de> for DTypeVisitor {
type Value = De<DType>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "u8")
}
fn visit_u8<E>(self, value: u8) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let dtype: DType = unsafe { std::mem::transmute(i32::from(value)) };
Ok(De(dtype))
}
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let dtype: DType = unsafe { std::mem::transmute(value as i32) };
Ok(De(dtype))
}
}
deserializer.deserialize_u8(DTypeVisitor)
}
}
impl<'a> Serialize for Ser<'a, Array> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let array: &Array = self.0;
let dim = array.dims();
let dtype: DType = array.get_type();
let mut tup = serializer.serialize_tuple(3)?;
tup.serialize_element(&Ser::new(&dtype))?;
tup.serialize_element(&Ser::new(&dim))?;
fn get_data<T: HasAfEnum>(array: &Array) -> Vec<T> {
let mut data: Vec<T> = Vec::with_capacity(array.elements());
unsafe {
data.set_len(array.elements());
}
array.host(&mut data.as_mut_slice());
data
}
match dtype {
DType::F32 => tup.serialize_element(&get_data::<f32>(array))?,
DType::F64 => tup.serialize_element(&get_data::<f64>(array))?,
DType::S16 => tup.serialize_element(&get_data::<i16>(array))?,
DType::S32 => tup.serialize_element(&get_data::<i32>(array))?,
DType::S64 => tup.serialize_element(&get_data::<i64>(array))?,
DType::U16 => tup.serialize_element(&get_data::<u16>(array))?,
DType::U32 => tup.serialize_element(&get_data::<u32>(array))?,
DType::U64 => tup.serialize_element(&get_data::<u64>(array))?,
DType::B8 => tup.serialize_element(&get_data::<bool>(array))?,
_ => panic!("unimplemented serialization for complex types!"),
}
tup.end()
}
}
impl<'de> Deserialize<'de> for De<Array> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ArrayVisitor;
impl<'de> Visitor<'de> for ArrayVisitor {
type Value = De<Array>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "struct ArrayStruct")
}
fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
where
V: SeqAccess<'de>,
{
let dtype: De<DType> = seq.next_element()?.expect("has element");
let dim: De<Dim4> = seq.next_element()?.expect("has element");
fn get_array<T: HasAfEnum>(data: Option<Vec<T>>, dim: &Dim4) -> Array {
let data: Vec<T> = data.expect("has vector of elements");
Array::new::<T>(data.as_slice(), *dim)
}
match dtype.0 {
DType::F32 => Ok(De(get_array::<f32>(seq.next_element()?, &dim.0))),
DType::F64 => Ok(De(get_array::<f64>(seq.next_element()?, &dim.0))),
DType::S16 => Ok(De(get_array::<i16>(seq.next_element()?, &dim.0))),
DType::S32 => Ok(De(get_array::<i32>(seq.next_element()?, &dim.0))),
DType::S64 => Ok(De(get_array::<i64>(seq.next_element()?, &dim.0))),
DType::U16 => Ok(De(get_array::<u16>(seq.next_element()?, &dim.0))),
DType::U32 => Ok(De(get_array::<u32>(seq.next_element()?, &dim.0))),
DType::U64 => Ok(De(get_array::<u64>(seq.next_element()?, &dim.0))),
DType::B8 => Ok(De(get_array::<bool>(seq.next_element()?, &dim.0))),
_ => panic!("unimplemented deserialization for complex types!"),
}
}
}
deserializer.deserialize_tuple(3, ArrayVisitor)
}
}