use crate::ir::serial::wire::tags::{
atomic_op_from_tag, bin_op_from_tag, data_type_from_tag, un_op_from_tag,
};
use crate::ir::serial::wire::{Expr, Node, Reader, MAX_ARGS, MAX_DECODE_DEPTH, MAX_NODES};
use crate::ir::DataType;
impl Reader<'_> {
#[inline]
pub fn nodes(&mut self) -> Result<Vec<Node>, String> {
let count = self.bounded_len(MAX_NODES, "node count")?;
let mut nodes = Vec::with_capacity(count);
for _ in 0..count {
nodes.push(self.node()?);
}
Ok(nodes)
}
#[inline]
pub(crate) fn node(&mut self) -> Result<Node, String> {
if self.depth >= MAX_DECODE_DEPTH {
return Err(format!(
"Fix: IR wire format exceeds maximum decode depth {MAX_DECODE_DEPTH}; flatten deeply nested Block/If/Loop structures or reject this untrusted blob."
));
}
self.depth += 1;
let result = self.node_inner();
self.depth -= 1;
result
}
fn node_inner(&mut self) -> Result<Node, String> {
match self.u8()? {
0 => Ok(Node::Let {
name: self.string()?,
value: self.expr()?,
}),
1 => Ok(Node::Assign {
name: self.string()?,
value: self.expr()?,
}),
2 => Ok(Node::Store {
buffer: self.string()?,
index: self.expr()?,
value: self.expr()?,
}),
3 => Ok(Node::If {
cond: self.expr()?,
then: self.nodes()?,
otherwise: self.nodes()?,
}),
4 => Ok(Node::Loop {
var: self.string()?,
from: self.expr()?,
to: self.expr()?,
body: self.nodes()?,
}),
5 => Ok(Node::Return),
6 => Ok(Node::Block(self.nodes()?)),
7 => Ok(Node::Barrier),
tag => Err(format!(
"Fix: unknown IR node tag {tag}; use a Program serializer compatible with this vyre version."
)),
}
}
#[inline]
pub(crate) fn expr(&mut self) -> Result<Expr, String> {
if self.depth >= MAX_DECODE_DEPTH {
return Err(format!(
"Fix: IR wire format exceeds maximum decode depth {MAX_DECODE_DEPTH}; flatten deeply nested Expr trees or reject this untrusted blob."
));
}
self.depth += 1;
let result = self.expr_inner();
self.depth -= 1;
result
}
#[inline]
pub(crate) fn data_type(&mut self) -> Result<DataType, String> {
let tag = self.u8()?;
if tag == 12 {
let element_size = usize::try_from(self.u32()?).map_err(|err| {
format!(
"Fix: array element_size cannot fit usize on this target ({err}); decode this VIR0 blob on a supported target or reject it."
)
})?;
return Ok(DataType::Array { element_size });
}
data_type_from_tag(tag)
}
fn expr_inner(&mut self) -> Result<Expr, String> {
match self.u8()? {
0 => Ok(Expr::LitU32(self.u32()?)),
1 => Ok(Expr::LitI32(self.i32()?)),
2 => Ok(Expr::LitBool(self.u8()? != 0)),
15 => Ok(Expr::LitF32(f32::from_bits(self.u32()?))),
3 => Ok(Expr::Var(self.string()?.into())),
4 => Ok(Expr::Load {
buffer: self.string()?.into(),
index: Box::new(self.expr()?),
}),
5 => Ok(Expr::BufLen {
buffer: self.string()?.into(),
}),
6 => Ok(Expr::InvocationId { axis: self.u8()? }),
7 => Ok(Expr::WorkgroupId { axis: self.u8()? }),
8 => Ok(Expr::LocalId { axis: self.u8()? }),
9 => Ok(Expr::BinOp {
op: bin_op_from_tag(self.u8()?)?,
left: Box::new(self.expr()?),
right: Box::new(self.expr()?),
}),
10 => Ok(Expr::UnOp {
op: un_op_from_tag(self.u8()?)?,
operand: Box::new(self.expr()?),
}),
11 => {
let op_id = self.string()?;
let count = self.bounded_len(MAX_ARGS, "call argument count")?;
let mut args = Vec::with_capacity(count);
for _ in 0..count {
args.push(self.expr()?);
}
Ok(Expr::Call { op_id, args })
}
12 => Ok(Expr::Select {
cond: Box::new(self.expr()?),
true_val: Box::new(self.expr()?),
false_val: Box::new(self.expr()?),
}),
13 => Ok(Expr::Cast {
target: self.data_type()?,
value: Box::new(self.expr()?),
}),
14 => Ok(Expr::Atomic {
op: atomic_op_from_tag(self.u8()?)?,
buffer: self.string()?.into(),
index: Box::new(self.expr()?),
expected: if self.u8()? == 0 {
None
} else {
Some(Box::new(self.expr()?))
},
value: Box::new(self.expr()?),
}),
16 => Ok(Expr::Fma {
a: Box::new(self.expr()?),
b: Box::new(self.expr()?),
c: Box::new(self.expr()?),
}),
tag => Err(format!(
"Fix: unknown IR expression tag {tag}; use a Program serializer compatible with this vyre version."
)),
}
}
}