use serde::de::{DeserializeSeed, Error, IgnoredAny, IntoDeserializer, SeqAccess, Visitor};
use serde::{Deserialize, Deserializer};
use std::borrow::BorrowMut;
use std::boxed::Box;
use std::vec;
use std::vec::Vec;
macro_rules! forward_visitors {
($(fn $method:ident ($arg:ty);)*) => ($(
fn $method<E: Error>(self, arg: $arg) -> Result<Self::Value, E> {
self.deserialize_num(arg)
}
)*);
}
pub trait Shape: BorrowMut<[usize]> {
const MIN_DIMS: usize;
const MAX_DIMS: usize;
fn new_zeroed(dims: usize) -> Self;
fn dim_len(&self, dims: usize) -> Option<usize> {
self.borrow().get(dims).copied()
}
fn set_dim_len(&mut self, dims: usize, value: usize) {
debug_assert_eq!(self.dim_len(dims), Some(0));
self.borrow_mut()[dims] = value;
}
}
impl Shape for Box<[usize]> {
const MIN_DIMS: usize = 0;
const MAX_DIMS: usize = usize::MAX;
fn new_zeroed(dims: usize) -> Self {
vec![0; dims].into_boxed_slice()
}
}
impl<const DIMS: usize> Shape for [usize; DIMS] {
const MIN_DIMS: usize = DIMS;
const MAX_DIMS: usize = DIMS;
fn new_zeroed(dims: usize) -> Self {
debug_assert_eq!(dims, DIMS);
[0; DIMS]
}
}
#[cfg(feature = "arrayvec")]
impl<const MAX_DIMS: usize> Shape for arrayvec::ArrayVec<usize, MAX_DIMS> {
const MIN_DIMS: usize = 0;
const MAX_DIMS: usize = MAX_DIMS;
fn new_zeroed(dims: usize) -> Self {
debug_assert!(dims <= MAX_DIMS);
let mut shape = Self::new();
shape.extend(core::iter::repeat(0).take(dims));
shape
}
}
#[derive(Debug)]
struct Context<T, S> {
data: Vec<T>,
shape: Option<S>,
current_dim: usize,
}
impl<'de, T: Deserialize<'de>, S: Shape> Context<T, S> {
fn got_number<E: Error>(&mut self) -> Result<(), E> {
match &self.shape {
Some(shape) => {
if self.current_dim < shape.borrow().len() {
return Err(E::invalid_type(
serde::de::Unexpected::Other("a single number"),
&"a sequence",
));
}
}
None => {
if self.current_dim < S::MIN_DIMS {
return Err(Error::custom(format_args!(
"didn't reach the expected minimum dims {}, got {}",
S::MIN_DIMS,
self.current_dim,
)));
}
self.shape = Some(S::new_zeroed(self.current_dim));
}
}
Ok(())
}
fn deserialize_num_from<D: Deserializer<'de>>(
&mut self,
deserializer: D,
) -> Result<(), D::Error> {
self.got_number()?;
let value = T::deserialize(deserializer)?;
self.data.push(value);
Ok(())
}
fn deserialize_num<E: Error>(&mut self, arg: impl IntoDeserializer<'de, E>) -> Result<(), E> {
self.deserialize_num_from(arg.into_deserializer())
}
}
impl<'de, T: Deserialize<'de>, S: Shape> Visitor<'de> for &mut Context<T, S> {
type Value = ();
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a sequence or a single number")
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
if let Some(shape) = &self.shape {
let expected_len = shape.dim_len(self.current_dim).ok_or_else(|| {
Error::invalid_type(serde::de::Unexpected::Seq, &"a single number")
})?;
self.current_dim += 1;
for _ in 0..expected_len {
seq.next_element_seed(&mut *self)?
.ok_or_else(|| Error::custom("expected more elements"))?;
}
if seq.next_element::<IgnoredAny>()?.is_some() {
return Err(Error::custom("expected end of sequence"));
}
self.current_dim -= 1;
} else {
debug_assert!(self.shape.is_none());
self.current_dim += 1;
if self.current_dim > S::MAX_DIMS {
return Err(Error::custom(format_args!(
"maximum dims of {} exceeded",
S::MAX_DIMS
)));
}
let mut len = 0;
while seq.next_element_seed(&mut *self)?.is_some() {
len += 1;
}
self.current_dim -= 1;
let shape = self
.shape
.as_mut()
.expect("internal error: shape should be allocated by now");
shape.set_dim_len(self.current_dim, len);
}
Ok(())
}
forward_visitors! {
fn visit_i8(i8);
fn visit_i16(i16);
fn visit_i32(i32);
fn visit_i64(i64);
fn visit_u8(u8);
fn visit_u16(u16);
fn visit_u32(u32);
fn visit_u64(u64);
fn visit_f32(f32);
fn visit_f64(f64);
fn visit_i128(i128);
fn visit_u128(u128);
}
fn visit_newtype_struct<D: Deserializer<'de>>(
self,
deserializer: D,
) -> Result<Self::Value, D::Error> {
self.deserialize_num_from(deserializer)
}
}
impl<'de, T: Deserialize<'de>, S: Shape> DeserializeSeed<'de> for &mut Context<T, S> {
type Value = ();
fn deserialize<D>(self, deserializer: D) -> Result<(), D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(self)
}
}
pub trait MakeNDim {
type Shape: Shape;
type Item;
fn from_shape_and_data(shape: Self::Shape, data: Vec<Self::Item>) -> Self;
}
pub fn deserialize<'de, A, D>(deserializer: D) -> Result<A, D::Error>
where
A: MakeNDim,
A::Item: Deserialize<'de>,
D: Deserializer<'de>,
{
let mut context = Context {
data: Vec::new(),
shape: None,
current_dim: 0,
};
deserializer.deserialize_any(&mut context)?;
Ok(A::from_shape_and_data(context.shape.unwrap(), context.data))
}