refloctopus 0.0.1

Speedy reflection-based serde transcoder
#![cfg(feature = "use_serde")]

use crate::metadata::*;
use crate::{Db, ShapedOutputLocation};

use alloc::{borrow::Cow, string::ToString};

use serde::de::Error as deError;

// TODO: field names
/// Wrapper for deserializing a value via reflection. You're better off using `Db::deserialize`.
pub(crate) struct Deserialize<'db, 'data, 'visitor> {
    pub(crate) db: &'db Db<'db>,
    pub(crate) name: &'static str,
    pub(crate) dst: ShapedOutputLocation<'db, 'data, 'visitor>,
}
impl<'db, 'data, 'visitor> Deserialize<'db, 'data, 'visitor> {
    unsafe fn subfield(&self, shape: &'visitor DataShape<'visitor>, ptr: *mut u8) -> Self {
        Self {
            dst: ShapedOutputLocation {
                shape,
                ptr,
                fields: self.dst.fields,
                _data: Default::default(),
            },
            ..*self
        }
    }
}

impl<'db, 'data, 'de, 'visitor> serde::de::DeserializeSeed<'de>
    for Deserialize<'db, 'data, 'visitor>
{
    type Value = ();
    fn deserialize<D>(self, src: D) -> Result<(), D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        match self.dst.shape {
            DataShape::Leaf(id) => self
                .db
                .deserialize_leaf(src, *id, &self.dst)
                .map_err(|e| serde::de::Error::custom(e.to_string())),
            DataShape::Tuple(fields) => src.deserialize_tuple(fields.len as usize, self),
            DataShape::Struct(DeclKind::Unit, ..) => src.deserialize_unit_struct(self.name, self),
            DataShape::Struct(DeclKind::Tuple, _, fields) => {
                src.deserialize_tuple_struct(self.name, fields.len as usize, self)
            }
            DataShape::Struct(DeclKind::Struct, labels, _) => {
                src.deserialize_struct(self.name, labels, self)
            }
            DataShape::Struct(DeclKind::Newtype, ..) => {
                src.deserialize_newtype_struct(self.name, self)
            }

            DataShape::Enum(variant_labels_for_serde, ..) => {
                src.deserialize_enum(self.name, variant_labels_for_serde, self)
            }
            DataShape::FixedArray { len, .. } => src.deserialize_tuple(*len, self),
            DataShape::Builtin(RustBuiltin::U8) => src.deserialize_u8(self),
            DataShape::Builtin(RustBuiltin::I8) => src.deserialize_i8(self),
            DataShape::Builtin(RustBuiltin::U16) => src.deserialize_u16(self),
            DataShape::Builtin(RustBuiltin::I16) => src.deserialize_i16(self),
            DataShape::Builtin(RustBuiltin::U32) => src.deserialize_u32(self),
            DataShape::Builtin(RustBuiltin::I32) => src.deserialize_i32(self),
            DataShape::Builtin(RustBuiltin::U64) => src.deserialize_u64(self),
            DataShape::Builtin(RustBuiltin::I64) => src.deserialize_i64(self),
            DataShape::Builtin(RustBuiltin::F32) => src.deserialize_f32(self),
            DataShape::Builtin(RustBuiltin::F64) => src.deserialize_f64(self),
            DataShape::Builtin(RustBuiltin::I128) => src.deserialize_i128(self),
            DataShape::Builtin(RustBuiltin::U128) => src.deserialize_u128(self),
            DataShape::Builtin(RustBuiltin::BOOLIN) => src.deserialize_bool(self),
            DataShape::Builtin(RustBuiltin::CHAR) => src.deserialize_char(self),
            DataShape::Slice(_element_shape) => Err(deError::custom(
                "can't deserialize a slice, where do i store the data?",
            )),
            DataShape::Ref(..) => Err(deError::custom(
                "can't deserialize a ref, where do i store the data?",
            )),
        }
    }
}

