vyre 0.1.0

GPU bytecode condition engine
Documentation
use crate::bytecode::{Instruction, Opcode};
use crate::error::{Error, Result};

/// Magic bytes for serialized programs.
pub const BYTECODE_MAGIC: [u8; 4] = *b"YBC0";

/// Compiled rule program.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Program {
    /// Rule-local instructions.
    pub instructions: Vec<Instruction>,
}

impl Program {
    /// Serialize this program to bytes.
    ///
    /// # Examples
    /// ```
    /// use rulefire::{Instruction, Opcode, Program};
    ///
    /// let program = Program { instructions: vec![Instruction::new(Opcode::Halt, 0)] };
    /// assert!(program.to_bytes().starts_with(b"YBC0"));
    /// ```
    pub fn to_bytes(&self) -> Vec<u8> {
        let mut out = Vec::with_capacity(8 + self.instructions.len() * 8);
        out.extend_from_slice(&BYTECODE_MAGIC);
        out.extend_from_slice(&(self.instructions.len() as u32).to_le_bytes());
        self.instructions.iter().for_each(|instruction| {
            out.extend_from_slice(&instruction.opcode.to_le_bytes());
            out.extend_from_slice(&instruction.operand.to_le_bytes());
        });
        out
    }

    /// Deserialize and validate a program from bytes.
    ///
    /// # Examples
    /// ```
    /// use rulefire::{Instruction, Opcode, Program};
    ///
    /// let bytes = Program { instructions: vec![Instruction::new(Opcode::Halt, 0)] }.to_bytes();
    /// assert_eq!(Program::from_bytes(&bytes).unwrap().instructions.len(), 1);
    /// ```
    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
        if bytes.len() < 8 || bytes[..4] != BYTECODE_MAGIC {
            return Err(Error::BytecodeValidation {
                message: "missing bytecode magic header".to_string(),
            });
        }
        let count = u32::from_le_bytes(bytes[4..8].try_into().map_err(|_| Error::BytecodeValidation {
            message: "truncated instruction count".to_string(),
        })?) as usize;
        let expected = 8 + count * std::mem::size_of::<Instruction>();
        if bytes.len() != expected {
            return Err(Error::BytecodeValidation {
                message: format!("bytecode size mismatch: got {}, expected {expected}", bytes.len()),
            });
        }

        let instructions = bytes[8..]
            .chunks_exact(8)
            .map(|chunk| {
                Ok(Instruction {
                    opcode: u32::from_le_bytes(chunk[..4].try_into().map_err(|_| Error::BytecodeValidation {
                        message: "invalid opcode bytes".to_string(),
                    })?),
                    operand: u32::from_le_bytes(chunk[4..8].try_into().map_err(|_| Error::BytecodeValidation {
                        message: "invalid operand bytes".to_string(),
                    })?),
                })
            })
            .collect::<Result<Vec<_>>>()?;
        let program = Self { instructions };
        program.validate()?;
        Ok(program)
    }

    /// Validate structural correctness.
    ///
    /// # Examples
    /// ```
    /// use rulefire::{Instruction, Opcode, Program};
    ///
    /// let program = Program { instructions: vec![Instruction::new(Opcode::Halt, 0)] };
    /// assert!(program.validate().is_ok());
    /// ```
    pub fn validate(&self) -> Result<()> {
        if self.instructions.is_empty() {
            return Err(Error::BytecodeValidation {
                message: "program is empty".to_string(),
            });
        }

        let mut loop_depth = 0u32;
        let mut has_halt = false;
        for (index, instruction) in self.instructions.iter().enumerate() {
            match instruction.kind()? {
                Opcode::ForAny | Opcode::ForAll | Opcode::ForN => {
                    let body_len = (instruction.operand >> 16) as usize;
                    let end_for_idx = index + body_len;
                    if body_len == 0 || end_for_idx >= self.instructions.len() {
                        return Err(Error::BytecodeValidation {
                            message: format!("FOR loop at {index} has invalid body_len {body_len}"),
                        });
                    }
                    if self.instructions[end_for_idx].kind()? != Opcode::EndFor {
                        return Err(Error::BytecodeValidation {
                            message: format!("FOR loop at {index} does not end with END_FOR"),
                        });
                    }
                    loop_depth = loop_depth.saturating_add(1);
                }
                Opcode::EndFor if loop_depth == 0 => {
                    return Err(Error::BytecodeValidation {
                        message: format!("END_FOR at {index} has no matching loop start"),
                    });
                }
                Opcode::EndFor => loop_depth -= 1,
                Opcode::Halt if index != self.instructions.len() - 1 => {
                    return Err(Error::BytecodeValidation {
                        message: format!("HALT must terminate the program, found at {index}"),
                    });
                }
                Opcode::Halt => has_halt = true,
                _ => {}
            }
        }

        if loop_depth != 0 {
            return Err(Error::BytecodeValidation {
                message: "unterminated FOR loop".to_string(),
            });
        }
        if !has_halt {
            return Err(Error::BytecodeValidation {
                message: "program must end with HALT instruction".to_string(),
            });
        }
        Ok(())
    }
}