use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
pub mod decode;
pub mod encode;
pub mod framing;
pub mod tags;
pub const MAX_BUFFERS: usize = 16_384;
pub const MAX_NODES: usize = 1_000_000;
pub const MAX_ARGS: usize = 4_096;
pub const MAX_STRING_LEN: usize = 1 << 20;
pub const MAX_OPAQUE_PAYLOAD_LEN: usize = MAX_ARGS * 1024;
pub const MAX_DECODE_DEPTH: u32 = 64;
pub const MAX_PROGRAM_BYTES: usize = 64 * 1024 * 1024;
pub(crate) struct Reader<'a> {
pub bytes: &'a [u8],
pub pos: usize,
pub depth: u32,
}
impl Program {
#[inline]
#[must_use]
pub fn to_wire(&self) -> Result<Vec<u8>, crate::error::Error> {
encode::to_wire(self).map_err(wire_err)
}
#[inline]
pub fn to_wire_into(&self, dst: &mut Vec<u8>) -> Result<(), crate::error::Error> {
encode::to_wire_into(self, dst).map_err(wire_err)
}
#[must_use]
#[inline]
pub fn to_bytes(&self) -> Vec<u8> {
match self.to_wire() {
Ok(bytes) => bytes,
Err(error) => {
tracing::error!(
error = %error,
"Program::to_bytes: wire encoding failed; returning empty bytes. \
Fix: call Program::to_wire and handle the validation error explicitly."
);
Vec::new()
}
}
}
#[inline]
#[must_use]
pub fn from_wire(bytes: &[u8]) -> Result<Self, crate::error::Error> {
if bytes.len() > MAX_PROGRAM_BYTES {
return Err(wire_err(format!(
"Fix: wire blob is {} bytes, exceeding the {}-byte IR framing cap. Reject this input or split the Program before serialization.",
bytes.len(),
MAX_PROGRAM_BYTES
)));
}
if bytes.len() >= framing::MAGIC.len() + 2
&& &bytes[..framing::MAGIC.len()] == framing::MAGIC
{
let version = u16::from_le_bytes([bytes[4], bytes[5]]);
if version != framing::WIRE_FORMAT_VERSION {
return Err(crate::error::Error::VersionMismatch {
expected: u32::from(framing::WIRE_FORMAT_VERSION),
found: u32::from(version),
});
}
}
decode::from_wire(bytes).map_err(wire_err)
}
#[inline]
#[must_use]
pub fn from_bytes(bytes: &[u8]) -> Result<Self, crate::error::Error> {
Self::from_wire(bytes)
}
#[must_use]
pub fn content_hash(&self) -> [u8; 32] {
self.fingerprint()
}
}
fn wire_err(message: String) -> crate::error::Error {
crate::error::Error::WireFormatValidation { message }
}
pub fn append_data_type_fingerprint(buf: &mut Vec<u8>, value: &DataType) -> Result<(), String> {
tags::data_type_tag::put_data_type(buf, value).map_err(String::from)
}
pub fn append_node_list_fingerprint(buf: &mut Vec<u8>, nodes: &[Node]) -> Result<(), String> {
encode::put_nodes(buf, nodes).map_err(String::from)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Node, Program};
#[test]
#[inline]
pub(crate) fn to_bytes_returns_empty_on_wire_error() {
let long_name = "x".repeat(MAX_STRING_LEN + 1);
let program = Program::wrapped(
vec![BufferDecl::storage(
&long_name,
0,
BufferAccess::ReadOnly,
DataType::U32,
)],
[1, 1, 1],
vec![],
);
assert!(program.to_wire().is_err());
assert!(program.to_bytes().is_empty());
}
#[test]
pub(crate) fn decode_depth_cap_rejects_deeply_nested_blocks() {
std::thread::Builder::new()
.stack_size(8 * 1024 * 1024)
.spawn(run_decode_depth_cap)
.expect("Fix: spawn test worker")
.join()
.expect("Fix: decode-depth-cap worker panicked");
}
fn run_decode_depth_cap() {
let mut inner = Node::Block(vec![]);
for _ in 0..MAX_DECODE_DEPTH {
inner = Node::Block(vec![inner]);
}
let program = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![inner],
);
let bytes = program
.to_wire()
.expect("Fix: building a (MAX_DEPTH+1)-nested program must still encode");
let decoded = Program::from_wire(&bytes);
assert!(
decoded.is_err(),
"decoding a program deeper than MAX_DECODE_DEPTH must fail; got Ok"
);
let err = decoded.unwrap_err().to_string();
assert!(
err.contains("Fix:"),
"depth-exceed error must carry a `Fix:` hint, got: {err}"
);
}
}
#[test]
pub(crate) fn opaque_payload_limit_is_symmetric() {
use crate::ir::{Expr, ExprNode};
use std::any::Any;
#[derive(Debug)]
struct BigOpaque(Vec<u8>);
impl ExprNode for BigOpaque {
fn extension_kind(&self) -> &'static str {
"test.big"
}
fn debug_identity(&self) -> &str {
"test.big"
}
fn result_type(&self) -> Option<DataType> {
Some(DataType::U32)
}
fn cse_safe(&self) -> bool {
false
}
fn stable_fingerprint(&self) -> [u8; 32] {
[0; 32]
}
fn validate_extension(&self) -> Result<(), String> {
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
fn wire_payload(&self) -> Vec<u8> {
self.0.clone()
}
}
let expr_ok = Expr::opaque(BigOpaque(vec![0u8; MAX_OPAQUE_PAYLOAD_LEN]));
let program_ok = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![Node::let_bind("_", expr_ok)],
);
assert!(
program_ok.to_wire().is_ok(),
"at-limit opaque payload ({MAX_OPAQUE_PAYLOAD_LEN} bytes) must encode"
);
let expr_over = Expr::opaque(BigOpaque(vec![0u8; MAX_OPAQUE_PAYLOAD_LEN + 1]));
let program_over = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![Node::let_bind("_", expr_over)],
);
let err = program_over
.to_wire()
.expect_err("opaque payload exceeding MAX_OPAQUE_PAYLOAD_LEN must fail at encode");
let msg = err.to_string();
assert!(
msg.contains("MAX_OPAQUE_PAYLOAD_LEN") || msg.contains(&MAX_OPAQUE_PAYLOAD_LEN.to_string()),
"error should mention the limit, got: {msg}"
);
}