macro_rules! visit {
    ($($ty:ident, $visit:ident, $uppercase:ident);*) => {
        $(fn $visit<E: deError>(self, val: $ty) -> Result<Self::Value, E> {
            if !matches!(self.dst.shape, DataShape::Builtin(RustBuiltin::$uppercase)) {
                return Err(deError::custom(alloc::format!(
                    "unexpected {} when expecting {:?}",
                    stringify!($ty),
                    self.dst.shape,
                )));
            }
            unsafe { self.dst.ptr.cast::<$ty>().write(val) }
            Ok(())
        })*
    };
}

impl<'fields, 'db: 'fields, 'data, 'visitor, 'de> serde::de::Visitor<'de>
    for Deserialize<'db, 'data, 'visitor>
{
    type Value = ();

    fn visit_unit<E>(self) -> Result<Self::Value, E>
    where
        E: deError,
    {
        unsafe {
            self.dst.ptr.cast::<()>().write(());
        }
        Ok(())
    }

    fn visit_map<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
    where
        A: serde::de::MapAccess<'de>,
    {
        match &self.dst.shape {
            DataShape::Struct(DeclKind::Struct, labels_for_serde, fields) => {
                let fields = unsafe { self.dst.fields.array(fields) };

                while let Some(ix) = seq.next_key_seed(FieldIx {
                    labels_for_serde,
                    fields: &Cow::Borrowed(fields),
                })? {
                    // SAFETY: i hope ix is correct!
                    unsafe {
                        seq.next_value_seed(
                            self.subfield(&fields[ix].shape, self.dst.ptr.add(fields[ix].offset)),
                        )?;
                    }
                }
                // TODO: track uninitialized fields
            }
            _ => return Err(deError::custom("unexpected shape when visiting sequence")),
        }
        Ok(())
    }

    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
    where
        A: serde::de::SeqAccess<'de>,
    {
        match self.dst.shape {
            DataShape::Leaf(_) =>
            // unfortunately need to bounce through the SeqAccess to get a new Deserializer
            {
                return seq.next_element_seed(self).map(|_| ())
            }
            DataShape::FixedArray { shape, len, stride } => {
                // SAFETY: correctness of reflection data
                let shape = unsafe { self.dst.fields.shape(shape) };
                for ix in 0..*len {
                    seq.next_element_seed(unsafe {
                        self.subfield(shape, self.dst.ptr.add(ix * stride))
                    })?;
                }
            }
            DataShape::Tuple(fields)
            | DataShape::Struct(DeclKind::Tuple, _, fields)
            | DataShape::Struct(DeclKind::Struct, _, fields) => {
                for &Field { shape, offset, .. } in unsafe { self.dst.fields.array(fields) } {
                    {
                        seq.next_element_seed(unsafe {
                            // SAFETY: correctness of reflection data
                            self.subfield(&shape, self.dst.ptr.add(offset))
                        })?;
                    }
                }
            }
            _ => return Err(deError::custom("unexpected shape when visiting sequence")),
        }
        Ok(())
    }

    fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
    where
        A: serde::de::EnumAccess<'de>,
    {
        use serde::de::VariantAccess;
        let dst = &self.dst;
        match dst.shape {
            DataShape::Enum(variant_labels_for_serde, variants) => {
                let variants = unsafe { dst.fields.array(variants) };
                match data.variant_seed(VariantIx(variant_labels_for_serde, variants)) {
                    Ok((ix, variant)) => {
                        let arm = &variants[ix];
                        let fields = unsafe { dst.fields.array(&arm.fields) };
                        unsafe {
                            dst.write_discriminant(arm.discriminant);
                        }
                        match arm.decl_kind {
                            DeclKind::Unit => variant.unit_variant(),
                            DeclKind::Struct => {
                                variant.struct_variant(variant_labels_for_serde, self)
                            }
                            DeclKind::Newtype => variant.newtype_variant_seed(unsafe {
                                self.subfield(&fields[0].shape, dst.ptr)
                            }),
                            DeclKind::Tuple => {
                                // SAFETY: we change the runtime type here so that the tuple visitor code isn't duplicated, but
                                // now that we've verified the discriminant, the field metadata is accurate when viewed as
                                // a plain tuple.
                                unsafe {
                                    variant.tuple_variant(
                                        arm.fields.len as usize,
                                        self.subfield(&DataShape::Tuple(arm.fields), dst.ptr),
                                    )
                                }
                            }
                        }
                    }
                    Err(e) => Err(e),
                }
            }
            _shape => Err(deError::custom(
                "visited an enum when not expecting an enum",
            )),
        }
    }

    fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        use serde::de::DeserializeSeed;
        match self.dst.shape {
            DataShape::Struct(DeclKind::Newtype, _, singular_field) => {
                let field = unsafe { self.dst.fields.array(&singular_field) };
                if field.len() != 1 {
                    return Err(deError::custom("newtype struct had too many fields"));
                };
                unsafe {
                    self.subfield(&field[0].shape, self.dst.ptr)
                        .deserialize(deserializer)
                }
            }

            _ => Err(deError::custom("visit and shape disagree")),
        }
    }

    visit!(u8, visit_u8, U8);
    visit!(i8, visit_i8, I8);
    visit!(u16, visit_u16, U16);
    visit!(i16, visit_i16, I16);
    visit!(u32, visit_u32, U32);
    visit!(i32, visit_i32, I32);
    visit!(u64, visit_u64, U64);
    visit!(i64, visit_i64, I64);
    visit!(u128, visit_u128, U128);
    visit!(i128, visit_i128, I128);
    visit!(f32, visit_f32, F32);
    visit!(f64, visit_f64, F64);
    visit!(bool, visit_bool, BOOLIN);
    visit!(char, visit_char, CHAR);

    fn expecting(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {
        write!(fmt, "a sequence to fill the fields of {:?}", self.dst.shape)
    }
}

