#![cfg(feature = "use_serde")]
use crate::metadata::*;
use crate::{Db, ShapedOutputLocation};
use alloc::{borrow::Cow, string::ToString};
use serde::de::Error as deError;
pub(crate) struct Deserialize<'db, 'data, 'visitor> {
pub(crate) db: &'db Db<'db>,
pub(crate) name: &'static str,
pub(crate) dst: ShapedOutputLocation<'db, 'data, 'visitor>,
}
impl<'db, 'data, 'visitor> Deserialize<'db, 'data, 'visitor> {
unsafe fn subfield(&self, shape: &'visitor DataShape<'visitor>, ptr: *mut u8) -> Self {
Self {
dst: ShapedOutputLocation {
shape,
ptr,
fields: self.dst.fields,
_data: Default::default(),
},
..*self
}
}
}
impl<'db, 'data, 'de, 'visitor> serde::de::DeserializeSeed<'de>
for Deserialize<'db, 'data, 'visitor>
{
type Value = ();
fn deserialize<D>(self, src: D) -> Result<(), D::Error>
where
D: serde::Deserializer<'de>,
{
match self.dst.shape {
DataShape::Leaf(id) => self
.db
.deserialize_leaf(src, *id, &self.dst)
.map_err(|e| serde::de::Error::custom(e.to_string())),
DataShape::Tuple(fields) => src.deserialize_tuple(fields.len as usize, self),
DataShape::Struct(DeclKind::Unit, ..) => src.deserialize_unit_struct(self.name, self),
DataShape::Struct(DeclKind::Tuple, _, fields) => {
src.deserialize_tuple_struct(self.name, fields.len as usize, self)
}
DataShape::Struct(DeclKind::Struct, labels, _) => {
src.deserialize_struct(self.name, labels, self)
}
DataShape::Struct(DeclKind::Newtype, ..) => {
src.deserialize_newtype_struct(self.name, self)
}
DataShape::Enum(variant_labels_for_serde, ..) => {
src.deserialize_enum(self.name, variant_labels_for_serde, self)
}
DataShape::FixedArray { len, .. } => src.deserialize_tuple(*len, self),
DataShape::Builtin(RustBuiltin::U8) => src.deserialize_u8(self),
DataShape::Builtin(RustBuiltin::I8) => src.deserialize_i8(self),
DataShape::Builtin(RustBuiltin::U16) => src.deserialize_u16(self),
DataShape::Builtin(RustBuiltin::I16) => src.deserialize_i16(self),
DataShape::Builtin(RustBuiltin::U32) => src.deserialize_u32(self),
DataShape::Builtin(RustBuiltin::I32) => src.deserialize_i32(self),
DataShape::Builtin(RustBuiltin::U64) => src.deserialize_u64(self),
DataShape::Builtin(RustBuiltin::I64) => src.deserialize_i64(self),
DataShape::Builtin(RustBuiltin::F32) => src.deserialize_f32(self),
DataShape::Builtin(RustBuiltin::F64) => src.deserialize_f64(self),
DataShape::Builtin(RustBuiltin::I128) => src.deserialize_i128(self),
DataShape::Builtin(RustBuiltin::U128) => src.deserialize_u128(self),
DataShape::Builtin(RustBuiltin::BOOLIN) => src.deserialize_bool(self),
DataShape::Builtin(RustBuiltin::CHAR) => src.deserialize_char(self),
DataShape::Slice(_element_shape) => Err(deError::custom(
"can't deserialize a slice, where do i store the data?",
)),
DataShape::Ref(..) => Err(deError::custom(
"can't deserialize a ref, where do i store the data?",
)),
}
}
}
macro_rules! visit {
($($ty:ident, $visit:ident, $uppercase:ident);*) => {
$(fn $visit<E: deError>(self, val: $ty) -> Result<Self::Value, E> {
if !matches!(self.dst.shape, DataShape::Builtin(RustBuiltin::$uppercase)) {
return Err(deError::custom(alloc::format!(
"unexpected {} when expecting {:?}",
stringify!($ty),
self.dst.shape,
)));
}
unsafe { self.dst.ptr.cast::<$ty>().write(val) }
Ok(())
})*
};
}
impl<'fields, 'db: 'fields, 'data, 'visitor, 'de> serde::de::Visitor<'de>
for Deserialize<'db, 'data, 'visitor>
{
type Value = ();
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: deError,
{
unsafe {
self.dst.ptr.cast::<()>().write(());
}
Ok(())
}
fn visit_map<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
match &self.dst.shape {
DataShape::Struct(DeclKind::Struct, labels_for_serde, fields) => {
let fields = unsafe { self.dst.fields.array(fields) };
while let Some(ix) = seq.next_key_seed(FieldIx {
labels_for_serde,
fields: &Cow::Borrowed(fields),
})? {
unsafe {
seq.next_value_seed(
self.subfield(&fields[ix].shape, self.dst.ptr.add(fields[ix].offset)),
)?;
}
}
}
_ => return Err(deError::custom("unexpected shape when visiting sequence")),
}
Ok(())
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
match self.dst.shape {
DataShape::Leaf(_) =>
{
return seq.next_element_seed(self).map(|_| ())
}
DataShape::FixedArray { shape, len, stride } => {
let shape = unsafe { self.dst.fields.shape(shape) };
for ix in 0..*len {
seq.next_element_seed(unsafe {
self.subfield(shape, self.dst.ptr.add(ix * stride))
})?;
}
}
DataShape::Tuple(fields)
| DataShape::Struct(DeclKind::Tuple, _, fields)
| DataShape::Struct(DeclKind::Struct, _, fields) => {
for &Field { shape, offset, .. } in unsafe { self.dst.fields.array(fields) } {
{
seq.next_element_seed(unsafe {
self.subfield(&shape, self.dst.ptr.add(offset))
})?;
}
}
}
_ => return Err(deError::custom("unexpected shape when visiting sequence")),
}
Ok(())
}
fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
where
A: serde::de::EnumAccess<'de>,
{
use serde::de::VariantAccess;
let dst = &self.dst;
match dst.shape {
DataShape::Enum(variant_labels_for_serde, variants) => {
let variants = unsafe { dst.fields.array(variants) };
match data.variant_seed(VariantIx(variant_labels_for_serde, variants)) {
Ok((ix, variant)) => {
let arm = &variants[ix];
let fields = unsafe { dst.fields.array(&arm.fields) };
unsafe {
dst.write_discriminant(arm.discriminant);
}
match arm.decl_kind {
DeclKind::Unit => variant.unit_variant(),
DeclKind::Struct => {
variant.struct_variant(variant_labels_for_serde, self)
}
DeclKind::Newtype => variant.newtype_variant_seed(unsafe {
self.subfield(&fields[0].shape, dst.ptr)
}),
DeclKind::Tuple => {
unsafe {
variant.tuple_variant(
arm.fields.len as usize,
self.subfield(&DataShape::Tuple(arm.fields), dst.ptr),
)
}
}
}
}
Err(e) => Err(e),
}
}
_shape => Err(deError::custom(
"visited an enum when not expecting an enum",
)),
}
}
fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::DeserializeSeed;
match self.dst.shape {
DataShape::Struct(DeclKind::Newtype, _, singular_field) => {
let field = unsafe { self.dst.fields.array(&singular_field) };
if field.len() != 1 {
return Err(deError::custom("newtype struct had too many fields"));
};
unsafe {
self.subfield(&field[0].shape, self.dst.ptr)
.deserialize(deserializer)
}
}
_ => Err(deError::custom("visit and shape disagree")),
}
}
visit!(u8, visit_u8, U8);
visit!(i8, visit_i8, I8);
visit!(u16, visit_u16, U16);
visit!(i16, visit_i16, I16);
visit!(u32, visit_u32, U32);
visit!(i32, visit_i32, I32);
visit!(u64, visit_u64, U64);
visit!(i64, visit_i64, I64);
visit!(u128, visit_u128, U128);
visit!(i128, visit_i128, I128);
visit!(f32, visit_f32, F32);
visit!(f64, visit_f64, F64);
visit!(bool, visit_bool, BOOLIN);
visit!(char, visit_char, CHAR);
fn expecting(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(fmt, "a sequence to fill the fields of {:?}", self.dst.shape)
}
}
struct FieldIx<'db> {
labels_for_serde: &'static [&'static str],
fields: &'db [Field<'db>],
}
impl<'de> serde::de::Visitor<'de> for FieldIx<'_> {
type Value = usize;
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: deError,
{
if (v as usize) < self.fields.len() {
Ok(v as usize)
} else {
Err(E::unknown_field(&v.to_string(), self.labels_for_serde))
}
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: deError,
{
for (ix, f) in self.fields.iter().enumerate() {
if f.name == Some(v) {
return Ok(ix);
}
}
Err(E::invalid_value(
serde::de::Unexpected::Str(v),
&"a string that is one of the field labels",
))
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: deError,
{
for (ix, f) in self.fields.iter().enumerate() {
if f.name.unwrap_or("").as_bytes() == v {
return Ok(ix);
}
}
Err(E::unknown_field(
core::str::from_utf8(v).unwrap_or("<non-utf8 fieldname>"),
self.labels_for_serde,
))
}
fn expecting(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {
fmt.write_str("a struct field")
}
}
impl<'de> serde::de::DeserializeSeed<'de> for FieldIx<'_> {
type Value = usize;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_identifier(self)
}
}
struct VariantIx<'db>(&'static [&'static str], &'db [EnumArm<'db>]);
impl<'de> serde::de::Visitor<'de> for VariantIx<'_> {
type Value = usize;
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: deError,
{
if (v as usize) < self.1.len() {
Ok(v as usize)
} else {
Err(E::unknown_variant(&v.to_string(), self.0))
}
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: deError,
{
for (ix, f) in self.1.iter().enumerate() {
if f.label == v {
return Ok(ix);
}
}
Err(E::unknown_variant(v, self.0))
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: deError,
{
for (ix, f) in self.1.iter().enumerate() {
if f.label.as_bytes() == v {
return Ok(ix);
}
}
Err(E::unknown_field(
core::str::from_utf8(v).unwrap_or("<non-utf8 fieldname>"),
self.0,
))
}
fn expecting(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {
fmt.write_str("a struct field")
}
}
impl<'de> serde::de::DeserializeSeed<'de> for VariantIx<'_> {
type Value = usize;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_identifier(self)
}
}