vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::serial::wire::tags::{
    atomic_op_from_tag, bin_op_from_tag, data_type_from_tag, un_op_from_tag,
};
use crate::ir::serial::wire::{Expr, Node, Reader, MAX_ARGS, MAX_DECODE_DEPTH, MAX_NODES};
use crate::ir::DataType;

impl Reader<'_> {
    #[inline]
    pub fn nodes(&mut self) -> Result<Vec<Node>, String> {
        let count = self.bounded_len(MAX_NODES, "node count")?;
        let mut nodes = Vec::with_capacity(count);
        for _ in 0..count {
            nodes.push(self.node()?);
        }
        Ok(nodes)
    }

    #[inline]
    pub(crate) fn node(&mut self) -> Result<Node, String> {
        // Recursion guard: every `node()` enter increments depth, every
        // exit decrements. Nested decode stops at `MAX_DECODE_DEPTH`.
        if self.depth >= MAX_DECODE_DEPTH {
            return Err(format!(
                "Fix: IR wire format exceeds maximum decode depth {MAX_DECODE_DEPTH}; flatten deeply nested Block/If/Loop structures or reject this untrusted blob."
            ));
        }
        self.depth += 1;
        let result = self.node_inner();
        self.depth -= 1;
        result
    }

    fn node_inner(&mut self) -> Result<Node, String> {
        match self.u8()? {
            0 => Ok(Node::Let {
                name: self.string()?,
                value: self.expr()?,
            }),
            1 => Ok(Node::Assign {
                name: self.string()?,
                value: self.expr()?,
            }),
            2 => Ok(Node::Store {
                buffer: self.string()?,
                index: self.expr()?,
                value: self.expr()?,
            }),
            3 => Ok(Node::If {
                cond: self.expr()?,
                then: self.nodes()?,
                otherwise: self.nodes()?,
            }),
            4 => Ok(Node::Loop {
                var: self.string()?,
                from: self.expr()?,
                to: self.expr()?,
                body: self.nodes()?,
            }),
            5 => Ok(Node::Return),
            6 => Ok(Node::Block(self.nodes()?)),
            7 => Ok(Node::Barrier),
            tag => Err(format!(
                "Fix: unknown IR node tag {tag}; use a Program serializer compatible with this vyre version."
            )),
        }
    }

    #[inline]
    pub(crate) fn expr(&mut self) -> Result<Expr, String> {
        // Recursion guard for arbitrarily nested Expr trees (BinOp, UnOp,
        // Select, Cast, Call arg lists, etc). Shares the same depth
        // counter and budget as `node()` so a hostile blob can't evade
        // the limit by alternating statement and expression levels.
        if self.depth >= MAX_DECODE_DEPTH {
            return Err(format!(
                "Fix: IR wire format exceeds maximum decode depth {MAX_DECODE_DEPTH}; flatten deeply nested Expr trees or reject this untrusted blob."
            ));
        }
        self.depth += 1;
        let result = self.expr_inner();
        self.depth -= 1;
        result
    }

    #[inline]
    pub(crate) fn data_type(&mut self) -> Result<DataType, String> {
        let tag = self.u8()?;
        if tag == 12 {
            let element_size = usize::try_from(self.u32()?).map_err(|err| {
                format!(
                    "Fix: array element_size cannot fit usize on this target ({err}); decode this VIR0 blob on a supported target or reject it."
                )
            })?;
            return Ok(DataType::Array { element_size });
        }
        data_type_from_tag(tag)
    }

    fn expr_inner(&mut self) -> Result<Expr, String> {
        match self.u8()? {
            0 => Ok(Expr::LitU32(self.u32()?)),
            1 => Ok(Expr::LitI32(self.i32()?)),
            2 => Ok(Expr::LitBool(self.u8()? != 0)),
            15 => Ok(Expr::LitF32(f32::from_bits(self.u32()?))),
            3 => Ok(Expr::Var(self.string()?.into())),
            4 => Ok(Expr::Load {
                buffer: self.string()?.into(),
                index: Box::new(self.expr()?),
            }),
            5 => Ok(Expr::BufLen {
                buffer: self.string()?.into(),
            }),
            6 => Ok(Expr::InvocationId { axis: self.u8()? }),
            7 => Ok(Expr::WorkgroupId { axis: self.u8()? }),
            8 => Ok(Expr::LocalId { axis: self.u8()? }),
            9 => Ok(Expr::BinOp {
                op: bin_op_from_tag(self.u8()?)?,
                left: Box::new(self.expr()?),
                right: Box::new(self.expr()?),
            }),
            10 => Ok(Expr::UnOp {
                op: un_op_from_tag(self.u8()?)?,
                operand: Box::new(self.expr()?),
            }),
            11 => {
                let op_id = self.string()?;
                let count = self.bounded_len(MAX_ARGS, "call argument count")?;
                let mut args = Vec::with_capacity(count);
                for _ in 0..count {
                    args.push(self.expr()?);
                }
                Ok(Expr::Call { op_id, args })
            }
            12 => Ok(Expr::Select {
                cond: Box::new(self.expr()?),
                true_val: Box::new(self.expr()?),
                false_val: Box::new(self.expr()?),
            }),
            13 => Ok(Expr::Cast {
                target: self.data_type()?,
                value: Box::new(self.expr()?),
            }),
            14 => Ok(Expr::Atomic {
                op: atomic_op_from_tag(self.u8()?)?,
                buffer: self.string()?.into(),
                index: Box::new(self.expr()?),
                expected: if self.u8()? == 0 {
                    None
                } else {
                    Some(Box::new(self.expr()?))
                },
                value: Box::new(self.expr()?),
            }),
            16 => Ok(Expr::Fma {
                a: Box::new(self.expr()?),
                b: Box::new(self.expr()?),
                c: Box::new(self.expr()?),
            }),
            tag => Err(format!(
                "Fix: unknown IR expression tag {tag}; use a Program serializer compatible with this vyre version."
            )),
        }
    }
}