struct FieldIx<'db> {
    labels_for_serde: &'static [&'static str],
    fields: &'db [Field<'db>],
}

impl<'de> serde::de::Visitor<'de> for FieldIx<'_> {
    type Value = usize;
    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
    where
        E: deError,
    {
        if (v as usize) < self.fields.len() {
            Ok(v as usize)
        } else {
            Err(E::unknown_field(&v.to_string(), self.labels_for_serde))
        }
    }
    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: deError,
    {
        // TODO: faster scan
        for (ix, f) in self.fields.iter().enumerate() {
            if f.name == Some(v) {
                return Ok(ix);
            }
        }
        Err(E::invalid_value(
            serde::de::Unexpected::Str(v),
            &"a string that is one of the field labels",
        ))
    }
    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
    where
        E: deError,
    {
        for (ix, f) in self.fields.iter().enumerate() {
            if f.name.unwrap_or("").as_bytes() == v {
                return Ok(ix);
            }
        }
        Err(E::unknown_field(
            core::str::from_utf8(v).unwrap_or("<non-utf8 fieldname>"),
            self.labels_for_serde,
        ))
    }
    fn expecting(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {
        fmt.write_str("a struct field")
    }
}
impl<'de> serde::de::DeserializeSeed<'de> for FieldIx<'_> {
    type Value = usize;
    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        deserializer.deserialize_identifier(self)
    }
}
struct VariantIx<'db>(&'static [&'static str], &'db [EnumArm<'db>]);
impl<'de> serde::de::Visitor<'de> for VariantIx<'_> {
    type Value = usize;
    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
    where
        E: deError,
    {
        if (v as usize) < self.1.len() {
            Ok(v as usize)
        } else {
            Err(E::unknown_variant(&v.to_string(), self.0))
        }
    }
    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: deError,
    {
        // TODO: faster scan
        for (ix, f) in self.1.iter().enumerate() {
            if f.label == v {
                return Ok(ix);
            }
        }
        Err(E::unknown_variant(v, self.0))
    }
    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
    where
        E: deError,
    {
        for (ix, f) in self.1.iter().enumerate() {
            if f.label.as_bytes() == v {
                return Ok(ix);
            }
        }
        Err(E::unknown_field(
            core::str::from_utf8(v).unwrap_or("<non-utf8 fieldname>"),
            self.0,
        ))
    }
    fn expecting(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {
        fmt.write_str("a struct field")
    }
}
impl<'de> serde::de::DeserializeSeed<'de> for VariantIx<'_> {
    type Value = usize;
    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        deserializer.deserialize_identifier(self)
    }
}