use facet::Facet;
use facet_core::{Def, ScalarType, StructKind, Type, UserType};
use facet_reflect::Peek;
use crate::encode;
use crate::error::SerializeError;
#[derive(Debug, Clone, Copy)]
pub struct SizeField(pub(crate) usize);
pub trait Writer {
fn write_byte(&mut self, byte: u8);
fn write_bytes(&mut self, bytes: &[u8]);
fn bytes_written(&self) -> usize;
fn reserve_size_field(&mut self) -> SizeField;
fn write_size_field(&mut self, handle: SizeField, value: u32);
}
impl Writer for Vec<u8> {
fn write_byte(&mut self, byte: u8) {
self.push(byte);
}
fn write_bytes(&mut self, bytes: &[u8]) {
self.extend_from_slice(bytes);
}
fn bytes_written(&self) -> usize {
self.len()
}
fn reserve_size_field(&mut self) -> SizeField {
let offset = self.len();
self.extend_from_slice(&[0u8; 4]);
SizeField(offset)
}
fn write_size_field(&mut self, handle: SizeField, value: u32) {
self[handle.0..handle.0 + 4].copy_from_slice(&value.to_le_bytes());
}
}
pub(crate) trait PostcardWriter<'a>: Writer {
fn write_referenced_bytes(&mut self, bytes: &'a [u8]);
}
pub(crate) struct CopyWriter<'w, W: Writer + ?Sized> {
inner: &'w mut W,
}
impl<'w, W: Writer + ?Sized> CopyWriter<'w, W> {
pub(crate) fn new(inner: &'w mut W) -> Self {
Self { inner }
}
}
impl<W: Writer + ?Sized> Writer for CopyWriter<'_, W> {
fn write_byte(&mut self, byte: u8) {
self.inner.write_byte(byte);
}
fn write_bytes(&mut self, bytes: &[u8]) {
self.inner.write_bytes(bytes);
}
fn bytes_written(&self) -> usize {
self.inner.bytes_written()
}
fn reserve_size_field(&mut self) -> SizeField {
self.inner.reserve_size_field()
}
fn write_size_field(&mut self, handle: SizeField, value: u32) {
self.inner.write_size_field(handle, value);
}
}
impl<'a, W: Writer + ?Sized> PostcardWriter<'a> for CopyWriter<'_, W> {
fn write_referenced_bytes(&mut self, bytes: &'a [u8]) {
self.inner.write_bytes(bytes);
}
}
pub fn to_vec<'a, T: Facet<'a>>(value: &T) -> Result<Vec<u8>, SerializeError> {
let peek = Peek::new(value);
let mut out = Vec::new();
serialize_peek(peek, &mut CopyWriter::new(&mut out))?;
Ok(out)
}
pub(crate) fn serialize_peek<'a>(
peek: Peek<'a, '_>,
out: &mut impl PostcardWriter<'a>,
) -> Result<(), SerializeError> {
serialize_peek_inner(peek, out)
}
fn serialize_peek_inner<'a>(
peek: Peek<'a, '_>,
out: &mut impl PostcardWriter<'a>,
) -> Result<(), SerializeError> {
let peek = peek.innermost_peek();
fn re(e: impl std::fmt::Display) -> SerializeError {
SerializeError::ReflectError(e.to_string())
}
if let Some(proxy_def) = peek.shape().proxy {
let proxy_shape = proxy_def.shape;
let proxy_layout = proxy_shape
.layout
.sized_layout()
.map_err(|_| SerializeError::ReflectError("proxy type must be sized".into()))?;
let proxy_uninit = facet_core::alloc_for_layout(proxy_layout);
#[allow(unsafe_code)]
let proxy_ptr = unsafe { (proxy_def.convert_out)(peek.data(), proxy_uninit) }
.map_err(SerializeError::ReflectError)?;
#[allow(unsafe_code)]
let proxy_peek = unsafe { Peek::unchecked_new(proxy_ptr.as_const(), proxy_shape) };
let result = serialize_peek_inner(proxy_peek, out);
#[allow(unsafe_code)]
unsafe {
let _ = proxy_shape.call_drop_in_place(proxy_ptr);
facet_core::dealloc_for_layout(proxy_ptr, proxy_layout);
}
return result;
}
if let Some(adapter) = peek.shape().opaque_adapter {
#[allow(unsafe_code)]
let mapped = unsafe { (adapter.serialize)(peek.data()) };
#[allow(unsafe_code)]
if let Some(bytes) =
unsafe { crate::raw::try_decode_passthrough_bytes(mapped.ptr, mapped.shape) }
{
out.write_bytes(&(bytes.len() as u32).to_le_bytes());
out.write_referenced_bytes(bytes);
return Ok(());
}
#[allow(unsafe_code)]
let mapped_peek = unsafe { Peek::unchecked_new(mapped.ptr, mapped.shape) };
let size_field = out.reserve_size_field();
let before = out.bytes_written();
serialize_peek_inner(mapped_peek, out)?;
let len = out.bytes_written() - before;
out.write_size_field(size_field, len as u32);
return Ok(());
}
if let Some(scalar_type) = peek.scalar_type() {
return serialize_scalar(peek, scalar_type, out);
}
match peek.shape().def {
Def::Option(_) => {
let opt = peek.into_option().map_err(re)?;
return match opt.value() {
Some(inner) => {
out.write_byte(0x01);
serialize_peek(inner, out)
}
None => {
out.write_byte(0x00);
Ok(())
}
};
}
Def::Result(_) => {
let res = peek.into_result().map_err(re)?;
return if let Some(ok_inner) = res.ok() {
encode::write_varint(out, 0);
serialize_peek(ok_inner, out)
} else if let Some(err_inner) = res.err() {
encode::write_varint(out, 1);
serialize_peek(err_inner, out)
} else {
Err(SerializeError::ReflectError(
"Result is neither Ok nor Err".into(),
))
};
}
Def::List(list_def) => {
if list_def.t().is_type::<u8>() {
let list = peek.into_list().map_err(re)?;
if let Some(bytes) = peek.as_bytes() {
encode::write_varint(out, bytes.len() as u64);
out.write_referenced_bytes(bytes);
} else {
let len = list.len();
let mut bytes = Vec::with_capacity(len);
for i in 0..len {
let elem = list
.get(i)
.ok_or_else(|| SerializeError::ReflectError("list index OOB".into()))?;
let byte = elem.get::<u8>().map_err(re)?;
bytes.push(*byte);
}
encode::write_varint(out, bytes.len() as u64);
out.write_bytes(&bytes);
}
} else {
let list = peek.into_list().map_err(re)?;
let len = list.len();
encode::write_varint(out, len as u64);
for elem in list.iter() {
serialize_peek(elem, out)?;
}
}
return Ok(());
}
Def::Array(_) => {
let list_like = peek.into_list_like().map_err(re)?;
for elem in list_like.iter() {
serialize_peek(elem, out)?;
}
return Ok(());
}
Def::Slice(slice_def) => {
let list_like = peek.into_list_like().map_err(re)?;
if slice_def.t().is_type::<u8>() {
if let Some(bytes) = list_like.as_bytes() {
encode::write_varint(out, bytes.len() as u64);
out.write_referenced_bytes(bytes);
} else {
let len = list_like.len();
let mut bytes = Vec::with_capacity(len);
for elem in list_like.iter() {
let byte = elem.get::<u8>().map_err(re)?;
bytes.push(*byte);
}
encode::write_varint(out, bytes.len() as u64);
out.write_bytes(&bytes);
}
} else {
let len = list_like.len();
encode::write_varint(out, len as u64);
for elem in list_like.iter() {
serialize_peek(elem, out)?;
}
}
return Ok(());
}
Def::Map(_) => {
let map = peek.into_map().map_err(re)?;
encode::write_varint(out, map.len() as u64);
for (key, value) in map.iter() {
serialize_peek(key, out)?;
serialize_peek(value, out)?;
}
return Ok(());
}
Def::Set(_) => {
let set = peek.into_set().map_err(re)?;
encode::write_varint(out, set.len() as u64);
for elem in set.iter() {
serialize_peek(elem, out)?;
}
return Ok(());
}
Def::Pointer(_) => {
let ptr = peek.into_pointer().map_err(re)?;
return match ptr.borrow_inner() {
Some(inner) => serialize_peek(inner, out),
None => Err(SerializeError::UnsupportedType("null pointer".into())),
};
}
_ => {}
}
match peek.shape().ty {
Type::User(UserType::Struct(struct_type)) => match struct_type.kind {
StructKind::Struct | StructKind::TupleStruct | StructKind::Tuple => {
let ps = peek.into_struct().map_err(re)?;
for i in 0..ps.field_count() {
let field_peek = ps.field(i).map_err(re)?;
serialize_peek_inner(field_peek, out)?;
}
Ok(())
}
StructKind::Unit => Ok(()),
},
Type::User(UserType::Enum(_)) => {
let pe = peek.into_enum().map_err(re)?;
let variant_index = pe.variant_index().map_err(re)?;
let variant = pe.active_variant().map_err(re)?;
encode::write_varint(out, variant_index as u64);
match variant.data.kind {
StructKind::Unit => {}
StructKind::TupleStruct | StructKind::Tuple | StructKind::Struct => {
for i in 0..variant.data.fields.len() {
let field_peek = pe.field(i).map_err(re)?.ok_or_else(|| {
SerializeError::ReflectError("missing variant field".into())
})?;
serialize_peek_inner(field_peek, out)?;
}
}
}
Ok(())
}
_ => Err(SerializeError::UnsupportedType(format!("{}", peek.shape()))),
}
}
fn serialize_scalar<'a>(
peek: Peek<'a, '_>,
scalar_type: ScalarType,
out: &mut impl PostcardWriter<'a>,
) -> Result<(), SerializeError> {
let re = |e: facet_reflect::ReflectError| SerializeError::ReflectError(e.to_string());
match scalar_type {
ScalarType::Unit => {}
ScalarType::Bool => {
let v = *peek.get::<bool>().map_err(re)?;
out.write_byte(if v { 0x01 } else { 0x00 });
}
ScalarType::Char => {
let v = *peek.get::<char>().map_err(re)?;
let mut buf = [0u8; 4];
let s = v.encode_utf8(&mut buf);
encode::write_varint(out, s.len() as u64);
out.write_bytes(s.as_bytes());
}
ScalarType::U8 => {
let v = *peek.get::<u8>().map_err(re)?;
out.write_byte(v);
}
ScalarType::U16 => {
let v = *peek.get::<u16>().map_err(re)?;
encode::write_varint(out, v as u64);
}
ScalarType::U32 => {
let v = *peek.get::<u32>().map_err(re)?;
encode::write_varint(out, v as u64);
}
ScalarType::U64 => {
let v = *peek.get::<u64>().map_err(re)?;
encode::write_varint(out, v);
}
ScalarType::U128 => {
let v = *peek.get::<u128>().map_err(re)?;
encode::write_varint_u128(out, v);
}
ScalarType::USize => {
let v = *peek.get::<usize>().map_err(re)?;
encode::write_varint(out, v as u64);
}
ScalarType::I8 => {
let v = *peek.get::<i8>().map_err(re)?;
out.write_byte(v as u8);
}
ScalarType::I16 => {
let v = *peek.get::<i16>().map_err(re)?;
encode::write_varint_signed(out, v as i64);
}
ScalarType::I32 => {
let v = *peek.get::<i32>().map_err(re)?;
encode::write_varint_signed(out, v as i64);
}
ScalarType::I64 => {
let v = *peek.get::<i64>().map_err(re)?;
encode::write_varint_signed(out, v);
}
ScalarType::I128 => {
let v = *peek.get::<i128>().map_err(re)?;
encode::write_varint_signed_i128(out, v);
}
ScalarType::ISize => {
let v = *peek.get::<isize>().map_err(re)?;
encode::write_varint_signed(out, v as i64);
}
ScalarType::F32 => {
let v = *peek.get::<f32>().map_err(re)?;
out.write_bytes(&v.to_le_bytes());
}
ScalarType::F64 => {
let v = *peek.get::<f64>().map_err(re)?;
out.write_bytes(&v.to_le_bytes());
}
ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
let s = peek
.as_str()
.ok_or_else(|| SerializeError::ReflectError("failed to extract string".into()))?;
encode::write_varint(out, s.len() as u64);
out.write_referenced_bytes(s.as_bytes());
}
_ => {
return Err(SerializeError::UnsupportedType(format!(
"scalar type {scalar_type:?}"
)));
}
}
Ok(())
}