use serde::de::{Deserialize, Deserializer, Error as _};
use serde::ser::{Serialize, SerializeTuple as _, Serializer};
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::{
fmt::Debug,
ops::{Index, IndexMut},
};
pub trait ArrayLike<Idx>:
Index<Idx, Output = Self::Elem> + IndexMut<Idx, Output = Self::Elem>
{
const LEN: usize;
type Elem;
fn from_fn<F: FnMut(Idx) -> Self::Elem>(cb: F) -> Self;
}
impl<T, const N: usize> ArrayLike<usize> for [T; N] {
const LEN: usize = N;
type Elem = T;
fn from_fn<F: FnMut(usize) -> T>(cb: F) -> Self {
core::array::from_fn::<T, N, F>(cb)
}
}
#[repr(transparent)]
#[derive(Clone, Debug)]
pub struct ArrayWrap<T, const N: usize>(pub [T; N]);
impl<T: Serialize, const N: usize> Serialize for ArrayWrap<T, N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut seq = serializer.serialize_tuple(N)?;
for i in 0..N {
seq.serialize_element(&self.0[i])?;
}
seq.end()
}
}
impl<'de, T, const N: usize> Deserialize<'de> for ArrayWrap<T, N>
where
T: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Visitor<T, const N: usize>(PhantomData<[T; N]>);
impl<'de, T: Deserialize<'de>, const N: usize> serde::de::Visitor<'de> for Visitor<T, N> {
type Value = ArrayWrap<T, N>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_fmt(format_args!("Array of Length {}", N))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut arr = [const { MaybeUninit::<T>::uninit() }; N];
let mut i = 0;
let err = loop {
if i >= N {
break None;
}
let elem = seq.next_element::<T>();
match elem {
Ok(Some(val)) => arr[i] = MaybeUninit::new(val),
Ok(None) => {
break Some(A::Error::custom::<String>(String::from(
"Sequence length does not match array length",
)));
}
Err(e) => break Some(e),
}
i += 1;
};
if let Some(e) = err {
for item in arr.iter_mut().take(i) {
unsafe {
item.assume_init_drop();
}
}
return Err(e);
}
Ok(ArrayWrap(unsafe {
std::mem::transmute_copy::<_, [T; N]>(&arr)
}))
}
}
deserializer.deserialize_tuple(N, Visitor::<T, N>(PhantomData))
}
}
pub fn serialize<S, T, const N: usize>(data: &[T; N], ser: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: Serialize,
{
let arr: &ArrayWrap<T, N> = unsafe { std::mem::transmute(data) };
arr.serialize(ser)
}
pub fn deserialize<'de, D, T, const N: usize>(deserialize: D) -> Result<[T; N], D::Error>
where
D: Deserializer<'de>,
T: Deserialize<'de>,
{
ArrayWrap::<T, N>::deserialize(deserialize).map(|val| val.0)
}
pub mod vec {
use super::ArrayWrap;
use serde::{Deserialize, Deserializer, Serialize, Serializer, de, ser::SerializeSeq};
use std::{fmt, marker::PhantomData};
pub fn serialize<S, T, const N: usize>(data: &Vec<[T; N]>, ser: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: Serialize,
{
let mut s = ser.serialize_seq(Some(data.len()))?;
for array in data {
let array = unsafe { std::mem::transmute::<&[T; N], &ArrayWrap<T, N>>(array) };
s.serialize_element(array)?;
}
s.end()
}
pub fn deserialize<'de, D, T, const N: usize>(deserialize: D) -> Result<Vec<[T; N]>, D::Error>
where
D: Deserializer<'de>,
T: Deserialize<'de>,
{
struct Visitor<T, const N: usize> {
_marker: PhantomData<T>,
}
impl<'de, T, const N: usize> de::Visitor<'de> for Visitor<T, N>
where
T: Deserialize<'de>,
{
type Value = Vec<[T; N]>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "a vector of arrays of size {}", N)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
let mut arr: Vec<[T; N]> = Vec::new();
if let Some(size) = seq.size_hint() {
arr.reserve(size);
}
loop {
match seq.next_element() {
Ok(Some(ArrayWrap(val))) => arr.push(val),
Ok(None) => break,
Err(e) => return Err(e),
}
}
Ok(arr)
}
}
deserialize.deserialize_seq(Visitor::<T, N> {
_marker: PhantomData,
})
}
}