use core::array;
use core::mem::MaybeUninit;
use bytes::Buf;
use crate::DecodeError;
use crate::encoding::DecodeContext;
use crate::encoding::WireType;
use crate::encoding::check_wire_type;
use crate::encoding::decode_varint;
use crate::encoding::skip_field;
use crate::traits::ArchivedProtoField;
use crate::traits::PrimitiveKind;
use crate::traits::ProtoArchive;
use crate::traits::ProtoDecode;
use crate::traits::ProtoDecoder;
use crate::traits::ProtoDefault;
use crate::traits::ProtoEncode;
use crate::traits::ProtoExt;
use crate::traits::ProtoFieldMerge;
use crate::traits::ProtoKind;
use crate::traits::ProtoShadowDecode;
use crate::traits::ProtoShadowEncode;
use crate::traits::buffer::RevWriter;
#[cfg(feature = "stable")]
#[inline]
#[allow(clippy::needless_pass_by_value)]
unsafe fn assume_init_array<T, const N: usize>(arr: [MaybeUninit<T>; N]) -> [T; N] {
let ptr = (&raw const arr).cast::<[T; N]>();
unsafe { core::ptr::read(ptr) }
}
#[cfg(not(feature = "stable"))]
#[inline]
#[allow(clippy::needless_pass_by_value)]
unsafe fn assume_init_array<T, const N: usize>(arr: [MaybeUninit<T>; N]) -> [T; N] {
unsafe { MaybeUninit::array_assume_init(arr) }
}
impl<T: ProtoExt, const N: usize> ProtoExt for [T; N] {
const KIND: ProtoKind = match T::KIND {
ProtoKind::Primitive(PrimitiveKind::U8) => ProtoKind::Bytes,
_ => ProtoKind::Repeated(&T::KIND),
};
const _REPEATED_SUPPORT: Option<&'static str> = match T::KIND {
ProtoKind::Primitive(PrimitiveKind::U8) => None,
_ => Some("Array"),
};
}
impl<T: ProtoFieldMerge + ProtoDefault, const N: usize> ProtoDecoder for [T; N] {
#[inline]
fn merge_field(value: &mut Self, tag: u32, wire_type: WireType, buf: &mut impl Buf, ctx: DecodeContext) -> Result<(), DecodeError> {
if tag == 1 {
Self::merge(value, wire_type, buf, ctx)
} else {
skip_field(wire_type, tag, buf, ctx)
}
}
#[inline]
fn merge(&mut self, wire_type: WireType, buf: &mut impl Buf, ctx: DecodeContext) -> Result<(), DecodeError> {
if T::KIND.is_bytes_kind() {
check_wire_type(WireType::LengthDelimited, wire_type)?;
let len = decode_varint(buf)? as usize;
if len != N {
return Err(DecodeError::new(format!(
"invalid length for fixed byte array: expected {N} got {len}"
)));
}
if len > buf.remaining() {
return Err(DecodeError::new("buffer underflow"));
}
let bytes: &mut [u8] = unsafe { core::slice::from_raw_parts_mut(self.as_mut_ptr().cast::<u8>(), self.len()) };
buf.copy_to_slice(bytes);
return Ok(());
}
match T::KIND {
ProtoKind::Primitive(_) | ProtoKind::SimpleEnum => {
if wire_type == WireType::LengthDelimited {
let len = decode_varint(buf)? as usize;
let mut slice = buf.take(len);
for v in self.iter_mut() {
if !slice.has_remaining() {
break;
}
T::merge_value(v, T::WIRE_TYPE, &mut slice, ctx)?;
}
debug_assert!(!slice.has_remaining());
} else {
for v in self.iter_mut() {
T::merge_value(v, wire_type, buf, ctx)?;
}
}
Ok(())
}
ProtoKind::String | ProtoKind::Bytes | ProtoKind::Message => {
for v in self.iter_mut() {
T::merge_value(v, wire_type, buf, ctx)?;
}
Ok(())
}
ProtoKind::Repeated(_) => unreachable!(),
}
}
}
impl<T: ProtoDefault, const N: usize> ProtoDefault for [T; N] {
#[inline]
fn proto_default() -> Self {
array::from_fn(|_| <T as ProtoDefault>::proto_default())
}
}
impl<T: ProtoDecode, const N: usize> ProtoDecode for [T; N]
where
T::ShadowDecoded: ProtoDecoder + ProtoExt,
{
type ShadowDecoded = [T::ShadowDecoded; N];
}
impl<T, U, const N: usize> ProtoShadowDecode<[U; N]> for [T; N]
where
T: ProtoShadowDecode<U>,
{
#[inline]
fn to_sun(self) -> Result<[U; N], DecodeError> {
let mut out: [MaybeUninit<U>; N] = [const { MaybeUninit::uninit() }; N];
for (i, elem) in self.into_iter().enumerate() {
match elem.to_sun() {
Ok(value) => {
out[i].write(value);
}
Err(err) => {
for entry in out.iter_mut().take(i) {
unsafe { entry.assume_init_drop() };
}
return Err(err);
}
}
}
Ok(unsafe { assume_init_array(out) })
}
}
impl<T, const N: usize> ProtoArchive for [T; N]
where
T: ProtoArchive + ProtoExt,
{
#[inline]
fn is_default(&self) -> bool {
self.iter().all(|item| <T as ProtoArchive>::is_default(item))
}
#[inline]
fn archive<const TAG: u32>(&self, w: &mut impl RevWriter) {
if T::KIND.is_bytes_kind() {
let bytes: &[u8] = unsafe { core::slice::from_raw_parts(self.as_ptr().cast::<u8>(), N) };
w.put_slice(bytes);
if TAG != 0 {
w.put_varint(bytes.len() as u64);
ArchivedProtoField::<TAG, Self>::put_key(w);
}
return;
}
match T::KIND {
ProtoKind::Primitive(_) | ProtoKind::SimpleEnum => {
let mark = w.mark();
for item in self.iter().rev() {
item.archive::<0>(w);
}
if TAG != 0 {
let payload_len = w.written_since(mark);
w.put_varint(payload_len as u64);
ArchivedProtoField::<TAG, Self>::put_key(w);
}
}
ProtoKind::String | ProtoKind::Bytes | ProtoKind::Message => {
for item in self.iter().rev() {
ArchivedProtoField::<TAG, T>::new_always(item, w);
}
}
ProtoKind::Repeated(_) => unreachable!(),
}
}
}
#[doc(hidden)]
pub struct ArrayShadow<'a, T: ProtoArchive + ProtoExt, const N: usize> {
slice: &'a [T],
}
impl<T: ProtoArchive + ProtoExt, const N: usize> ProtoExt for ArrayShadow<'_, T, N> {
const KIND: ProtoKind = <[T; N] as ProtoExt>::KIND;
const _REPEATED_SUPPORT: Option<&'static str> = <[T; N] as ProtoExt>::_REPEATED_SUPPORT;
}
impl<T: ProtoArchive + ProtoExt, const N: usize> ProtoArchive for ArrayShadow<'_, T, N> {
#[inline]
fn is_default(&self) -> bool {
self.slice.iter().all(|item| <T as ProtoArchive>::is_default(item))
}
#[inline]
fn archive<const TAG: u32>(&self, w: &mut impl RevWriter) {
self.slice.archive::<TAG>(w);
}
}
impl<T: ProtoEncode, const N: usize> ProtoEncode for [T; N]
where
for<'a> T::Shadow<'a>: ProtoArchive + ProtoExt,
for<'a> T: 'a + ProtoExt + ProtoArchive,
{
type Shadow<'a> = ArrayShadow<'a, T, N>;
}
impl<'a, T: ProtoArchive + ProtoExt, const N: usize> ProtoShadowEncode<'a, [T; N]> for ArrayShadow<'a, T, N> {
#[inline]
fn from_sun(value: &'a [T; N]) -> Self {
ArrayShadow { slice: value.as_slice() }
}
}