use crate::error::{Error, Result};
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ParallelCondition {
pub opcode: u32,
pub lhs: u32,
pub rhs: u32,
pub extra: u32,
}
impl ParallelCondition {
pub const fn new(kind: ParallelConditionKind, lhs: u32, rhs: u32) -> Self {
Self {
opcode: kind as u32,
lhs,
rhs,
extra: 0,
}
}
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParallelConditionKind {
PatternExists = 1,
PatternCountGt = 2,
PatternCountGte = 3,
FileSizeLt = 4,
FileSizeLte = 5,
FileSizeGt = 6,
FileSizeGte = 7,
FileSizeEq = 8,
FileSizeNe = 9,
LiteralTrue = 10,
LiteralFalse = 11,
}
impl ParallelConditionKind {
pub fn from_u32(value: u32) -> Result<Self> {
match value {
1 => Ok(Self::PatternExists),
2 => Ok(Self::PatternCountGt),
3 => Ok(Self::PatternCountGte),
4 => Ok(Self::FileSizeLt),
5 => Ok(Self::FileSizeLte),
6 => Ok(Self::FileSizeGt),
7 => Ok(Self::FileSizeGte),
8 => Ok(Self::FileSizeEq),
9 => Ok(Self::FileSizeNe),
10 => Ok(Self::LiteralTrue),
11 => Ok(Self::LiteralFalse),
_ => Err(Error::BytecodeValidation {
message: format!("Fix: use a supported parallel condition opcode, got {value}"),
}),
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FormulaInstruction {
pub opcode: u32,
pub operand: u32,
}
impl FormulaInstruction {
pub const fn push_result(condition_index: u32) -> Self {
Self {
opcode: FormulaOp::PushResult as u32,
operand: condition_index,
}
}
pub const fn op(opcode: FormulaOp) -> Self {
Self {
opcode: opcode as u32,
operand: 0,
}
}
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FormulaOp {
PushResult = 1,
And = 2,
Or = 3,
Not = 4,
}
impl FormulaOp {
pub fn from_u32(value: u32) -> Result<Self> {
match value {
1 => Ok(Self::PushResult),
2 => Ok(Self::And),
3 => Ok(Self::Or),
4 => Ok(Self::Not),
_ => Err(Error::BytecodeValidation {
message: format!("Fix: use a supported postfix formula opcode, got {value}"),
}),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ExecutionStrategy {
#[default]
Sequential,
Parallel,
}
pub fn build_parallel_eval_shader(max_conditions: u32, max_formula_ops: u32) -> String {
let workgroup_size = max_conditions.max(1);
let template = r#"
const MAX_CONDITIONS: u32 = __MAX_CONDITIONS__u;
const MAX_FORMULA_OPS: u32 = __MAX_FORMULA_OPS__u;
struct ParallelCondition {
opcode: u32,
lhs: u32,
rhs: u32,
extra: u32,
};
struct FormulaInstruction {
opcode: u32,
operand: u32,
};
struct FileContext {
file_size: u32,
entropy_bucket: u32,
magic_u32: u32,
is_pe: u32,
is_dll: u32,
is_64bit: u32,
has_signature: u32,
num_sections: u32,
num_imports: u32,
entry_point_rva: u32,
unique_pattern_count: u32,
total_match_count: u32,
};
struct Params {
rule_count: u32,
max_patterns: u32,
_reserved0: u32,
_reserved1: u32,
};
@group(0) @binding(0) var<storage, read> rule_condition_spans: array<vec2<u32>>;
@group(0) @binding(1) var<storage, read> rule_formula_spans: array<vec2<u32>>;
@group(0) @binding(2) var<storage, read> conditions: array<ParallelCondition>;
@group(0) @binding(3) var<storage, read> formula: array<FormulaInstruction>;
@group(0) @binding(4) var<storage, read> rule_bitmaps: array<u32>;
@group(0) @binding(5) var<storage, read> rule_counts: array<u32>;
@group(0) @binding(6) var<uniform> params: Params;
@group(0) @binding(7) var<uniform> file_ctx: FileContext;
@group(0) @binding(8) var<storage, read_write> verdicts: array<u32>;
var<workgroup> condition_results: array<u32, __MAX_CONDITIONS__>;
var<workgroup> reduce_stack: array<u32, __MAX_FORMULA_OPS__>;
fn pattern_exists(rule_id: u32, pattern_id: u32) -> bool {
if (pattern_id >= params.max_patterns) {
return false;
}
let word = pattern_id / 32u;
let bit = 1u << (pattern_id % 32u);
return (rule_bitmaps[rule_id * 8u + word] & bit) != 0u;
}
fn pattern_count(rule_id: u32, pattern_id: u32) -> u32 {
if (pattern_id >= params.max_patterns) {
return 0u;
}
return rule_counts[rule_id * params.max_patterns + pattern_id];
}
fn eval_condition(rule_id: u32, condition: ParallelCondition) -> u32 {
switch condition.opcode {
case 1u: { return select(0u, 1u, pattern_exists(rule_id, condition.lhs)); }
case 2u: { return select(0u, 1u, pattern_count(rule_id, condition.lhs) > condition.rhs); }
case 3u: { return select(0u, 1u, pattern_count(rule_id, condition.lhs) >= condition.rhs); }
case 4u: { return select(0u, 1u, file_ctx.file_size < condition.rhs); }
case 5u: { return select(0u, 1u, file_ctx.file_size <= condition.rhs); }
case 6u: { return select(0u, 1u, file_ctx.file_size > condition.rhs); }
case 7u: { return select(0u, 1u, file_ctx.file_size >= condition.rhs); }
case 8u: { return select(0u, 1u, file_ctx.file_size == condition.rhs); }
case 9u: { return select(0u, 1u, file_ctx.file_size != condition.rhs); }
case 10u: { return 1u; }
case 11u: { return 0u; }
default: { return 0u; }
}
}
@compute @workgroup_size(__WORKGROUP_SIZE__)
fn main(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(local_invocation_id) local_invocation_id: vec3<u32>,
) {
let rule_id = workgroup_id.x;
let lane = local_invocation_id.x;
if (rule_id >= params.rule_count) {
return;
}
let condition_span = rule_condition_spans[rule_id];
let condition_start = condition_span.x;
let condition_count = condition_span.y;
if (lane < MAX_CONDITIONS) {
condition_results[lane] = 0u;
}
if (lane < condition_count && lane < MAX_CONDITIONS) {
condition_results[lane] = eval_condition(rule_id, conditions[condition_start + lane]);
}
workgroupBarrier();
if (lane != 0u) {
return;
}
let formula_span = rule_formula_spans[rule_id];
let formula_start = formula_span.x;
let formula_count = formula_span.y;
var sp = 0u;
for (var idx = 0u; idx < formula_count; idx = idx + 1u) {
let inst = formula[formula_start + idx];
switch inst.opcode {
case 1u: {
if (inst.operand >= condition_count || sp >= MAX_FORMULA_OPS) {
verdicts[rule_id] = 0u;
return;
}
reduce_stack[sp] = condition_results[inst.operand];
sp = sp + 1u;
}
case 2u: {
if (sp < 2u) {
verdicts[rule_id] = 0u;
return;
}
sp = sp - 1u;
reduce_stack[sp - 1u] = select(0u, 1u, reduce_stack[sp - 1u] != 0u && reduce_stack[sp] != 0u);
}
case 3u: {
if (sp < 2u) {
verdicts[rule_id] = 0u;
return;
}
sp = sp - 1u;
reduce_stack[sp - 1u] = select(0u, 1u, reduce_stack[sp - 1u] != 0u || reduce_stack[sp] != 0u);
}
case 4u: {
if (sp == 0u) {
verdicts[rule_id] = 0u;
return;
}
reduce_stack[sp - 1u] = select(1u, 0u, reduce_stack[sp - 1u] != 0u);
}
default: {
verdicts[rule_id] = 0u;
return;
}
}
}
verdicts[rule_id] = select(0u, 1u, sp == 1u && reduce_stack[0] != 0u);
}
"#;
template
.replace("__MAX_CONDITIONS__", &max_conditions.max(1).to_string())
.replace("__MAX_FORMULA_OPS__", &max_formula_ops.max(1).to_string())
.replace("__WORKGROUP_SIZE__", &workgroup_size.to_string())
}
pub fn evaluate_parallel_condition(
condition: ParallelCondition,
matched_patterns: &[bool],
pattern_counts: &[u32],
file_size: u32,
) -> Result<bool> {
let kind = ParallelConditionKind::from_u32(condition.opcode)?;
let pattern_state = |pattern_id: u32| -> bool { matched_patterns.get(pattern_id as usize).copied().unwrap_or(false) };
let count_state = |pattern_id: u32| -> u32 { pattern_counts.get(pattern_id as usize).copied().unwrap_or(0) };
Ok(match kind {
ParallelConditionKind::PatternExists => pattern_state(condition.lhs),
ParallelConditionKind::PatternCountGt => count_state(condition.lhs) > condition.rhs,
ParallelConditionKind::PatternCountGte => count_state(condition.lhs) >= condition.rhs,
ParallelConditionKind::FileSizeLt => file_size < condition.rhs,
ParallelConditionKind::FileSizeLte => file_size <= condition.rhs,
ParallelConditionKind::FileSizeGt => file_size > condition.rhs,
ParallelConditionKind::FileSizeGte => file_size >= condition.rhs,
ParallelConditionKind::FileSizeEq => file_size == condition.rhs,
ParallelConditionKind::FileSizeNe => file_size != condition.rhs,
ParallelConditionKind::LiteralTrue => true,
ParallelConditionKind::LiteralFalse => false,
})
}
pub fn reduce_postfix_formula(results: &[bool], formula: &[FormulaInstruction]) -> Result<bool> {
let mut stack = Vec::with_capacity(formula.len().max(1));
for instruction in formula {
match FormulaOp::from_u32(instruction.opcode)? {
FormulaOp::PushResult => {
let value = results.get(instruction.operand as usize).copied().ok_or_else(|| Error::BytecodeValidation {
message: format!(
"Fix: formula references missing condition result index {}",
instruction.operand
),
})?;
stack.push(value);
}
FormulaOp::And => {
let rhs = stack.pop().ok_or_else(|| Error::BytecodeValidation {
message: "Fix: postfix AND requires two stack values".to_string(),
})?;
let lhs = stack.pop().ok_or_else(|| Error::BytecodeValidation {
message: "Fix: postfix AND requires two stack values".to_string(),
})?;
stack.push(lhs && rhs);
}
FormulaOp::Or => {
let rhs = stack.pop().ok_or_else(|| Error::BytecodeValidation {
message: "Fix: postfix OR requires two stack values".to_string(),
})?;
let lhs = stack.pop().ok_or_else(|| Error::BytecodeValidation {
message: "Fix: postfix OR requires two stack values".to_string(),
})?;
stack.push(lhs || rhs);
}
FormulaOp::Not => {
let value = stack.pop().ok_or_else(|| Error::BytecodeValidation {
message: "Fix: postfix NOT requires one stack value".to_string(),
})?;
stack.push(!value);
}
}
}
if stack.len() != 1 {
return Err(Error::BytecodeValidation {
message: format!(
"Fix: postfix reduction must leave exactly one stack value, left {}",
stack.len()
),
});
}
Ok(stack[0])
}
#[cfg(test)]
mod tests {
use super::{
build_parallel_eval_shader, evaluate_parallel_condition, reduce_postfix_formula, ExecutionStrategy,
FormulaInstruction, FormulaOp, ParallelCondition, ParallelConditionKind,
};
fn eval_rule(
conditions: &[ParallelCondition],
formula: &[FormulaInstruction],
matched_patterns: &[bool],
pattern_counts: &[u32],
file_size: u32,
) -> bool {
let results = conditions
.iter()
.map(|condition| evaluate_parallel_condition(*condition, matched_patterns, pattern_counts, file_size).unwrap())
.collect::<Vec<_>>();
reduce_postfix_formula(&results, formula).unwrap()
}
#[test]
fn five_independent_conditions_produce_expected_verdict() {
let conditions = [
ParallelCondition::new(ParallelConditionKind::PatternExists, 0, 0),
ParallelCondition::new(ParallelConditionKind::PatternCountGte, 1, 3),
ParallelCondition::new(ParallelConditionKind::FileSizeLt, 0, 1024),
ParallelCondition::new(ParallelConditionKind::PatternExists, 2, 0),
ParallelCondition::new(ParallelConditionKind::FileSizeEq, 0, 512),
];
let formula = [
FormulaInstruction::push_result(0),
FormulaInstruction::push_result(1),
FormulaInstruction::op(FormulaOp::And),
FormulaInstruction::push_result(2),
FormulaInstruction::op(FormulaOp::And),
FormulaInstruction::push_result(3),
FormulaInstruction::op(FormulaOp::Not),
FormulaInstruction::op(FormulaOp::And),
FormulaInstruction::push_result(4),
FormulaInstruction::op(FormulaOp::Or),
];
let verdict = eval_rule(&conditions, &formula, &[true, true, false], &[1, 3, 0], 512);
assert!(verdict);
}
#[test]
fn parallel_formula_matches_sequential_boolean_reduction() {
let conditions = [
ParallelCondition::new(ParallelConditionKind::PatternExists, 0, 0),
ParallelCondition::new(ParallelConditionKind::PatternCountGt, 1, 1),
ParallelCondition::new(ParallelConditionKind::FileSizeGte, 0, 2048),
ParallelCondition::new(ParallelConditionKind::PatternExists, 2, 0),
];
let formula = [
FormulaInstruction::push_result(0),
FormulaInstruction::push_result(1),
FormulaInstruction::op(FormulaOp::And),
FormulaInstruction::push_result(2),
FormulaInstruction::push_result(3),
FormulaInstruction::op(FormulaOp::Not),
FormulaInstruction::op(FormulaOp::And),
FormulaInstruction::op(FormulaOp::Or),
];
let matched_patterns = [true, true, false];
let pattern_counts = [1, 2, 0];
let file_size = 4096;
let sequential = (matched_patterns[0] && pattern_counts[1] > 1) || (file_size >= 2048 && !matched_patterns[2]);
let parallel = eval_rule(&conditions, &formula, &matched_patterns, &pattern_counts, file_size);
assert_eq!(parallel, sequential);
}
#[test]
fn edge_case_condition_counts_reduce_correctly() {
for &count in &[0usize, 1, 32, 64] {
let conditions = (0..count)
.map(|index| {
if index % 2 == 0 {
ParallelCondition::new(ParallelConditionKind::LiteralTrue, 0, 0)
} else {
ParallelCondition::new(ParallelConditionKind::LiteralFalse, 0, 0)
}
})
.collect::<Vec<_>>();
let formula = if count == 0 {
vec![FormulaInstruction::op(FormulaOp::Not)]
} else {
(0..count)
.flat_map(|index| {
let mut ops = vec![FormulaInstruction::push_result(index as u32)];
if index != 0 {
ops.push(FormulaInstruction::op(FormulaOp::Or));
}
ops
})
.collect::<Vec<_>>()
};
let verdict = if count == 0 {
reduce_postfix_formula(&[], &formula)
} else {
let results = conditions
.iter()
.map(|condition| evaluate_parallel_condition(*condition, &[], &[], 0))
.collect::<Result<Vec<_>, _>>()
.unwrap();
reduce_postfix_formula(&results, &formula)
};
if count == 0 {
assert!(verdict.is_err());
} else {
assert!(verdict.unwrap());
}
}
}
#[test]
fn shader_builder_embeds_parallel_architecture_constants() {
let shader = build_parallel_eval_shader(64, 128);
assert!(shader.contains("@compute @workgroup_size(64)"));
assert!(shader.contains("var<workgroup> condition_results"));
assert!(shader.contains("workgroupBarrier();"));
assert!(shader.contains("verdicts[rule_id]"));
}
#[test]
fn execution_strategy_defaults_to_sequential() {
assert_eq!(ExecutionStrategy::default(), ExecutionStrategy::Sequential);
assert_eq!(ExecutionStrategy::Parallel, ExecutionStrategy::Parallel);
}
}