vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
// Rule set program builder.

use crate::ops::rule::ast::{RuleCondition, RuleFormula};
use crate::ops::rule::{file_size_eq, file_size_gt, file_size_gte, file_size_lt, file_size_lte, file_size_ne, literal_false, literal_true, pattern_count_gt, pattern_count_gte, pattern_exists};
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};


/// `WORKGROUP_SIZE` constant.
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];

/// Build one IR program for an entire rule set.
///
/// Each tuple is `(formula, rule_id)`. The generated program writes each rule's
/// boolean verdict as `0` or `1` into `verdicts[rule_id]`.
///
/// # Examples
///
/// ```
/// use vyre::ops::rule::{build_rule_program, RuleCondition, RuleFormula};
///
/// let formula = RuleFormula::condition(RuleCondition::LiteralTrue);
/// let program = build_rule_program(&[(formula, 3)]);
/// assert!(program.has_buffer("verdicts"));
/// ```
#[must_use]
pub fn build_rule_program(rules: &[(RuleFormula, u32)]) -> Program {
    Program::new(rule_buffers(), WORKGROUP_SIZE, rule_nodes(rules))
}

pub fn rule_buffers() -> Vec<BufferDecl> {
    vec![
        BufferDecl::read("rule_ids", 0, DataType::U32),
        BufferDecl::read("pattern_ids", 1, DataType::U32),
        BufferDecl::read("rule_bitmaps", 2, DataType::U32),
        BufferDecl::read("rule_counts", 3, DataType::U32),
        BufferDecl::read("file_size", 4, DataType::U32),
        BufferDecl::output("verdicts", 5, DataType::U32),
    ]
}

pub fn rule_nodes(rules: &[(RuleFormula, u32)]) -> Vec<Node> {
    rules
        .iter()
        .map(|(formula, rule_id)| {
            Node::if_then(
                Expr::lt(Expr::u32(*rule_id), Expr::buf_len("verdicts")),
                vec![Node::store(
                    "verdicts",
                    Expr::u32(*rule_id),
                    formula_expr(formula, *rule_id),
                )],
            )
        })
        .collect()
}

pub fn formula_expr(formula: &RuleFormula, rule_id: u32) -> Expr {
    match formula {
        RuleFormula::Condition(condition) => condition_expr(condition, rule_id),
        RuleFormula::And(left, right) => {
            Expr::and(formula_expr(left, rule_id), formula_expr(right, rule_id))
        }
        RuleFormula::Or(left, right) => {
            Expr::or(formula_expr(left, rule_id), formula_expr(right, rule_id))
        }
        RuleFormula::Not(formula) => Expr::not(formula_expr(formula, rule_id)),
    }
}

pub fn condition_expr(condition: &RuleCondition, rule_id: u32) -> Expr {
    match condition {
        RuleCondition::PatternExists { pattern_id } => {
            call(pattern_exists::OP_ID, rule_id, *pattern_id, 0)
        }
        RuleCondition::PatternCountGt {
            pattern_id,
            threshold,
        } => call(pattern_count_gt::OP_ID, rule_id, *pattern_id, *threshold),
        RuleCondition::PatternCountGte {
            pattern_id,
            threshold,
        } => call(pattern_count_gte::OP_ID, rule_id, *pattern_id, *threshold),
        RuleCondition::FileSizeLt(threshold) => file_size_call(
            *threshold,
            file_size_lt::OP_ID,
            literal_true::OP_ID,
            rule_id,
        ),
        RuleCondition::FileSizeLte(threshold) => file_size_call(
            *threshold,
            file_size_lte::OP_ID,
            literal_true::OP_ID,
            rule_id,
        ),
        RuleCondition::FileSizeGt(threshold) => file_size_call(
            *threshold,
            file_size_gt::OP_ID,
            literal_false::OP_ID,
            rule_id,
        ),
        RuleCondition::FileSizeGte(threshold) => file_size_call(
            *threshold,
            file_size_gte::OP_ID,
            literal_false::OP_ID,
            rule_id,
        ),
        RuleCondition::FileSizeEq(threshold) => file_size_call(
            *threshold,
            file_size_eq::OP_ID,
            literal_false::OP_ID,
            rule_id,
        ),
        RuleCondition::FileSizeNe(threshold) => file_size_call(
            *threshold,
            file_size_ne::OP_ID,
            literal_true::OP_ID,
            rule_id,
        ),
        RuleCondition::LiteralTrue => call(literal_true::OP_ID, rule_id, 0, 0),
        RuleCondition::LiteralFalse => call(literal_false::OP_ID, rule_id, 0, 0),
    }
}

pub fn file_size_call(threshold: u64, op_id: &str, overflow_op_id: &str, rule_id: u32) -> Expr {
    let Ok(threshold) = u32::try_from(threshold) else {
        return call(overflow_op_id, rule_id, 0, 0);
    };
    call(op_id, rule_id, 0, threshold)
}

pub fn call(op_id: &str, rule_id: u32, pattern_id: u32, threshold: u32) -> Expr {
    Expr::call(
        op_id,
        vec![
            Expr::u32(rule_id),
            Expr::u32(pattern_id),
            pattern_state(pattern_id),
            pattern_count(pattern_id),
            Expr::load("file_size", Expr::u32(0)),
            Expr::u32(threshold),
        ],
    )
}

pub fn pattern_state(pattern_id: u32) -> Expr {
    Expr::select(
        Expr::lt(Expr::u32(pattern_id), Expr::buf_len("rule_bitmaps")),
        Expr::load("rule_bitmaps", Expr::u32(pattern_id)),
        Expr::u32(0),
    )
}

pub fn pattern_count(pattern_id: u32) -> Expr {
    Expr::select(
        Expr::lt(Expr::u32(pattern_id), Expr::buf_len("rule_counts")),
        Expr::load("rule_counts", Expr::u32(pattern_id)),
        Expr::u32(0),
    )
}