use std::alloc;
use std::collections::{BTreeMap, HashMap};
use phon_ir::ir::{
BorrowOp, BytesOp, DefaultOp, EnumOp, EnumVariantOp, Lowered, MapOp, MemOp, MemProgram,
OpaqueOp, OptionOp, PointerOp, ResultOp, SeqOp, SetOp, SkipOp, fuse,
};
use phon_ir::{
Access, Construct, Descriptor, EnumAccess, MapStorage, Presence, RecordAccess, ResultAccess,
SequenceAccess, SequenceStorage, SetAccess, SetStorage, Tag, VariantAccess,
};
use phon_schema::bytes::{Reader, write_u8, write_u32};
use phon_schema::{
DecodeError, Field, Primitive, SchemaId, SchemaKind, SchemaRef, Value, Variant, VariantPayload,
read_value, write_value,
};
use crate::compact::{self, CompactError, Registry, Resolved, alignment, pad_to, skip_pad};
use crate::compat::{self, FieldMatch, VariantMatch, incompatible};
type Result<T> = core::result::Result<T, CompactError>;
fn fixed_size(p: Primitive) -> Option<usize> {
Some(match p {
Primitive::Unit => 0,
Primitive::Bool | Primitive::U8 | Primitive::I8 => 1,
Primitive::U16 | Primitive::I16 => 2,
Primitive::U32 | Primitive::I32 | Primitive::F32 | Primitive::Char => 4,
Primitive::U64 | Primitive::I64 | Primitive::F64 => 8,
Primitive::U128 | Primitive::I128 => 16,
Primitive::String
| Primitive::Bytes
| Primitive::Never
| Primitive::DateTime
| Primitive::Uuid
| Primitive::QName => return None,
})
}
fn elem_min_wire(element: &MemProgram) -> usize {
let zero_sized = element
.iter()
.all(|op| matches!(op, MemOp::Scalar { size: 0, .. }));
usize::from(!zero_sized)
}
pub fn lower(descriptor: &Descriptor, reg: &Registry) -> Result<MemProgram> {
let mut out = Vec::new();
lower_node(descriptor, reg, 0, &mut out)?;
Ok(fuse(out))
}
pub fn lower_typed(
descriptor: &Descriptor,
descriptor_blocks: &HashMap<SchemaId, Descriptor>,
reg: &Registry,
) -> Result<Lowered> {
let mut root = Vec::new();
lower_node(descriptor, reg, 0, &mut root)?;
let mut blocks = BTreeMap::new();
for (id, body) in descriptor_blocks {
let mut ops = Vec::new();
lower_node(body, reg, 0, &mut ops)?;
blocks.insert(*id, fuse(ops));
}
Ok(Lowered {
program: fuse(root),
blocks,
})
}
fn lower_node(d: &Descriptor, reg: &Registry, base: usize, out: &mut MemProgram) -> Result<()> {
if matches!(d.access, Access::Recurse) {
let schema = match &d.schema {
SchemaRef::Concrete { id, .. } => *id,
SchemaRef::Var { .. } => {
return Err(CompactError::Unsupported(
"typed: recursion via type-var ref",
));
}
};
out.push(MemOp::CallBlock {
schema,
offset: base,
});
return Ok(());
}
match (&d.access, compact::resolve(reg, &d.schema)?) {
(Access::Scalar, Resolved::Primitive(p)) => {
let size = fixed_size(p).ok_or(CompactError::Unsupported(
"typed: variable-length scalar field",
))?;
if d.layout.size == size {
out.push(MemOp::Scalar {
offset: base,
size,
align: alignment(p),
});
} else if matches!(p, Primitive::U64 | Primitive::I64)
&& matches!(d.layout.size, 1 | 2 | 4 | 8)
{
out.push(MemOp::NativeInt {
offset: base,
mem_size: d.layout.size,
signed: matches!(p, Primitive::I64),
});
} else {
return Err(CompactError::Unsupported(
"typed: scalar memory width differs from wire width",
));
}
Ok(())
}
(Access::Record(ra), Resolved::Composite(kind)) => {
let arity = match &kind {
SchemaKind::Struct { fields, .. } => fields.len(),
SchemaKind::Tuple { elements } => elements.len(),
_ => {
return Err(CompactError::TypeMismatch {
expected: "struct or tuple for a record descriptor",
});
}
};
if arity != ra.fields.len() {
return Err(CompactError::Malformed(
"descriptor/schema field count mismatch",
));
}
match &ra.construct {
Construct::InPlace => {}
Construct::Thunk(_) => {
return Err(CompactError::Unsupported("typed: thunk construction"));
}
}
for fa in &ra.fields {
lower_node(&fa.descriptor, reg, base + fa.offset, out)?;
}
Ok(())
}
(Access::Sequence(seq), Resolved::Composite(SchemaKind::List { .. })) => {
let SequenceStorage::Vtable(thunks) = &seq.storage else {
return Err(CompactError::Unsupported(
"typed: only vtable-backed owned sequences so far",
));
};
let stride = seq.element.layout.size;
let elem_align = seq.element.layout.align;
let mut element = Vec::new();
lower_node(&seq.element, reg, 0, &mut element)?;
let element = fuse(element);
let bulk = matches!(
element.as_slice(),
[MemOp::Scalar { offset: 0, size, align }]
if *size == stride && stride.is_multiple_of(*align)
);
if bulk {
out.push(MemOp::Bytes(Box::new(BytesOp {
field_offset: base,
stride,
elem_align,
validate: validate_any,
thunks: *thunks,
})));
} else {
let min_wire = elem_min_wire(&element);
out.push(MemOp::Sequence(Box::new(SeqOp {
field_offset: base,
element,
stride,
elem_align,
min_wire,
thunks: *thunks,
})));
}
Ok(())
}
(Access::Set(set), Resolved::Composite(SchemaKind::Set { .. })) => {
lower_set(set, reg, base, out)
}
(
Access::Array {
element,
count,
stride,
},
Resolved::Composite(SchemaKind::Array { dimensions, .. }),
) => {
require_fixed_array_count(*count, &dimensions)?;
lower_fixed_array(element, *count, *stride, reg, base, out)
}
(
Access::Sequence(seq),
Resolved::Primitive(p @ (Primitive::String | Primitive::Bytes)),
) => {
match &seq.storage {
SequenceStorage::BorrowedVtable(thunks) => {
out.push(MemOp::Borrow(Box::new(BorrowOp {
field_offset: base,
stride: 1,
elem_align: 1,
thunks: *thunks,
})));
Ok(())
}
SequenceStorage::Vtable(thunks) => {
out.push(MemOp::Bytes(Box::new(BytesOp {
field_offset: base,
stride: 1,
elem_align: 1,
validate: if matches!(p, Primitive::String) {
validate_utf8
} else {
validate_any
},
thunks: *thunks,
})));
Ok(())
}
_ => Err(CompactError::Unsupported(
"typed: string/bytes needs vtable thunks",
)),
}
}
(Access::Option(opt), Resolved::Composite(SchemaKind::Option { .. })) => {
let Presence::Vtable(thunks) = &opt.presence else {
return Err(CompactError::Unsupported(
"typed: option needs vtable presence thunks",
));
};
let mut some = Vec::new();
lower_node(&opt.some, reg, 0, &mut some)?;
out.push(MemOp::Option(Box::new(OptionOp {
field_offset: base,
some: fuse(some),
inner_size: opt.some.layout.size,
inner_align: opt.some.layout.align,
thunks: *thunks,
})));
Ok(())
}
(Access::Enum(ea), Resolved::Composite(SchemaKind::Enum { .. })) => {
let Tag::Direct { offset, width } = &ea.tag else {
return Err(CompactError::Unsupported(
"typed: only #[repr(int)] enums (direct discriminant) so far",
));
};
let mut variants = Vec::with_capacity(ea.variants.len());
for va in &ea.variants {
let mut payload = Vec::new();
for f in &va.payload.fields {
lower_node(&f.descriptor, reg, base + f.offset, &mut payload)?;
}
variants.push(EnumVariantOp {
wire_index: va.index,
selector: va.selector,
payload: fuse(payload),
});
}
out.push(MemOp::Enum(Box::new(EnumOp {
tag_offset: base + *offset,
tag_width: *width,
variants,
writer_only: Vec::new(),
})));
Ok(())
}
(Access::Map(ma), Resolved::Composite(SchemaKind::Map { .. })) => {
let MapStorage::Vtable(thunks) = &ma.storage else {
return Err(CompactError::Unsupported("typed: map needs vtable thunks"));
};
let mut key = Vec::new();
lower_node(&ma.key, reg, 0, &mut key)?;
let mut value = Vec::new();
lower_node(&ma.value, reg, 0, &mut value)?;
out.push(MemOp::Map(Box::new(MapOp {
field_offset: base,
key: fuse(key),
value: fuse(value),
key_size: ma.key.layout.size,
key_align: ma.key.layout.align,
value_size: ma.value.layout.size,
value_align: ma.value.layout.align,
thunks: *thunks,
})));
Ok(())
}
(Access::Dynamic, Resolved::Composite(SchemaKind::Dynamic)) => {
out.push(MemOp::Dynamic { field_offset: base });
Ok(())
}
(Access::Result(ra), Resolved::Composite(SchemaKind::Enum { variants, .. })) => {
out.push(MemOp::Result(Box::new(lower_result(
ra, &variants, reg, base,
)?)));
Ok(())
}
(Access::Pointer(pa), _) => {
let mut pointee = Vec::new();
lower_node(&pa.pointee, reg, 0, &mut pointee)?;
out.push(MemOp::Pointer(Box::new(PointerOp {
field_offset: base,
pointee: fuse(pointee),
pointee_size: pa.pointee.layout.size,
pointee_align: pa.pointee.layout.align,
thunks: pa.thunks,
})));
Ok(())
}
(Access::Opaque(thunks), Resolved::Primitive(Primitive::Bytes)) => {
out.push(MemOp::Opaque(Box::new(OpaqueOp {
field_offset: base,
thunks: *thunks,
})));
Ok(())
}
_ => Err(CompactError::Unsupported(
"typed: only fixed scalars, in-place records, owned sequences, strings, options, #[repr(int)] enums, and opaque fields so far",
)),
}
}
pub fn lower_decode(
writer_root: SchemaId,
reader: &Descriptor,
reader_blocks: &HashMap<SchemaId, Descriptor>,
reg: &Registry,
) -> Result<Lowered> {
let mut out = Vec::new();
lower_decode_node(&SchemaRef::concrete(writer_root), reader, reg, 0, &mut out)?;
let mut blocks = BTreeMap::new();
for (id, body) in reader_blocks {
let mut ops = Vec::new();
lower_decode_node(&SchemaRef::concrete(*id), body, reg, 0, &mut ops)?;
blocks.insert(*id, fuse(ops));
}
Ok(Lowered {
program: fuse(out),
blocks,
})
}
fn lower_decode_node(
writer: &SchemaRef,
reader: &Descriptor,
reg: &Registry,
base: usize,
out: &mut MemProgram,
) -> Result<()> {
if matches!(reader.access, Access::Recurse) {
let schema = match &reader.schema {
SchemaRef::Concrete { id, .. } => *id,
SchemaRef::Var { .. } => {
return Err(CompactError::Unsupported(
"typed: recursion via type-var ref (decode)",
));
}
};
out.push(MemOp::CallBlock {
schema,
offset: base,
});
return Ok(());
}
let w = compact::resolve(reg, writer)?;
match (&reader.access, w) {
(Access::Scalar, Resolved::Primitive(wp)) => {
let Resolved::Primitive(rp) = compact::resolve(reg, &reader.schema)? else {
return Err(CompactError::TypeMismatch {
expected: "scalar reader schema for a scalar descriptor",
});
};
if wp != rp {
return Err(incompatible(format!("primitive {wp:?} is not {rp:?}")));
}
let size = fixed_size(wp).ok_or(CompactError::Unsupported(
"typed: variable-length scalar field",
))?;
out.push(MemOp::Scalar {
offset: base,
size,
align: alignment(wp),
});
Ok(())
}
(Access::Record(ra), Resolved::Composite(SchemaKind::Struct { fields: wf, .. })) => {
lower_decode_record(&wf, ra, &reader.schema, RecordKind::Struct, reg, base, out)
}
(Access::Record(ra), Resolved::Composite(SchemaKind::Tuple { elements })) => {
let wf = tuple_fields(elements);
lower_decode_record(&wf, ra, &reader.schema, RecordKind::Tuple, reg, base, out)
}
(Access::Enum(ea), Resolved::Composite(SchemaKind::Enum { variants: wv, .. })) => {
lower_decode_enum(&wv, ea, &reader.schema, reg, base, out)
}
(Access::Option(opt), Resolved::Composite(SchemaKind::Option { element: we })) => {
require_reader_option(&reader.schema, reg)?;
let Presence::Vtable(thunks) = &opt.presence else {
return Err(CompactError::Unsupported(
"typed: option needs vtable presence thunks",
));
};
let mut some = Vec::new();
lower_decode_node(&we, &opt.some, reg, 0, &mut some)?;
out.push(MemOp::Option(Box::new(OptionOp {
field_offset: base,
some: fuse(some),
inner_size: opt.some.layout.size,
inner_align: opt.some.layout.align,
thunks: *thunks,
})));
Ok(())
}
(Access::Sequence(seq), Resolved::Composite(SchemaKind::List { element: we })) => {
require_reader_list(&reader.schema, reg)?;
lower_decode_sequence(&we, seq, reg, base, out)
}
(Access::Set(set), Resolved::Composite(SchemaKind::Set { element: we })) => {
require_reader_set(&reader.schema, reg)?;
lower_decode_set(&we, set, reg, base, out)
}
(
Access::Array {
element,
count,
stride,
},
Resolved::Composite(SchemaKind::Array {
element: we,
dimensions: wd,
}),
) => {
let Resolved::Composite(SchemaKind::Array { dimensions: rd, .. }) =
compact::resolve(reg, &reader.schema)?
else {
return Err(incompatible("schema kinds differ"));
};
if wd != rd {
return Err(incompatible("array dimensions differ"));
}
require_fixed_array_count(*count, &wd)?;
lower_decode_fixed_array(&we, element, *count, *stride, reg, base, out)
}
(
Access::Sequence(seq),
Resolved::Primitive(p @ (Primitive::String | Primitive::Bytes)),
) => {
let Resolved::Primitive(rp) = compact::resolve(reg, &reader.schema)? else {
return Err(CompactError::TypeMismatch {
expected: "string/bytes reader schema",
});
};
if p != rp {
return Err(incompatible(format!("primitive {p:?} is not {rp:?}")));
}
match &seq.storage {
SequenceStorage::BorrowedVtable(thunks) => {
out.push(MemOp::Borrow(Box::new(BorrowOp {
field_offset: base,
stride: 1,
elem_align: 1,
thunks: *thunks,
})));
Ok(())
}
SequenceStorage::Vtable(thunks) => {
out.push(MemOp::Bytes(Box::new(BytesOp {
field_offset: base,
stride: 1,
elem_align: 1,
validate: if matches!(p, Primitive::String) {
validate_utf8
} else {
validate_any
},
thunks: *thunks,
})));
Ok(())
}
_ => Err(CompactError::Unsupported(
"typed: string/bytes needs vtable thunks",
)),
}
}
(Access::Map(ma), Resolved::Composite(SchemaKind::Map { key: wk, value: wv })) => {
require_reader_map(&reader.schema, reg)?;
let MapStorage::Vtable(thunks) = &ma.storage else {
return Err(CompactError::Unsupported("typed: map needs vtable thunks"));
};
let mut key = Vec::new();
lower_decode_node(&wk, &ma.key, reg, 0, &mut key)?;
let mut value = Vec::new();
lower_decode_node(&wv, &ma.value, reg, 0, &mut value)?;
out.push(MemOp::Map(Box::new(MapOp {
field_offset: base,
key: fuse(key),
value: fuse(value),
key_size: ma.key.layout.size,
key_align: ma.key.layout.align,
value_size: ma.value.layout.size,
value_align: ma.value.layout.align,
thunks: *thunks,
})));
Ok(())
}
(Access::Dynamic, Resolved::Composite(SchemaKind::Dynamic)) => {
require_reader_dynamic(&reader.schema, reg)?;
out.push(MemOp::Dynamic { field_offset: base });
Ok(())
}
(Access::Result(ra), Resolved::Composite(SchemaKind::Enum { variants: wv, .. })) => {
out.push(MemOp::Result(Box::new(lower_decode_result(
&wv, ra, reg, base,
)?)));
Ok(())
}
(Access::Pointer(pa), _) => {
let mut pointee = Vec::new();
lower_decode_node(writer, &pa.pointee, reg, 0, &mut pointee)?;
out.push(MemOp::Pointer(Box::new(PointerOp {
field_offset: base,
pointee: fuse(pointee),
pointee_size: pa.pointee.layout.size,
pointee_align: pa.pointee.layout.align,
thunks: pa.thunks,
})));
Ok(())
}
(Access::Opaque(thunks), Resolved::Primitive(Primitive::Bytes)) => {
require_reader_bytes(&reader.schema, reg)?;
out.push(MemOp::Opaque(Box::new(OpaqueOp {
field_offset: base,
thunks: *thunks,
})));
Ok(())
}
_ => Err(incompatible("writer and reader schema kinds differ")),
}
}
fn lower_decode_sequence(
writer_element: &SchemaRef,
seq: &SequenceAccess,
reg: &Registry,
base: usize,
out: &mut MemProgram,
) -> Result<()> {
let SequenceStorage::Vtable(thunks) = &seq.storage else {
return Err(CompactError::Unsupported(
"typed: only vtable-backed owned sequences so far",
));
};
let stride = seq.element.layout.size;
let elem_align = seq.element.layout.align;
let mut element = Vec::new();
lower_decode_node(writer_element, &seq.element, reg, 0, &mut element)?;
let element = fuse(element);
let bulk = matches!(
element.as_slice(),
[MemOp::Scalar { offset: 0, size, align }]
if *size == stride && stride.is_multiple_of(*align)
);
if bulk {
out.push(MemOp::Bytes(Box::new(BytesOp {
field_offset: base,
stride,
elem_align,
validate: validate_any,
thunks: *thunks,
})));
} else {
let min_wire = elem_min_wire(&element);
out.push(MemOp::Sequence(Box::new(SeqOp {
field_offset: base,
element,
stride,
elem_align,
min_wire,
thunks: *thunks,
})));
}
Ok(())
}
fn lower_fixed_array(
element: &Descriptor,
count: usize,
stride: usize,
reg: &Registry,
base: usize,
out: &mut MemProgram,
) -> Result<()> {
let mut element_ops = Vec::new();
lower_node(element, reg, 0, &mut element_ops)?;
let element_ops = fuse(element_ops);
if let [
MemOp::Scalar {
offset: 0,
size,
align,
},
] = element_ops.as_slice()
&& *size == stride
&& stride.is_multiple_of(*align)
{
out.push(MemOp::Scalar {
offset: base,
size: fixed_array_copy_size(count, stride)?,
align: *align,
});
return Ok(());
}
for i in 0..count {
lower_node(element, reg, array_element_offset(base, i, stride)?, out)?;
}
Ok(())
}
fn lower_decode_fixed_array(
writer_element: &SchemaRef,
element: &Descriptor,
count: usize,
stride: usize,
reg: &Registry,
base: usize,
out: &mut MemProgram,
) -> Result<()> {
let mut element_ops = Vec::new();
lower_decode_node(writer_element, element, reg, 0, &mut element_ops)?;
let element_ops = fuse(element_ops);
if let [
MemOp::Scalar {
offset: 0,
size,
align,
},
] = element_ops.as_slice()
&& *size == stride
&& stride.is_multiple_of(*align)
{
out.push(MemOp::Scalar {
offset: base,
size: fixed_array_copy_size(count, stride)?,
align: *align,
});
return Ok(());
}
for i in 0..count {
lower_decode_node(
writer_element,
element,
reg,
array_element_offset(base, i, stride)?,
out,
)?;
}
Ok(())
}
fn fixed_array_copy_size(count: usize, stride: usize) -> Result<usize> {
count
.checked_mul(stride)
.ok_or(CompactError::Malformed("array bulk copy size overflow"))
}
fn array_element_offset(base: usize, index: usize, stride: usize) -> Result<usize> {
let rel = index
.checked_mul(stride)
.ok_or(CompactError::Malformed("array element offset overflow"))?;
base.checked_add(rel)
.ok_or(CompactError::Malformed("array element offset overflow"))
}
fn require_fixed_array_count(count: usize, dimensions: &[u64]) -> Result<()> {
let schema_count = compact::product(dimensions)?;
let descriptor_count = u64::try_from(count)
.map_err(|_| CompactError::Malformed("descriptor array length overflows u64"))?;
if schema_count == descriptor_count {
Ok(())
} else {
Err(CompactError::Malformed(
"descriptor/schema array length mismatch",
))
}
}
fn lower_set(set: &SetAccess, reg: &Registry, base: usize, out: &mut MemProgram) -> Result<()> {
let SetStorage::Vtable(thunks) = &set.storage;
let mut element = Vec::new();
lower_node(&set.element, reg, 0, &mut element)?;
let element = fuse(element);
let min_wire = elem_min_wire(&element);
out.push(MemOp::Set(Box::new(SetOp {
field_offset: base,
element,
elem_size: set.element.layout.size,
elem_align: set.element.layout.align,
min_wire,
thunks: *thunks,
})));
Ok(())
}
fn lower_decode_set(
writer_element: &SchemaRef,
set: &SetAccess,
reg: &Registry,
base: usize,
out: &mut MemProgram,
) -> Result<()> {
let SetStorage::Vtable(thunks) = &set.storage;
let mut element = Vec::new();
lower_decode_node(writer_element, &set.element, reg, 0, &mut element)?;
let element = fuse(element);
let min_wire = elem_min_wire(&element);
out.push(MemOp::Set(Box::new(SetOp {
field_offset: base,
element,
elem_size: set.element.layout.size,
elem_align: set.element.layout.align,
min_wire,
thunks: *thunks,
})));
Ok(())
}
enum RecordKind {
Struct,
Tuple,
}
fn lower_decode_record(
w_fields: &[Field],
ra: &RecordAccess,
reader_schema: &SchemaRef,
record_kind: RecordKind,
reg: &Registry,
base: usize,
out: &mut MemProgram,
) -> Result<()> {
match &ra.construct {
Construct::InPlace => {}
Construct::Thunk(_) => {
return Err(CompactError::Unsupported("typed: thunk construction"));
}
}
let r_named = reader_record_fields(reader_schema, record_kind, reg)?;
if r_named.len() != ra.fields.len() {
return Err(CompactError::Malformed(
"descriptor/schema field count mismatch",
));
}
for step in compat::match_fields(
w_fields,
&r_named,
|ri, _| ra.fields[ri].default.is_some(),
|rf| {
incompatible(format!(
"required reader field '{}' is absent from the writer",
rf.name
))
},
)? {
match step {
FieldMatch::Take {
writer,
reader_index: ri,
} => {
let fa = &ra.fields[ri];
lower_decode_node(&writer.schema, &fa.descriptor, reg, base + fa.offset, out)?;
}
FieldMatch::Skip { writer } => {
out.push(MemOp::SkipWire(Box::new(skip_op(&writer.schema, reg)?)));
}
FieldMatch::Default { reader_index: ri } => {
let fa = &ra.fields[ri];
let Some(d) = fa.default else {
return Err(incompatible(format!(
"required reader field '{}' is absent from the writer",
r_named[ri].name
)));
};
out.push(MemOp::Default(Box::new(DefaultOp {
offset: base + fa.offset,
ctx: d.ctx,
default: d.thunk,
})));
}
}
}
Ok(())
}
fn lower_decode_enum(
w_variants: &[Variant],
ea: &EnumAccess,
reader_schema: &SchemaRef,
reg: &Registry,
base: usize,
out: &mut MemProgram,
) -> Result<()> {
let Tag::Direct { offset, width } = &ea.tag else {
return Err(CompactError::Unsupported(
"typed: only #[repr(int)] enums (direct discriminant) so far",
));
};
let r_named = reader_enum_variants(reader_schema, reg)?;
if r_named.len() != ea.variants.len() {
return Err(CompactError::Malformed(
"descriptor/schema variant count mismatch",
));
}
let mut variants = Vec::new();
let mut writer_only = Vec::new();
for step in compat::match_variants(w_variants, &r_named) {
match step {
VariantMatch::Take {
writer,
reader_index: ri,
} => {
let va = &ea.variants[ri];
let payload =
lower_decode_payload(&writer.payload, va, &r_named[ri].payload, reg, base)?;
variants.push(EnumVariantOp {
wire_index: writer.index,
selector: va.selector,
payload,
});
}
VariantMatch::WriterOnly { writer } => {
writer_only.push(writer.index);
}
}
}
out.push(MemOp::Enum(Box::new(EnumOp {
tag_offset: base + *offset,
tag_width: *width,
variants,
writer_only,
})));
Ok(())
}
fn lower_decode_payload(
w: &VariantPayload,
va: &VariantAccess,
r_schema_payload: &VariantPayload,
reg: &Registry,
base: usize,
) -> Result<MemProgram> {
let mut payload = Vec::new();
match (w, r_schema_payload) {
(VariantPayload::Unit, VariantPayload::Unit) => {}
(VariantPayload::Newtype(wr), VariantPayload::Newtype(_)) => {
let fa = va.payload.fields.first().ok_or(CompactError::Malformed(
"newtype variant has no payload field",
))?;
lower_decode_node(wr, &fa.descriptor, reg, base + fa.offset, &mut payload)?;
}
(VariantPayload::Tuple(wrs), VariantPayload::Tuple(rrs)) => {
if wrs.len() != rrs.len() || wrs.len() != va.payload.fields.len() {
return Err(incompatible("variant tuple arity differs"));
}
for (wr, fa) in wrs.iter().zip(&va.payload.fields) {
lower_decode_node(wr, &fa.descriptor, reg, base + fa.offset, &mut payload)?;
}
}
(VariantPayload::Struct(wfs), VariantPayload::Struct(rfs)) => {
lower_decode_variant_struct(wfs, &va.payload, rfs, reg, base, &mut payload)?;
}
_ => return Err(incompatible("variant payload shapes differ")),
}
Ok(fuse(payload))
}
fn lower_decode_variant_struct(
w_fields: &[Field],
ra: &RecordAccess,
r_fields: &[Field],
reg: &Registry,
base: usize,
out: &mut MemProgram,
) -> Result<()> {
if r_fields.len() != ra.fields.len() {
return Err(CompactError::Malformed(
"variant descriptor/schema field count mismatch",
));
}
for step in compat::match_fields(
w_fields,
r_fields,
|ri, _| ra.fields[ri].default.is_some(),
|rf| {
incompatible(format!(
"required reader variant field '{}' is absent from the writer",
rf.name
))
},
)? {
match step {
FieldMatch::Take {
writer,
reader_index: ri,
} => {
let fa = &ra.fields[ri];
lower_decode_node(&writer.schema, &fa.descriptor, reg, base + fa.offset, out)?;
}
FieldMatch::Skip { writer } => {
out.push(MemOp::SkipWire(Box::new(skip_op(&writer.schema, reg)?)));
}
FieldMatch::Default { reader_index: ri } => {
let fa = &ra.fields[ri];
let Some(d) = fa.default else {
return Err(incompatible(format!(
"required reader variant field '{}' is absent from the writer",
r_fields[ri].name
)));
};
out.push(MemOp::Default(Box::new(DefaultOp {
offset: base + fa.offset,
ctx: d.ctx,
default: d.thunk,
})));
}
}
}
Ok(())
}
fn tuple_fields(elements: Vec<SchemaRef>) -> Vec<Field> {
elements
.into_iter()
.enumerate()
.map(|(i, schema)| Field {
name: i.to_string(),
schema,
required: true,
})
.collect()
}
fn reader_record_fields(
r: &SchemaRef,
record_kind: RecordKind,
reg: &Registry,
) -> Result<Vec<Field>> {
match (record_kind, compact::resolve(reg, r)?) {
(RecordKind::Struct, Resolved::Composite(SchemaKind::Struct { fields, .. })) => Ok(fields),
(RecordKind::Tuple, Resolved::Composite(SchemaKind::Tuple { elements })) => {
Ok(tuple_fields(elements))
}
_ => Err(incompatible("schema kinds differ")),
}
}
fn reader_enum_variants(r: &SchemaRef, reg: &Registry) -> Result<Vec<Variant>> {
match compact::resolve(reg, r)? {
Resolved::Composite(SchemaKind::Enum { variants, .. }) => Ok(variants),
_ => Err(CompactError::TypeMismatch {
expected: "enum reader schema for an enum descriptor",
}),
}
}
fn require_reader_list(r: &SchemaRef, reg: &Registry) -> Result<()> {
match compact::resolve(reg, r)? {
Resolved::Composite(SchemaKind::List { .. }) => Ok(()),
_ => Err(incompatible("schema kinds differ")),
}
}
fn require_reader_set(r: &SchemaRef, reg: &Registry) -> Result<()> {
match compact::resolve(reg, r)? {
Resolved::Composite(SchemaKind::Set { .. }) => Ok(()),
_ => Err(incompatible("schema kinds differ")),
}
}
fn require_reader_option(r: &SchemaRef, reg: &Registry) -> Result<()> {
match compact::resolve(reg, r)? {
Resolved::Composite(SchemaKind::Option { .. }) => Ok(()),
_ => Err(incompatible("schema kinds differ")),
}
}
fn require_reader_map(r: &SchemaRef, reg: &Registry) -> Result<()> {
match compact::resolve(reg, r)? {
Resolved::Composite(SchemaKind::Map { .. }) => Ok(()),
_ => Err(incompatible("schema kinds differ")),
}
}
fn require_reader_dynamic(r: &SchemaRef, reg: &Registry) -> Result<()> {
match compact::resolve(reg, r)? {
Resolved::Composite(SchemaKind::Dynamic) => Ok(()),
_ => Err(incompatible("schema kinds differ")),
}
}
fn require_reader_bytes(r: &SchemaRef, reg: &Registry) -> Result<()> {
match compact::resolve(reg, r)? {
Resolved::Primitive(Primitive::Bytes) => Ok(()),
_ => Err(incompatible("primitive Bytes is not reader schema")),
}
}
fn variant_index_by_name(variants: &[Variant], name: &str) -> Result<u32> {
variants
.iter()
.find(|v| v.name == name)
.map(|v| v.index)
.ok_or(CompactError::Malformed(
"Result schema missing Ok or Err variant",
))
}
fn lower_result(
ra: &ResultAccess,
variants: &[Variant],
reg: &Registry,
base: usize,
) -> Result<ResultOp> {
let ok_wire_index = variant_index_by_name(variants, "Ok")?;
let err_wire_index = variant_index_by_name(variants, "Err")?;
let mut ok = Vec::new();
lower_node(&ra.ok, reg, 0, &mut ok)?;
let mut err = Vec::new();
lower_node(&ra.err, reg, 0, &mut err)?;
Ok(ResultOp {
field_offset: base,
ok: fuse(ok),
ok_size: ra.ok.layout.size,
ok_align: ra.ok.layout.align,
ok_wire_index,
err: fuse(err),
err_size: ra.err.layout.size,
err_align: ra.err.layout.align,
err_wire_index,
thunks: ra.thunks,
})
}
fn lower_decode_result(
wv: &[Variant],
ra: &ResultAccess,
reg: &Registry,
base: usize,
) -> Result<ResultOp> {
let ok_wv = wv
.iter()
.find(|v| v.name == "Ok")
.ok_or_else(|| incompatible("writer Result schema missing Ok variant"))?;
let err_wv = wv
.iter()
.find(|v| v.name == "Err")
.ok_or_else(|| incompatible("writer Result schema missing Err variant"))?;
Ok(ResultOp {
field_offset: base,
ok: lower_decode_result_arm(&ok_wv.payload, &ra.ok, reg)?,
ok_size: ra.ok.layout.size,
ok_align: ra.ok.layout.align,
ok_wire_index: ok_wv.index,
err: lower_decode_result_arm(&err_wv.payload, &ra.err, reg)?,
err_size: ra.err.layout.size,
err_align: ra.err.layout.align,
err_wire_index: err_wv.index,
thunks: ra.thunks,
})
}
fn lower_decode_result_arm(
w: &VariantPayload,
reader: &Descriptor,
reg: &Registry,
) -> Result<MemProgram> {
let VariantPayload::Newtype(wr) = w else {
return Err(incompatible("Result arm payload must be a newtype"));
};
let mut prog = Vec::new();
lower_decode_node(wr, reader, reg, 0, &mut prog)?;
Ok(fuse(prog))
}
fn skip_op(writer: &SchemaRef, reg: &Registry) -> Result<SkipOp> {
match compact::resolve(reg, writer)? {
Resolved::Primitive(p) => match p {
Primitive::String | Primitive::Bytes => Ok(SkipOp::Bytes {
stride: 1,
elem_align: 1,
}),
other => {
let size = fixed_size(other).ok_or(CompactError::Unsupported(
"skip: variable-length scalar (datetime/uuid/qname)",
))?;
Ok(SkipOp::Scalar {
size,
align: alignment(other),
})
}
},
Resolved::Composite(kind) => match kind {
SchemaKind::Struct { fields, .. } => {
let mut fs = Vec::with_capacity(fields.len());
for f in &fields {
fs.push(skip_op(&f.schema, reg)?);
}
Ok(SkipOp::Struct(fs))
}
SchemaKind::Tuple { elements } => {
let mut fs = Vec::with_capacity(elements.len());
for e in &elements {
fs.push(skip_op(e, reg)?);
}
Ok(SkipOp::Struct(fs))
}
SchemaKind::Enum { variants, .. } => {
let mut arms = Vec::with_capacity(variants.len());
for v in &variants {
let fields = match &v.payload {
VariantPayload::Unit => Vec::new(),
VariantPayload::Newtype(r) => vec![skip_op(r, reg)?],
VariantPayload::Tuple(rs) => {
let mut fs = Vec::with_capacity(rs.len());
for r in rs {
fs.push(skip_op(r, reg)?);
}
fs
}
VariantPayload::Struct(fields) => {
let mut fs = Vec::with_capacity(fields.len());
for f in fields {
fs.push(skip_op(&f.schema, reg)?);
}
fs
}
};
arms.push((v.index, fields));
}
Ok(SkipOp::Enum(arms))
}
SchemaKind::List { element } | SchemaKind::Set { element } => {
if let Resolved::Primitive(ep) = compact::resolve(reg, &element)?
&& let Some(size) = fixed_size(ep)
&& !matches!(ep, Primitive::String | Primitive::Bytes)
{
let align = alignment(ep);
if size % align == 0 {
return Ok(SkipOp::Bytes {
stride: size,
elem_align: align,
});
}
}
Ok(SkipOp::Seq(Box::new(skip_op(&element, reg)?)))
}
SchemaKind::Option { element } => Ok(SkipOp::Option(Box::new(skip_op(&element, reg)?))),
SchemaKind::Map { key, value } => Ok(SkipOp::Map(
Box::new(skip_op(&key, reg)?),
Box::new(skip_op(&value, reg)?),
)),
SchemaKind::Array { .. } => Err(CompactError::Unsupported("skip: fixed array")),
SchemaKind::Tensor { .. } => Err(CompactError::Unsupported("skip: tensor")),
SchemaKind::Channel { .. } => Err(CompactError::Unsupported("skip: channel")),
SchemaKind::External { .. } => Err(CompactError::Unsupported("skip: external")),
SchemaKind::Dynamic => Ok(SkipOp::Dynamic),
SchemaKind::Primitive(_) => {
Err(CompactError::Malformed(
"skip: primitive in composite position",
))
}
},
}
}
unsafe fn read_uint(ptr: *const u8, width: usize) -> u64 {
let mut buf = [0u8; 8];
unsafe { core::ptr::copy_nonoverlapping(ptr, buf.as_mut_ptr(), width) };
u64::from_le_bytes(buf)
}
unsafe fn write_uint(ptr: *mut u8, width: usize, val: u64) {
let bytes = val.to_le_bytes();
unsafe { core::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, width) };
}
fn sign_extend(raw: u64, width: usize) -> i64 {
if width >= 8 {
raw as i64
} else {
let shift = 64 - width * 8;
((raw << shift) as i64) >> shift
}
}
fn signed_fits_width(value: i64, width: usize) -> bool {
if width >= 8 {
return true;
}
let bits = width * 8;
let min = -(1i64 << (bits - 1));
let max = (1i64 << (bits - 1)) - 1;
(min..=max).contains(&value)
}
fn width_mask(width: usize) -> u64 {
if width >= 8 {
u64::MAX
} else {
(1u64 << (width * 8)) - 1
}
}
unsafe extern "C" fn validate_utf8(ptr: *const u8, len: usize) -> bool {
let bytes = unsafe { core::slice::from_raw_parts(ptr, len) };
core::str::from_utf8(bytes).is_ok()
}
unsafe extern "C" fn validate_any(_ptr: *const u8, _len: usize) -> bool {
true
}
#[must_use]
pub unsafe fn encode_with(lowered: &Lowered, base: *const u8) -> Vec<u8> {
let mut out = Vec::new();
unsafe { encode_program(&lowered.program, base, &mut out, &lowered.blocks) };
out
}
unsafe fn encode_program(
program: &MemProgram,
base: *const u8,
out: &mut Vec<u8>,
blocks: &BTreeMap<SchemaId, MemProgram>,
) {
for op in program {
match op {
MemOp::CallBlock { schema, offset } => {
let block = blocks
.get(schema)
.expect("CallBlock references a lowered recursion block");
unsafe { encode_program(block, base.add(*offset), out, blocks) };
}
MemOp::Scalar {
offset,
size,
align,
} => {
pad_to(out, *align);
let src = unsafe { core::slice::from_raw_parts(base.add(*offset), *size) };
out.extend_from_slice(src);
}
MemOp::NativeInt {
offset,
mem_size,
signed,
} => {
pad_to(out, 8);
let raw = unsafe { read_uint(base.add(*offset), *mem_size) };
if *signed {
out.extend_from_slice(&sign_extend(raw, *mem_size).to_le_bytes());
} else {
out.extend_from_slice(&raw.to_le_bytes());
}
}
MemOp::Sequence(s) => {
let list = unsafe { base.add(s.field_offset) };
let n = unsafe { (s.thunks.len)(s.thunks.ctx, list) };
write_u32(out, n as u32);
let data = unsafe { (s.thunks.data)(s.thunks.ctx, list) };
for i in 0..n {
unsafe { encode_program(&s.element, data.add(i * s.stride), out, blocks) };
}
}
MemOp::Set(s) => {
let set = unsafe { base.add(s.field_offset) };
let n = unsafe { (s.thunks.len)(s.thunks.ctx, set) };
write_u32(out, n as u32);
let it = unsafe { (s.thunks.iter_init)(s.thunks.ctx, set) };
loop {
let mut value: *const u8 = core::ptr::null();
if !unsafe { (s.thunks.iter_next)(s.thunks.ctx, it, &mut value) } {
break;
}
unsafe { encode_program(&s.element, value, out, blocks) };
}
unsafe { (s.thunks.iter_dealloc)(s.thunks.ctx, it) };
}
MemOp::Bytes(b) => {
let list = unsafe { base.add(b.field_offset) };
let count = unsafe { (b.thunks.len)(b.thunks.ctx, list) };
write_u32(out, count as u32);
if count > 0 {
pad_to(out, b.elem_align);
}
let data = unsafe { (b.thunks.data)(b.thunks.ctx, list) };
let src = unsafe { core::slice::from_raw_parts(data, count * b.stride) };
out.extend_from_slice(src);
}
MemOp::Borrow(b) => {
let field = unsafe { base.add(b.field_offset) };
let count = unsafe { (b.thunks.len)(b.thunks.ctx, field) };
write_u32(out, count as u32);
if count > 0 {
pad_to(out, b.elem_align);
}
let data = unsafe { (b.thunks.data)(b.thunks.ctx, field) };
let src = unsafe { core::slice::from_raw_parts(data, count * b.stride) };
out.extend_from_slice(src);
}
MemOp::Option(o) => {
let option = unsafe { base.add(o.field_offset) };
if unsafe { (o.thunks.is_some)(o.thunks.ctx, option) } {
write_u8(out, 1);
let inner = unsafe { (o.thunks.get_value)(o.thunks.ctx, option) };
unsafe { encode_program(&o.some, inner, out, blocks) };
} else {
write_u8(out, 0);
}
}
MemOp::Enum(e) => {
let disc = unsafe { read_uint(base.add(e.tag_offset), e.tag_width) };
let mask = width_mask(e.tag_width);
let variant = e
.variants
.iter()
.find(|v| (v.selector & mask) == (disc & mask))
.expect("enum discriminant matches no modelled variant (invalid value)");
write_u32(out, variant.wire_index);
unsafe { encode_program(&variant.payload, base, out, blocks) };
}
MemOp::Map(m) => {
let map = unsafe { base.add(m.field_offset) };
let n = unsafe { (m.thunks.len)(m.thunks.ctx, map) };
write_u32(out, n as u32);
let it = unsafe { (m.thunks.iter_init)(m.thunks.ctx, map) };
loop {
let mut k: *const u8 = core::ptr::null();
let mut v: *const u8 = core::ptr::null();
if !unsafe { (m.thunks.iter_next)(m.thunks.ctx, it, &mut k, &mut v) } {
break;
}
unsafe { encode_program(&m.key, k, out, blocks) };
unsafe { encode_program(&m.value, v, out, blocks) };
}
unsafe { (m.thunks.iter_dealloc)(m.thunks.ctx, it) };
}
MemOp::Dynamic { field_offset } => {
let v = unsafe { &*base.add(*field_offset).cast::<Value>() };
write_value(out, v)
.expect("dynamic value is encodable by the self-describing codec");
}
MemOp::Result(rs) => {
let result = unsafe { base.add(rs.field_offset) };
if unsafe { (rs.thunks.is_ok)(rs.thunks.ctx, result) } {
write_u32(out, rs.ok_wire_index);
let ok = unsafe { (rs.thunks.get_ok)(rs.thunks.ctx, result) };
unsafe { encode_program(&rs.ok, ok, out, blocks) };
} else {
write_u32(out, rs.err_wire_index);
let err = unsafe { (rs.thunks.get_err)(rs.thunks.ctx, result) };
unsafe { encode_program(&rs.err, err, out, blocks) };
}
}
MemOp::Pointer(p) => {
let pointer = unsafe { base.add(p.field_offset) };
let pointee = unsafe { (p.thunks.borrow)(p.thunks.ctx, pointer) };
unsafe { encode_program(&p.pointee, pointee, out, blocks) };
}
MemOp::Opaque(o) => {
let field = unsafe { base.add(o.field_offset) };
let len_pos = out.len();
write_u32(out, 0); let start = out.len();
unsafe { (o.thunks.encode)(o.thunks.ctx, field, core::ptr::from_mut(out)) };
let inner_len = (out.len() - start) as u32;
out[len_pos..len_pos + 4].copy_from_slice(&inner_len.to_le_bytes());
}
MemOp::SkipWire(_) | MemOp::Default(_) => {
unreachable!("typed encode never emits compat skip/default ops")
}
}
}
}
pub unsafe fn encode(
base: *const u8,
descriptor: &Descriptor,
descriptor_blocks: &HashMap<SchemaId, Descriptor>,
reg: &Registry,
) -> Result<Vec<u8>> {
let lowered = lower_typed(descriptor, descriptor_blocks, reg)?;
Ok(unsafe { encode_with(&lowered, base) })
}
pub unsafe fn decode_with(lowered: &Lowered, bytes: &[u8], base: *mut u8) -> Result<()> {
let mut r = Reader::new(bytes);
unsafe { decode_program(&lowered.program, &mut r, base, &lowered.blocks)? };
if r.remaining() != 0 {
return Err(CompactError::Decode(DecodeError::TrailingBytes(
r.remaining(),
)));
}
Ok(())
}
unsafe fn decode_program(
program: &MemProgram,
r: &mut Reader,
base: *mut u8,
blocks: &BTreeMap<SchemaId, MemProgram>,
) -> Result<()> {
for op in program {
match op {
MemOp::CallBlock { schema, offset } => {
let block = blocks
.get(schema)
.expect("CallBlock references a lowered recursion block");
unsafe { decode_program(block, r, base.add(*offset), blocks)? };
}
MemOp::Scalar {
offset,
size,
align,
} => {
skip_pad(r, *align)?;
let src = r.read_slice(*size)?;
unsafe { core::ptr::copy_nonoverlapping(src.as_ptr(), base.add(*offset), *size) };
}
MemOp::NativeInt {
offset,
mem_size,
signed,
} => {
skip_pad(r, 8)?;
if *signed {
let value = r.read_i64()?;
if !signed_fits_width(value, *mem_size) {
return Err(DecodeError::Malformed(
"native-sized signed integer out of range",
)
.into());
}
unsafe { write_uint(base.add(*offset), *mem_size, value as u64) };
} else {
let value = r.read_u64()?;
if *mem_size < 8 && value > width_mask(*mem_size) {
return Err(DecodeError::Malformed(
"native-sized unsigned integer out of range",
)
.into());
}
unsafe { write_uint(base.add(*offset), *mem_size, value) };
}
}
MemOp::Sequence(s) => {
let count = r.read_len(s.min_wire)?;
let (buffer, cap) = if count == 0 || s.stride == 0 {
(s.elem_align as *mut u8, count)
} else {
let layout = alloc::Layout::from_size_align(count * s.stride, s.elem_align)
.map_err(|_| {
CompactError::Decode(DecodeError::Malformed("sequence layout overflow"))
})?;
let buf = unsafe { alloc::alloc(layout) };
if buf.is_null() {
alloc::handle_alloc_error(layout);
}
(buf, count)
};
for i in 0..count {
if let Err(e) =
unsafe { decode_program(&s.element, r, buffer.add(i * s.stride), blocks) }
{
if cap != 0 && s.stride != 0 {
let layout =
alloc::Layout::from_size_align(cap * s.stride, s.elem_align)
.unwrap();
unsafe { alloc::dealloc(buffer, layout) };
}
return Err(e);
}
}
let list = unsafe { base.add(s.field_offset) };
unsafe { (s.thunks.from_raw_parts)(s.thunks.ctx, list, buffer, count, cap) };
}
MemOp::Set(s) => {
let count = r.read_len(s.min_wire)?;
let set = unsafe { base.add(s.field_offset) };
unsafe { (s.thunks.init_with_capacity)(s.thunks.ctx, set, count) };
for _ in 0..count {
let (scratch, layout) = alloc_scratch(s.elem_size, s.elem_align)?;
if let Err(e) = unsafe { decode_program(&s.element, r, scratch, blocks) } {
free_scratch(scratch, layout);
return Err(e);
}
let inserted = unsafe { (s.thunks.insert)(s.thunks.ctx, set, scratch) };
free_scratch(scratch, layout);
if !inserted {
return Err(CompactError::Decode(DecodeError::DuplicateElement));
}
}
}
MemOp::Bytes(b) => {
let count = r.read_len(b.stride.max(1))?;
if count > 0 {
skip_pad(r, b.elem_align)?;
}
let total = count * b.stride;
let src = r.read_slice(total)?;
if !unsafe { (b.validate)(src.as_ptr(), total) } {
return Err(CompactError::Decode(DecodeError::InvalidUtf8));
}
let (buffer, cap) = if total == 0 {
(b.elem_align as *mut u8, 0usize)
} else {
let layout =
alloc::Layout::from_size_align(total, b.elem_align).map_err(|_| {
CompactError::Decode(DecodeError::Malformed("bytes layout overflow"))
})?;
let buf = unsafe { alloc::alloc(layout) };
if buf.is_null() {
alloc::handle_alloc_error(layout);
}
unsafe { core::ptr::copy_nonoverlapping(src.as_ptr(), buf, total) };
(buf, count)
};
let list = unsafe { base.add(b.field_offset) };
unsafe { (b.thunks.from_raw_parts)(b.thunks.ctx, list, buffer, count, cap) };
}
MemOp::Borrow(b) => {
let count = r.read_len(b.stride.max(1))?;
if count > 0 {
skip_pad(r, b.elem_align)?;
}
let total = count * b.stride;
let src = r.read_slice(total)?;
let field = unsafe { base.add(b.field_offset) };
if !unsafe { (b.thunks.set_borrowed)(b.thunks.ctx, field, src.as_ptr(), count) } {
return Err(CompactError::Decode(DecodeError::InvalidUtf8));
}
}
MemOp::Option(o) => {
let option = unsafe { base.add(o.field_offset) };
match r.read_u8()? {
0 => unsafe { (o.thunks.init_none)(o.thunks.ctx, option) },
1 => {
let (scratch, layout) = if o.inner_size == 0 {
(o.inner_align as *mut u8, None)
} else {
let layout =
alloc::Layout::from_size_align(o.inner_size, o.inner_align)
.map_err(|_| {
CompactError::Decode(DecodeError::Malformed(
"option inner layout overflow",
))
})?;
let buf = unsafe { alloc::alloc(layout) };
if buf.is_null() {
alloc::handle_alloc_error(layout);
}
(buf, Some(layout))
};
if let Err(e) = unsafe { decode_program(&o.some, r, scratch, blocks) } {
if let Some(layout) = layout {
unsafe { alloc::dealloc(scratch, layout) };
}
return Err(e);
}
unsafe { (o.thunks.init_some)(o.thunks.ctx, option, scratch) };
if let Some(layout) = layout {
unsafe { alloc::dealloc(scratch, layout) };
}
}
b => return Err(CompactError::Decode(DecodeError::InvalidBool(b))),
}
}
MemOp::Enum(e) => {
let wire_index = r.read_u32()?;
let variant = match e.variants.iter().find(|v| v.wire_index == wire_index) {
Some(v) => v,
None if e.writer_only.contains(&wire_index) => {
return Err(CompactError::WriterOnlyVariant(wire_index));
}
None => return Err(CompactError::BadVariantIndex(wire_index)),
};
unsafe { write_uint(base.add(e.tag_offset), e.tag_width, variant.selector) };
unsafe { decode_program(&variant.payload, r, base, blocks)? };
}
MemOp::Map(m) => {
let n = r.read_len(1)?;
let map = unsafe { base.add(m.field_offset) };
unsafe { (m.thunks.init_with_capacity)(m.thunks.ctx, map, n) };
for _ in 0..n {
let (key_scratch, key_layout) = alloc_scratch(m.key_size, m.key_align)?;
let (value_scratch, value_layout) =
match alloc_scratch(m.value_size, m.value_align) {
Ok(s) => s,
Err(e) => {
free_scratch(key_scratch, key_layout);
return Err(e);
}
};
if let Err(e) = unsafe { decode_program(&m.key, r, key_scratch, blocks) } {
free_scratch(key_scratch, key_layout);
free_scratch(value_scratch, value_layout);
return Err(e);
}
if let Err(e) = unsafe { decode_program(&m.value, r, value_scratch, blocks) } {
free_scratch(key_scratch, key_layout);
free_scratch(value_scratch, value_layout);
return Err(e);
}
unsafe {
(m.thunks.insert)(m.thunks.ctx, map, key_scratch, value_scratch);
}
free_scratch(key_scratch, key_layout);
free_scratch(value_scratch, value_layout);
}
if unsafe { (m.thunks.len)(m.thunks.ctx, map) } != n {
return Err(CompactError::Decode(DecodeError::DuplicateKey));
}
}
MemOp::Dynamic { field_offset } => {
let v = read_value(r)?;
unsafe { core::ptr::write(base.add(*field_offset).cast::<Value>(), v) };
}
MemOp::Result(rs) => {
let idx = r.read_u32()?;
let result = unsafe { base.add(rs.field_offset) };
if idx == rs.ok_wire_index {
unsafe {
decode_into_via_init(
&rs.ok,
rs.ok_size,
rs.ok_align,
r,
InitTarget {
ctx: rs.thunks.ctx,
handle: result,
init: rs.thunks.init_ok,
},
blocks,
)?
};
} else if idx == rs.err_wire_index {
unsafe {
decode_into_via_init(
&rs.err,
rs.err_size,
rs.err_align,
r,
InitTarget {
ctx: rs.thunks.ctx,
handle: result,
init: rs.thunks.init_err,
},
blocks,
)?
};
} else {
return Err(CompactError::BadVariantIndex(idx));
}
}
MemOp::Pointer(p) => {
unsafe {
decode_into_via_init(
&p.pointee,
p.pointee_size,
p.pointee_align,
r,
InitTarget {
ctx: p.thunks.ctx,
handle: base.add(p.field_offset),
init: p.thunks.init,
},
blocks,
)?
};
}
MemOp::SkipWire(s) => phon_ir::ir::skip(r, s)?,
MemOp::Default(d) => {
unsafe { (d.default)(d.ctx, base.add(d.offset)) };
}
MemOp::Opaque(o) => {
let len = r.read_len(1)?;
let span = r.read_slice(len)?;
let slot = unsafe { base.add(o.field_offset) };
if !unsafe { (o.thunks.decode)(o.thunks.ctx, span.as_ptr(), len, slot) } {
return Err(CompactError::Decode(DecodeError::Malformed(
"opaque adapter rejected input",
)));
}
}
}
}
Ok(())
}
fn alloc_scratch(size: usize, align: usize) -> Result<(*mut u8, Option<alloc::Layout>)> {
if size == 0 {
Ok((align as *mut u8, None))
} else {
let layout = alloc::Layout::from_size_align(size, align).map_err(|_| {
CompactError::Decode(DecodeError::Malformed("map scratch layout overflow"))
})?;
let buf = unsafe { alloc::alloc(layout) };
if buf.is_null() {
alloc::handle_alloc_error(layout);
}
Ok((buf, Some(layout)))
}
}
fn free_scratch(buf: *mut u8, layout: Option<alloc::Layout>) {
if let Some(layout) = layout {
unsafe { alloc::dealloc(buf, layout) };
}
}
struct InitTarget {
ctx: *const (),
handle: *mut u8,
init: unsafe extern "C" fn(ctx: *const (), handle: *mut u8, value: *mut u8),
}
unsafe fn decode_into_via_init(
program: &MemProgram,
size: usize,
align: usize,
r: &mut Reader,
target: InitTarget,
blocks: &BTreeMap<SchemaId, MemProgram>,
) -> Result<()> {
let (scratch, layout) = alloc_scratch(size, align)?;
if let Err(e) = unsafe { decode_program(program, r, scratch, blocks) } {
free_scratch(scratch, layout);
return Err(e);
}
unsafe { (target.init)(target.ctx, target.handle, scratch) };
free_scratch(scratch, layout);
Ok(())
}
pub unsafe fn decode(
bytes: &[u8],
descriptor: &Descriptor,
descriptor_blocks: &HashMap<SchemaId, Descriptor>,
reg: &Registry,
base: *mut u8,
) -> Result<()> {
let lowered = lower_typed(descriptor, descriptor_blocks, reg)?;
unsafe { decode_with(&lowered, bytes, base) }
}
#[cfg(test)]
mod tests {
use super::*;
use core::mem::{MaybeUninit, align_of, offset_of, size_of};
use facet_value::{VArray, Value};
use phon_ir::{FieldAccess, Layout, SeqThunks, SequenceAccess};
use phon_schema::bytes::{write_i64, write_u64};
use phon_schema::{Schema, SchemaId, SchemaRef, primitive_id};
unsafe extern "C" fn vu32_from_raw_parts(
_ctx: *const (),
list: *mut u8,
ptr: *mut u8,
len: usize,
cap: usize,
) {
let v = unsafe { Vec::<u32>::from_raw_parts(ptr.cast::<u32>(), len, cap) };
unsafe { core::ptr::write(list.cast::<Vec<u32>>(), v) };
}
unsafe extern "C" fn vu32_len(_ctx: *const (), list: *const u8) -> usize {
unsafe { (*list.cast::<Vec<u32>>()).len() }
}
unsafe extern "C" fn vu32_data(_ctx: *const (), list: *const u8) -> *const u8 {
unsafe { (*list.cast::<Vec<u32>>()).as_ptr().cast::<u8>() }
}
fn vu32_thunks() -> SeqThunks {
SeqThunks {
ctx: core::ptr::null(),
from_raw_parts: vu32_from_raw_parts,
len: vu32_len,
data: vu32_data,
}
}
fn vec_u32_descriptor(schema: SchemaId) -> Descriptor {
Descriptor {
schema: SchemaRef::concrete(schema),
layout: Layout {
size: size_of::<Vec<u32>>(),
align: align_of::<Vec<u32>>(),
},
access: Access::Sequence(SequenceAccess {
element: Box::new(Descriptor {
schema: SchemaRef::concrete(primitive_id(Primitive::U32)),
layout: Layout { size: 4, align: 4 },
access: Access::Scalar,
}),
storage: SequenceStorage::Vtable(vu32_thunks()),
}),
}
}
#[repr(C)]
#[derive(Debug, PartialEq)]
struct NarrowNativeInts {
count: u32,
delta: i32,
}
fn narrow_native_int_schema(schema: SchemaId) -> Schema {
Schema {
id: schema,
type_params: Vec::new(),
kind: SchemaKind::Struct {
name: "NarrowNativeInts".to_string(),
fields: vec![
Field {
name: "count".to_string(),
schema: SchemaRef::concrete(primitive_id(Primitive::U64)),
required: true,
},
Field {
name: "delta".to_string(),
schema: SchemaRef::concrete(primitive_id(Primitive::I64)),
required: true,
},
],
},
}
}
fn narrow_native_int_descriptor(schema: SchemaId) -> Descriptor {
Descriptor {
schema: SchemaRef::concrete(schema),
layout: Layout {
size: size_of::<NarrowNativeInts>(),
align: align_of::<NarrowNativeInts>(),
},
access: Access::Record(RecordAccess {
fields: vec![
FieldAccess {
offset: offset_of!(NarrowNativeInts, count),
descriptor: Descriptor {
schema: SchemaRef::concrete(primitive_id(Primitive::U64)),
layout: Layout {
size: size_of::<u32>(),
align: align_of::<u32>(),
},
access: Access::Scalar,
},
default: None,
},
FieldAccess {
offset: offset_of!(NarrowNativeInts, delta),
descriptor: Descriptor {
schema: SchemaRef::concrete(primitive_id(Primitive::I64)),
layout: Layout {
size: size_of::<i32>(),
align: align_of::<i32>(),
},
access: Access::Scalar,
},
default: None,
},
],
construct: Construct::InPlace,
}),
}
}
#[test]
fn owned_vec_u32_roundtrips_and_matches_dynamic() {
let list = Schema {
id: SchemaId(1),
type_params: Vec::new(),
kind: SchemaKind::List {
element: SchemaRef::concrete(primitive_id(Primitive::U32)),
},
};
let reg = Registry::new([list]);
let desc = vec_u32_descriptor(SchemaId(1));
let values = [1u32, 2, 999, 0xDEAD_BEEF];
let mut arr = VArray::new();
for &v in &values {
arr.push(Value::from(v));
}
let dyn_bytes = compact::to_bytes(&Value::from(arr), SchemaId(1), ®).unwrap();
let v: Vec<u32> = values.to_vec();
let no_blocks = HashMap::new();
let typed_bytes = unsafe {
encode(
core::ptr::from_ref(&v).cast::<u8>(),
&desc,
&no_blocks,
®,
)
}
.unwrap();
assert_eq!(typed_bytes, dyn_bytes);
let mut slot = MaybeUninit::<Vec<u32>>::uninit();
unsafe {
decode(
&typed_bytes,
&desc,
&no_blocks,
®,
slot.as_mut_ptr().cast::<u8>(),
)
}
.unwrap();
let back = unsafe { slot.assume_init() };
assert_eq!(back, values.to_vec());
}
#[test]
fn native_int_memops_roundtrip_and_reject_out_of_range_values() {
let schema = SchemaId(1);
let reg = Registry::new([narrow_native_int_schema(schema)]);
let desc = narrow_native_int_descriptor(schema);
let no_blocks = HashMap::new();
let lowered = lower_typed(&desc, &no_blocks, ®).unwrap();
assert_eq!(lowered.program.len(), 2);
assert!(matches!(
lowered.program[0],
MemOp::NativeInt {
mem_size: 4,
signed: false,
..
}
));
assert!(matches!(
lowered.program[1],
MemOp::NativeInt {
mem_size: 4,
signed: true,
..
}
));
let value = NarrowNativeInts {
count: 0xCAFE_F00D,
delta: -42,
};
let bytes = unsafe { encode_with(&lowered, core::ptr::from_ref(&value).cast::<u8>()) };
let mut expected = Vec::new();
write_u64(&mut expected, u64::from(value.count));
write_i64(&mut expected, i64::from(value.delta));
assert_eq!(bytes, expected);
let mut slot = MaybeUninit::<NarrowNativeInts>::uninit();
unsafe { decode_with(&lowered, &bytes, slot.as_mut_ptr().cast::<u8>()) }.unwrap();
assert_eq!(unsafe { slot.assume_init() }, value);
let mut unsigned_out_of_range = Vec::new();
write_u64(&mut unsigned_out_of_range, u64::from(u32::MAX) + 1);
write_i64(&mut unsigned_out_of_range, 0);
let mut slot = MaybeUninit::<NarrowNativeInts>::uninit();
let err = unsafe {
decode_with(
&lowered,
&unsigned_out_of_range,
slot.as_mut_ptr().cast::<u8>(),
)
}
.unwrap_err();
assert!(matches!(
err,
CompactError::Decode(DecodeError::Malformed(
"native-sized unsigned integer out of range"
))
));
let mut signed_out_of_range = Vec::new();
write_u64(&mut signed_out_of_range, 0);
write_i64(&mut signed_out_of_range, i64::from(i32::MIN) - 1);
let mut slot = MaybeUninit::<NarrowNativeInts>::uninit();
let err = unsafe {
decode_with(
&lowered,
&signed_out_of_range,
slot.as_mut_ptr().cast::<u8>(),
)
}
.unwrap_err();
assert!(matches!(
err,
CompactError::Decode(DecodeError::Malformed(
"native-sized signed integer out of range"
))
));
}
#[test]
fn decode_compat_rejects_list_set_kind_mismatch() {
let element = SchemaRef::concrete(primitive_id(Primitive::U32));
let writer = Schema {
id: SchemaId(1),
type_params: Vec::new(),
kind: SchemaKind::Set {
element: element.clone(),
},
};
let reader = Schema {
id: SchemaId(2),
type_params: Vec::new(),
kind: SchemaKind::List { element },
};
let reg = Registry::new([writer, reader]);
let desc = vec_u32_descriptor(SchemaId(2));
let no_blocks = HashMap::new();
let typed = lower_decode(SchemaId(1), &desc, &no_blocks, ®);
assert!(
matches!(typed, Err(CompactError::Incompatible(_))),
"typed compat accepted Set writer for List reader: {typed:?}"
);
let dynamic = crate::plan::build_plan(SchemaId(1), SchemaId(2), ®);
assert!(
matches!(dynamic, Err(CompactError::Incompatible(_))),
"dynamic compat unexpectedly accepted Set writer for List reader"
);
}
}