use crate::coder::{Decoder, Encoder, Result, View};
use crate::derive::vec::{unsafe_wild_copy, VecDecoder, VecEncoder};
use crate::derive::{Decode, Encode};
use crate::error::err;
use crate::str::{StrDecoder, StrEncoder};
use arrayvec::{ArrayString, ArrayVec};
use core::mem::MaybeUninit;
impl<const N: usize> Encoder<ArrayString<N>> for StrEncoder {
#[inline(always)]
fn encode(&mut self, t: &ArrayString<N>) {
if N > 64 || cfg!(miri) {
self.encode(t.as_str());
return;
}
let s = t.as_str();
self.0.lengths.encode_less_than_255(s.len());
let primitives = self.0.elements.as_primitive().unwrap();
primitives.reserve(N); let dst = primitives.end_ptr();
unsafe {
*(dst as *mut MaybeUninit<[u8; N]>) = *(s.as_ptr() as *const MaybeUninit<[u8; N]>);
primitives.set_end_ptr(dst.add(s.len()));
}
}
#[inline(never)]
fn encode_vectored<'a>(&mut self, i: impl Iterator<Item = &'a ArrayString<N>> + Clone) {
if N > 64 || cfg!(miri) {
self.encode_vectored(i.map(|t| t.as_str()));
return;
}
let primitives = self.0.elements.as_primitive().unwrap();
primitives.reserve(i.size_hint().1.unwrap() * N);
let mut dst = primitives.end_ptr();
for t in i {
let s = t.as_str();
self.0.lengths.encode_less_than_255(s.len());
unsafe {
*(dst as *mut MaybeUninit<[u8; N]>) = *(s.as_ptr() as *const MaybeUninit<[u8; N]>);
dst = dst.add(s.len());
}
}
primitives.set_end_ptr(dst);
}
}
impl<const N: usize> Encode for ArrayString<N> {
type Encoder = StrEncoder;
}
#[derive(Default)]
pub struct ArrayStringDecoder<'a, const N: usize>(StrDecoder<'a>);
impl<'a, const N: usize> View<'a> for ArrayStringDecoder<'a, N> {
fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
self.0.populate(input, length)?;
if unsafe { self.0.lengths.any_greater_than::<N>(length) } {
return err("invalid ArrayString");
}
Ok(())
}
}
impl<'a, const N: usize> Decoder<'a, ArrayString<N>> for ArrayStringDecoder<'a, N> {
#[inline(always)]
fn decode_in_place(&mut self, out: &mut MaybeUninit<ArrayString<N>>) {
let s: &str = self.0.decode();
let array_string = out.write(ArrayString::new());
if N > 64 || cfg!(miri) {
unsafe { array_string.try_push_str(s).unwrap_unchecked() };
return;
}
if s.is_empty() {
return;
}
unsafe {
let src = s.as_ptr();
let dst = array_string.as_mut_str().as_mut_ptr();
let n = s.len();
unsafe_wild_copy!([u8; N], src, dst, n);
array_string.set_len(s.len());
}
}
}
impl<'a, const N: usize> Decode<'a> for ArrayString<N> {
type Decoder = ArrayStringDecoder<'a, N>;
}
#[inline(always)]
fn as_slice_assert_len<T, const N: usize>(t: &ArrayVec<T, N>) -> &[T] {
let s = t.as_slice();
if s.len() > N {
unsafe { core::hint::unreachable_unchecked() };
}
s
}
impl<T: Encode, const N: usize> Encoder<ArrayVec<T, N>> for VecEncoder<T> {
#[inline(always)]
fn encode(&mut self, t: &ArrayVec<T, N>) {
self.encode(as_slice_assert_len(t));
}
#[inline(always)]
fn encode_vectored<'a>(&mut self, i: impl Iterator<Item = &'a ArrayVec<T, N>> + Clone)
where
ArrayVec<T, N>: 'a,
{
self.encode_vectored(i.map(as_slice_assert_len));
}
}
impl<T: Encode, const N: usize> Encode for ArrayVec<T, N> {
type Encoder = VecEncoder<T>;
}
pub struct ArrayVecDecoder<'a, T: Decode<'a>, const N: usize>(VecDecoder<'a, T>);
impl<'a, T: Decode<'a>, const N: usize> Default for ArrayVecDecoder<'a, T, N> {
fn default() -> Self {
Self(Default::default())
}
}
impl<'a, T: Decode<'a>, const N: usize> View<'a> for ArrayVecDecoder<'a, T, N> {
fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
self.0.populate(input, length)?;
if unsafe { self.0.lengths.any_greater_than::<N>(length) } {
return err("invalid ArrayVec");
}
Ok(())
}
}
impl<'a, T: Decode<'a>, const N: usize> Decoder<'a, ArrayVec<T, N>> for ArrayVecDecoder<'a, T, N> {
#[inline(always)]
fn decode_in_place(&mut self, out: &mut MaybeUninit<ArrayVec<T, N>>) {
unsafe {
let av = out.write(ArrayVec::new());
let n = self.0.lengths.decode();
for i in 0..n {
self.0
.elements
.decode_in_place(&mut *(av.as_mut_ptr().add(i) as *mut MaybeUninit<T>));
}
av.set_len(n);
}
}
}
impl<'a, T: Decode<'a>, const N: usize> Decode<'a> for ArrayVec<T, N> {
type Decoder = ArrayVecDecoder<'a, T, N>;
}
#[cfg(test)]
mod tests {
use crate::{decode, encode};
use alloc::vec::Vec;
use arrayvec::{ArrayString, ArrayVec};
#[test]
fn array_string() {
let mut v = ArrayString::<2>::default();
v.push('0');
v.push('1');
let b = encode(&v);
assert!(decode::<ArrayString<1>>(&b).is_err());
assert_eq!(decode::<ArrayString<2>>(&b).unwrap(), v);
assert_eq!(decode::<ArrayString<3>>(&b).unwrap().as_str(), v.as_str());
assert!(decode::<ArrayString<0>>(&encode(&ArrayString::<0>::default())).is_ok());
}
#[test]
fn array_vec() {
let mut v = ArrayVec::<u8, 2>::default();
v.push(0);
v.push(1);
let b = encode(&v);
assert!(decode::<ArrayVec<u8, 1>>(&b).is_err());
assert_eq!(decode::<ArrayVec<u8, 2>>(&b).unwrap(), v);
assert_eq!(
decode::<ArrayVec<u8, 3>>(&b).unwrap().as_slice(),
v.as_slice()
);
assert_eq!(
decode::<ArrayVec<u8, 500>>(&b).unwrap().as_slice(),
v.as_slice()
);
assert!(decode::<ArrayVec<u8, 0>>(&encode(&ArrayVec::<u8, 0>::default())).is_ok());
let mut v = ArrayVec::<u8, 500>::default();
for i in 0..500 {
v.push(i as u8);
}
let b = encode(&v);
assert!(decode::<ArrayVec<u8, 499>>(&b).is_err());
assert_eq!(decode::<ArrayVec<u8, 500>>(&b).unwrap(), v);
}
#[test]
fn array_string_bug() {
type T = ArrayString<1>;
let mut v = T::default();
v.push(' ');
let mut buffer = crate::Buffer::new();
buffer
.decode::<Vec<T>>(&encode::<Vec<T>>(&vec![v])[..2])
.unwrap_err();
buffer.decode::<Vec<T>>(&encode::<Vec<T>>(&vec![])).unwrap();
}
}