vyre 0.3.0

GPU bytecode condition engine
Documentation
use std::path::Path;

use crate::bytecode::Program;
use crate::error::{Error, Result};
use crate::index::CompiledRuleIndex;
use crate::pattern::{CompiledPattern, PatternMapping, RuleEntry};

impl CompiledRuleIndex {
    /// Save the compiled index to disk.
    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
        let bytes = self.to_bytes()?;
        std::fs::write(path.as_ref(), bytes).map_err(|e| Error::Serialization {
            message: format!("failed to write index: {e}"),
        })
    }

    /// Load a compiled index from disk.
    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
        let bytes = std::fs::read(path.as_ref()).map_err(|e| Error::Serialization {
            message: format!("failed to read index: {e}"),
        })?;
        Self::from_bytes(&bytes)
    }

    pub(crate) fn to_bytes(&self) -> Result<Vec<u8>> {
        let mut out = Vec::new();
        out.extend_from_slice(b"RFIDX001");
        write_rules(&mut out, &self.rules);
        write_patterns(&mut out, &self.patterns);
        write_mapping(&mut out, &self.mapping);
        write_programs(&mut out, &self.programs);
        write_pattern_set(&mut out, &self.patterns);
        Ok(out)
    }

    pub(crate) fn from_bytes(bytes: &[u8]) -> Result<Self> {
        const MAGIC: &[u8] = b"RFIDX001";
        if bytes.len() < MAGIC.len() || &bytes[..MAGIC.len()] != MAGIC {
            return Err(Error::Serialization { message: "invalid magic header".to_string() });
        }

        let mut reader = Reader::new(&bytes[MAGIC.len()..]);
        let rules = reader.read_rules()?;
        let patterns = reader.read_patterns()?;
        let mapping = reader.read_mapping()?;
        let programs = reader.read_programs()?;
        let pattern_set = reader.read_pattern_set()?.build()?;
        Ok(Self::build(rules, patterns, mapping, programs, pattern_set))
    }
}

fn write_rules(out: &mut Vec<u8>, rules: &[RuleEntry]) {
    out.extend_from_slice(&(rules.len() as u64).to_le_bytes());
    rules.iter().for_each(|rule| {
        write_string(out, &rule.name);
        out.extend_from_slice(&(rule.tags.len() as u64).to_le_bytes());
        rule.tags.iter().for_each(|tag| write_string(out, tag));
        out.extend_from_slice(&(rule.strings.len() as u64).to_le_bytes());
        rule.strings.iter().for_each(|value| write_string(out, value));
    });
}

fn write_patterns(out: &mut Vec<u8>, patterns: &[CompiledPattern]) {
    out.extend_from_slice(&(patterns.len() as u64).to_le_bytes());
    patterns.iter().for_each(|pattern| {
        out.extend_from_slice(&pattern.pattern_id.to_le_bytes());
        out.extend_from_slice(&pattern.rule_id.to_le_bytes());
        out.extend_from_slice(&pattern.string_id.to_le_bytes());
        write_string(out, &pattern.identifier);
        write_string(out, &pattern.source);
        out.push(pattern.is_regex as u8);
    });
}

fn write_mapping(out: &mut Vec<u8>, mapping: &PatternMapping) {
    out.extend_from_slice(&(mapping.pattern_to_rules.len() as u64).to_le_bytes());
    mapping.pattern_to_rules.iter().for_each(|entry| {
        out.extend_from_slice(&entry[0].to_le_bytes());
        out.extend_from_slice(&entry[1].to_le_bytes());
    });
    for values in [&mapping.rule_list, &mapping.string_local_ids] {
        out.extend_from_slice(&(values.len() as u64).to_le_bytes());
        values.iter().for_each(|value| out.extend_from_slice(&value.to_le_bytes()));
    }
}

fn write_programs(out: &mut Vec<u8>, programs: &[Program]) {
    out.extend_from_slice(&(programs.len() as u64).to_le_bytes());
    programs.iter().for_each(|program| {
        let bytes = program.to_bytes();
        out.extend_from_slice(&(bytes.len() as u64).to_le_bytes());
        out.extend_from_slice(&bytes);
    });
}

fn write_pattern_set(out: &mut Vec<u8>, patterns: &[CompiledPattern]) {
    out.extend_from_slice(&(patterns.len() as u64).to_le_bytes());
    patterns.iter().for_each(|pattern| {
        write_string(out, &pattern.source);
        out.push(pattern.is_regex as u8);
    });
}

fn write_string(out: &mut Vec<u8>, value: &str) {
    out.extend_from_slice(&(value.len() as u64).to_le_bytes());
    out.extend_from_slice(value.as_bytes());
}

struct Reader<'a> {
    bytes: &'a [u8],
    cursor: usize,
}

