use super::put_node;
use crate::ir_inner::model::program::{BufferDecl, CacheLocality, MemoryKind};
use crate::ir_inner::model::types::{BufferAccess, DataType};
use crate::perf::PerfScope;
use crate::serial::wire::encode::WireEncodeErr;
use crate::serial::wire::framing::{
put_string, put_u32, put_u8, FLAG_OPAQUE_ENDIAN_FIXED, MAGIC, WIRE_FORMAT_VERSION,
};
use crate::serial::wire::tags::access_tag::access_tag;
use crate::serial::wire::Program;
use std::cell::RefCell;
const METADATA_OP_ID: &str = "vyre.program.metadata";
thread_local! {
static BODY_SCRATCH: RefCell<Vec<u8>> = RefCell::new(Vec::with_capacity(4096));
static NODE_PAYLOAD_SCRATCH: RefCell<Vec<u8>> = RefCell::new(Vec::with_capacity(256));
static REGION_PAYLOAD_SCRATCH: RefCell<RegionPayloadScratch> = RefCell::new(RegionPayloadScratch {
shape: Vec::with_capacity(16),
hints: Vec::with_capacity(16),
});
}
struct RegionPayloadScratch {
shape: Vec<u8>,
hints: Vec<u8>,
}
#[inline]
#[must_use]
pub fn to_wire(program: &Program) -> Result<Vec<u8>, String> {
let mut out = Vec::with_capacity(estimated_wire_capacity(program, program.buffers()));
to_wire_with_buffer_order_into(program, program.buffers(), &mut out).map_err(String::from)?;
Ok(out)
}
#[inline]
pub fn to_wire_into(program: &Program, dst: &mut Vec<u8>) -> Result<(), String> {
to_wire_with_buffer_order_into(program, program.buffers(), dst).map_err(String::from)
}
#[inline]
pub fn to_wire_with_buffer_order_into(
program: &Program,
buffers: &[BufferDecl],
dst: &mut Vec<u8>,
) -> Result<(), WireEncodeErr> {
let perf_scope = PerfScope::start("vyre-foundation", "foundation.wire.encode");
reject_non_roundtrippable_shapes(program, buffers)?;
let result = BODY_SCRATCH.with(|scratch| {
let mut body = scratch.borrow_mut();
body.clear();
body.reserve(estimated_body_capacity(program, buffers));
put_nodes_section(&mut body, program, buffers)?;
put_memory_regions(&mut body, buffers)?;
crate::serial::output_set::OutputSet::encode_from_buffers_into(buffers, &mut body)
.map_err(|s| WireEncodeErr::from(s))?;
let digest = blake3::hash(&body);
dst.reserve(MAGIC.len() + 2 + 2 + 32 + body.len());
dst.extend_from_slice(MAGIC);
dst.extend_from_slice(&WIRE_FORMAT_VERSION.to_le_bytes());
dst.extend_from_slice(&FLAG_OPAQUE_ENDIAN_FIXED.to_le_bytes());
dst.extend_from_slice(digest.as_bytes());
dst.extend_from_slice(&body);
Ok(())
});
let _ = perf_scope.finish();
result
}
#[inline]
fn estimated_wire_capacity(program: &Program, buffers: &[BufferDecl]) -> usize {
MAGIC
.len()
.saturating_add(2)
.saturating_add(2)
.saturating_add(32)
.saturating_add(estimated_body_capacity(program, buffers))
}
#[inline]
fn estimated_body_capacity(program: &Program, buffers: &[BufferDecl]) -> usize {
let buffer_name_bytes = buffers
.iter()
.map(|buffer| buffer.name().len())
.sum::<usize>();
program
.entry()
.len()
.saturating_mul(48)
.saturating_add(buffers.len().saturating_mul(40))
.saturating_add(buffer_name_bytes)
.saturating_add(buffers.len().saturating_mul(2))
.saturating_add(256)
}
fn reject_non_roundtrippable_shapes(
program: &Program,
buffers: &[BufferDecl],
) -> Result<(), WireEncodeErr> {
for (axis, size) in program.workgroup_size().into_iter().enumerate() {
if size == 0 {
return Err(WireEncodeErr::fmt_usize(
"Fix: workgroup_size[",
axis,
"] is 0. Encode only programs whose workgroup dimensions are >= 1.",
));
}
}
for buffer in buffers {
if buffer.count() == 0 && buffer.access() == BufferAccess::Workgroup {
let mut buf = arrayvec::ArrayString::<256>::new();
buf.push_str("Fix: workgroup buffer `");
buf.push_str(buffer.name());
buf.push_str("` has count 0. Encode only positive-length shared-memory buffers.");
return Err(WireEncodeErr::Dynamic(buf));
}
if buffer.count() == 0 && buffer.is_pipeline_live_out() {
let mut buf = arrayvec::ArrayString::<256>::new();
buf.push_str("Fix: live-out buffer `");
buf.push_str(buffer.name());
buf.push_str("` has count 0. Encode only positive-length externally-visible buffers.");
return Err(WireEncodeErr::Dynamic(buf));
}
if let Some(range) = buffer.output_byte_range() {
let elem_size = buffer.element().size_bytes().unwrap_or(0) as u64;
let count = buffer.count() as u64;
let full_size = if count == 0 {
u64::MAX
} else {
count.saturating_mul(elem_size)
};
let start = range.start as u64;
let end = range.end as u64;
if start > end {
let mut buf = arrayvec::ArrayString::<256>::new();
let mut tmp = itoa::Buffer::new();
buf.push_str("Fix: buffer `");
buf.push_str(buffer.name());
buf.push_str("` output byte range has start (");
buf.push_str(tmp.format(range.start));
buf.push_str(") > end (");
buf.push_str(tmp.format(range.end));
buf.push_str("). Encode only valid ranges.");
return Err(WireEncodeErr::Dynamic(buf));
}
if end > full_size && full_size != u64::MAX {
let mut buf = arrayvec::ArrayString::<256>::new();
let mut tmp = itoa::Buffer::new();
buf.push_str("Fix: buffer `");
buf.push_str(buffer.name());
buf.push_str("` output byte range end (");
buf.push_str(tmp.format(range.end));
buf.push_str(") exceeds full buffer size (");
buf.push_str(tmp.format(full_size));
buf.push_str("). Encode only ranges that fit within the declared buffer size.");
return Err(WireEncodeErr::Dynamic(buf));
}
}
}
Ok(())
}
fn put_nodes_section(
out: &mut Vec<u8>,
program: &Program,
buffers: &[BufferDecl],
) -> Result<(), WireEncodeErr> {
NODE_PAYLOAD_SCRATCH.with(|payload| {
let mut payload = payload.borrow_mut();
put_nodes_section_with_payload(out, program, buffers, &mut payload)
})
}
fn put_nodes_section_with_payload(
out: &mut Vec<u8>,
program: &Program,
buffers: &[BufferDecl],
payload: &mut Vec<u8>,
) -> Result<(), WireEncodeErr> {
put_leb_u64(
out,
u64::try_from(program.entry().len() + 1).map_err(|_| {
WireEncodeErr::static_msg(
"Fix: node count cannot fit u64; split the Program before serialization.",
)
})?,
);
payload.clear();
put_metadata_payload(payload, program, buffers)?;
put_node_record(out, METADATA_OP_ID, payload, &[])?;
for node in program.entry() {
payload.clear();
put_node(payload, node)?;
put_node_record(
out,
crate::ir_inner::model::node::node_op_id(node),
payload,
&[],
)?;
}
Ok(())
}
fn put_node_record(
out: &mut Vec<u8>,
op_id: &str,
payload: &[u8],
operands: &[u32],
) -> Result<(), WireEncodeErr> {
put_leb_str(out, op_id)?;
put_leb_u64(
out,
u64::try_from(payload.len()).map_err(|_| {
WireEncodeErr::static_msg("Fix: node payload length cannot fit u64; split the Program.")
})?,
);
out.extend_from_slice(payload);
put_leb_u64(
out,
u64::try_from(operands.len()).map_err(|_| {
WireEncodeErr::static_msg("Fix: node operand count cannot fit u64; split the Program.")
})?,
);
for operand in operands {
put_leb_u32(out, *operand);
}
Ok(())
}
fn put_metadata_payload(
out: &mut Vec<u8>,
program: &Program,
buffers: &[BufferDecl],
) -> Result<(), WireEncodeErr> {
out.extend_from_slice(b"VYRE-META");
match program.entry_op_id() {
Some(op_id) => {
put_u8(out, 1);
put_string(out, op_id)?;
}
None => put_u8(out, 0),
}
for size in program.workgroup_size() {
put_u32(out, size);
}
put_u8(out, u8::from(program.is_non_composable_with_self()));
put_leb_u64(
out,
u64::try_from(buffers.len()).map_err(|_| {
WireEncodeErr::static_msg(
"Fix: buffer metadata count cannot fit u64; split the Program.",
)
})?,
);
for buffer in buffers {
put_string(out, buffer.name())?;
put_u32(out, buffer.binding());
put_u32(out, buffer.count());
put_u8(out, u8::from(buffer.is_output()));
put_u8(out, u8::from(buffer.is_pipeline_live_out()));
match buffer.output_byte_range() {
Some(range) => {
put_u8(out, 1);
put_leb_u64(
out,
u64::try_from(range.start).map_err(|_| {
WireEncodeErr::static_msg(
"Fix: output range start cannot fit u64; split the output buffer.",
)
})?,
);
put_leb_u64(
out,
u64::try_from(range.end).map_err(|_| {
WireEncodeErr::static_msg(
"Fix: output range end cannot fit u64; split the output buffer.",
)
})?,
);
}
None => put_u8(out, 0),
}
put_hints_payload(out, buffer.hints())?;
}
Ok(())
}
fn put_memory_regions(out: &mut Vec<u8>, buffers: &[BufferDecl]) -> Result<(), WireEncodeErr> {
REGION_PAYLOAD_SCRATCH.with(|scratch| {
let mut scratch = scratch.borrow_mut();
let RegionPayloadScratch { shape, hints } = &mut *scratch;
put_memory_regions_with_scratch(out, buffers, shape, hints)
})
}
fn put_memory_regions_with_scratch(
out: &mut Vec<u8>,
buffers: &[BufferDecl],
shape: &mut Vec<u8>,
hints: &mut Vec<u8>,
) -> Result<(), WireEncodeErr> {
put_leb_u64(
out,
u64::try_from(buffers.len()).map_err(|_| {
WireEncodeErr::static_msg("Fix: memory-region count cannot fit u64; split the Program.")
})?,
);
for (index, buffer) in buffers.iter().enumerate() {
put_leb_u32(
out,
u32::try_from(index).map_err(|_| {
WireEncodeErr::fmt_usize(
"Fix: memory-region id ",
index,
" cannot fit u32; split the Program.",
)
})?,
);
put_u8(out, memory_kind_tag(buffer.kind()));
put_u8(
out,
access_tag(buffer.access()).map_err(|s| WireEncodeErr::from(s))?,
);
put_u8(out, data_type_tag(&buffer.element())?);
put_u8(out, 0);
shape.clear();
put_leb_u64(shape, u64::from(buffer.count()));
if let DataType::Array { element_size } = buffer.element() {
put_leb_u64(
shape,
u64::try_from(element_size).map_err(|_| {
WireEncodeErr::static_msg(
"Fix: array element size cannot fit u64; cap the element size.",
)
})?,
);
}
if let DataType::Opaque(id) = buffer.element() {
put_leb_u64(shape, u64::from(id.as_u32()));
}
put_leb_u64(
out,
u64::try_from(shape.len()).map_err(|_| {
WireEncodeErr::static_msg(
"Fix: shape payload length cannot fit u64; split the Program.",
)
})?,
);
out.extend_from_slice(&shape);
hints.clear();
put_hints_payload(hints, buffer.hints())?;
put_leb_u64(
out,
u64::try_from(hints.len()).map_err(|_| {
WireEncodeErr::static_msg(
"Fix: hints payload length cannot fit u64; split the Program.",
)
})?,
);
out.extend_from_slice(&hints);
}
Ok(())
}
fn put_hints_payload(
out: &mut Vec<u8>,
hints: crate::ir::MemoryHints,
) -> Result<(), WireEncodeErr> {
match hints.coalesce_axis {
Some(axis) => {
put_u8(out, 1);
put_u8(out, axis);
}
None => put_u8(out, 0),
}
put_u32(out, hints.preferred_alignment);
put_u8(
out,
match hints.cache_locality {
CacheLocality::Streaming => 0,
CacheLocality::Temporal => 1,
CacheLocality::Random => 2,
},
);
Ok(())
}
fn memory_kind_tag(kind: MemoryKind) -> u8 {
match kind {
MemoryKind::Global => 0,
MemoryKind::Shared => 1,
MemoryKind::Uniform => 2,
MemoryKind::Local => 3,
MemoryKind::Readonly => 4,
MemoryKind::Push => 5,
MemoryKind::Persistent => 6,
}
}
fn data_type_tag(value: &DataType) -> Result<u8, WireEncodeErr> {
Ok(match value {
DataType::U32 => 0x01,
DataType::I32 => 0x02,
DataType::U64 => 0x03,
DataType::Vec2U32 => 0x04,
DataType::Vec4U32 => 0x05,
DataType::Bool => 0x06,
DataType::Bytes => 0x07,
DataType::Array { .. } => 0x08,
DataType::F16 => 0x09,
DataType::BF16 => 0x0A,
DataType::F32 => 0x0B,
DataType::F64 => 0x0C,
DataType::Tensor => 0x0D,
DataType::U8 => 0x0E,
DataType::U16 => 0x0F,
DataType::I8 => 0x10,
DataType::I16 => 0x11,
DataType::I64 => 0x12,
DataType::Handle(_) => 0x13,
DataType::Vec { .. } => 0x14,
DataType::TensorShaped { .. } => 0x15,
DataType::Opaque(_) => 0x80,
_ => {
return Err(WireEncodeErr::static_msg(
"Fix: unknown DataType variant cannot be serialized into VYRE wire format.",
));
}
})
}
fn put_leb_str(out: &mut Vec<u8>, value: &str) -> Result<(), WireEncodeErr> {
put_leb_u64(
out,
u64::try_from(value.len()).map_err(|_| {
WireEncodeErr::static_msg("Fix: string length cannot fit u64; shorten the identifier.")
})?,
);
out.extend_from_slice(value.as_bytes());
Ok(())
}
fn put_leb_u32(out: &mut Vec<u8>, value: u32) {
put_leb_u64(out, u64::from(value));
}
fn put_leb_u64(out: &mut Vec<u8>, mut value: u64) {
loop {
let mut byte = (value & 0x7F) as u8;
value >>= 7;
if value != 0 {
byte |= 0x80;
}
out.push(byte);
if value == 0 {
break;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
#[test]
fn to_wire_into_appends_byte_for_byte() {
let program = Program::wrapped(
vec![
BufferDecl::read_write("a", 0, DataType::U32),
BufferDecl::read("b", 1, DataType::U32),
],
[64, 1, 1],
vec![
Node::let_bind("idx", Expr::gid_x()),
Node::store("a", Expr::var("idx"), Expr::load("b", Expr::var("idx"))),
],
);
let mut separate = Vec::new();
for _ in 0..100 {
separate.extend_from_slice(&to_wire(&program).unwrap());
}
let mut reused = Vec::new();
for _ in 0..100 {
to_wire_into(&program, &mut reused).unwrap();
}
assert_eq!(
separate, reused,
"100 separate to_wire calls must match 100 to_wire_into calls into the same buffer"
);
}
#[test]
fn encode_section_helpers_reuse_caller_scratch() {
let program = Program::wrapped(
vec![
BufferDecl::read_write("a", 0, DataType::U32).with_count(64),
BufferDecl::read("b", 1, DataType::U32).with_count(64),
BufferDecl::read("mask", 2, DataType::Bool).with_count(64),
],
[64, 1, 1],
vec![
Node::let_bind("idx", Expr::gid_x()),
Node::store("a", Expr::var("idx"), Expr::load("b", Expr::var("idx"))),
],
);
let mut out = Vec::with_capacity(2048);
let mut payload = Vec::with_capacity(2048);
put_nodes_section_with_payload(&mut out, &program, program.buffers(), &mut payload)
.expect("node section must encode");
let payload_ptr = payload.as_ptr();
let payload_capacity = payload.capacity();
out.clear();
put_nodes_section_with_payload(&mut out, &program, program.buffers(), &mut payload)
.expect("node section must encode a second time");
assert_eq!(payload.as_ptr(), payload_ptr);
assert_eq!(payload.capacity(), payload_capacity);
let mut shape = Vec::with_capacity(64);
let mut hints = Vec::with_capacity(64);
put_memory_regions_with_scratch(&mut out, program.buffers(), &mut shape, &mut hints)
.expect("memory regions must encode");
let shape_ptr = shape.as_ptr();
let hints_ptr = hints.as_ptr();
let shape_capacity = shape.capacity();
let hints_capacity = hints.capacity();
out.clear();
put_memory_regions_with_scratch(&mut out, program.buffers(), &mut shape, &mut hints)
.expect("memory regions must encode a second time");
assert_eq!(shape.as_ptr(), shape_ptr);
assert_eq!(hints.as_ptr(), hints_ptr);
assert_eq!(shape.capacity(), shape_capacity);
assert_eq!(hints.capacity(), hints_capacity);
}
#[test]
fn output_set_is_serialized_and_validated() {
let program = Program::wrapped(
vec![
BufferDecl::read("input", 0, DataType::U32).with_count(4),
BufferDecl::output("out", 1, DataType::U32).with_count(4),
BufferDecl::read_write("scratch_out", 2, DataType::U32).with_count(4),
],
[64, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::load("input", Expr::u32(0)),
)],
);
let encoded = to_wire(&program).expect("Fix: output-set program must encode");
assert_eq!(
&encoded[encoded.len() - 3..],
&[2, 1, 2],
"OutputSet must list the two writable buffer indices in declaration order"
);
let decoded =
Program::from_wire(&encoded).expect("Fix: encoded output-set program must decode");
assert_eq!(decoded.output_buffer_indices(), &[1, 2]);
let mut tampered = encoded;
let last = tampered.len() - 1;
tampered[last] = 0;
let digest = blake3::hash(&tampered[40..]);
tampered[8..40].copy_from_slice(digest.as_bytes());
let err = Program::from_wire(&tampered)
.expect_err("tampered output-set must be rejected")
.to_string();
assert!(
err.contains("output-set"),
"decode error must name the corrupt OutputSet: {err}"
);
}
}