impl<'a> Reader<'a> {
    fn new(bytes: &'a [u8]) -> Self { Self { bytes, cursor: 0 } }
    fn read_rules(&mut self) -> Result<Vec<RuleEntry>> { self.read_rule_vec() }
    fn read_patterns(&mut self) -> Result<Vec<CompiledPattern>> { self.read_pattern_vec() }
    fn read_mapping(&mut self) -> Result<PatternMapping> { self.read_pattern_mapping() }
    fn read_programs(&mut self) -> Result<Vec<Program>> { self.read_program_vec() }
    fn read_pattern_set(&mut self) -> Result<warpstate::PatternSetBuilder> { self.read_pattern_builder() }

    fn read_u64(&mut self) -> Result<u64> {
        self.bytes.get(self.cursor..self.cursor + 8).ok_or_else(|| Error::Serialization {
            message: "truncated u64 read".to_string(),
        }).and_then(|slice| {
            self.cursor += 8;
            Ok(u64::from_le_bytes(slice.try_into().map_err(|_| Error::Serialization {
                message: "invalid u64".to_string(),
            })?))
        })
    }

    fn read_u32(&mut self) -> Result<u32> {
        self.bytes.get(self.cursor..self.cursor + 4).ok_or_else(|| Error::Serialization {
            message: "truncated u32 read".to_string(),
        }).and_then(|slice| {
            self.cursor += 4;
            Ok(u32::from_le_bytes(slice.try_into().map_err(|_| Error::Serialization {
                message: "invalid u32".to_string(),
            })?))
        })
    }

    fn read_string(&mut self) -> Result<String> {
        let len = self.read_u64()? as usize;
        let slice = self.bytes.get(self.cursor..self.cursor + len).ok_or_else(|| Error::Serialization {
            message: "truncated string read".to_string(),
        })?;
        self.cursor += len;
        String::from_utf8(slice.to_vec()).map_err(|_| Error::Serialization {
            message: "invalid UTF-8 in string".to_string(),
        })
    }

    fn read_bool(&mut self) -> Result<bool> {
        let value = *self.bytes.get(self.cursor).ok_or_else(|| Error::Serialization {
            message: "truncated bool read".to_string(),
        })?;
        self.cursor += 1;
        Ok(value != 0)
    }

    fn read_rule_vec(&mut self) -> Result<Vec<RuleEntry>> {
        (0..self.read_u64()? as usize).map(|_| {
            let name = self.read_string()?;
            let tags = (0..self.read_u64()? as usize).map(|_| self.read_string()).collect::<Result<Vec<_>>>()?;
            let strings = (0..self.read_u64()? as usize).map(|_| self.read_string()).collect::<Result<Vec<_>>>()?;
            Ok(RuleEntry { name, tags, strings })
        }).collect()
    }

    fn read_pattern_vec(&mut self) -> Result<Vec<CompiledPattern>> {
        (0..self.read_u64()? as usize).map(|_| {
            Ok(CompiledPattern {
                pattern_id: self.read_u32()?,
                rule_id: self.read_u32()?,
                string_id: self.read_u32()?,
                identifier: self.read_string()?,
                source: self.read_string()?,
                is_regex: self.read_bool()?,
            })
        }).collect()
    }

    fn read_pattern_mapping(&mut self) -> Result<PatternMapping> {
        let pattern_to_rules = (0..self.read_u64()? as usize).map(|_| Ok([self.read_u32()?, self.read_u32()?])).collect::<Result<Vec<_>>>()?;
        let rule_list = (0..self.read_u64()? as usize).map(|_| self.read_u32()).collect::<Result<Vec<_>>>()?;
        let string_local_ids = (0..self.read_u64()? as usize).map(|_| self.read_u32()).collect::<Result<Vec<_>>>()?;
        Ok(PatternMapping { pattern_to_rules, rule_list, string_local_ids })
    }

    fn read_program_vec(&mut self) -> Result<Vec<Program>> {
        (0..self.read_u64()? as usize).map(|_| {
            let len = self.read_u64()? as usize;
            let bytes = self.bytes.get(self.cursor..self.cursor + len).ok_or_else(|| Error::Serialization {
                message: "truncated program data".to_string(),
            })?;
            self.cursor += len;
            Program::from_bytes(bytes)
        }).collect()
    }

    fn read_pattern_builder(&mut self) -> Result<warpstate::PatternSetBuilder> {
        let mut builder = warpstate::PatternSet::builder();
        for _ in 0..self.read_u64()? as usize {
            let source = self.read_string()?;
            builder = if self.read_bool()? { builder.regex(&source) } else { builder.literal_bytes(source.into_bytes()) };
        }
        Ok(builder)
    }
}