pub fn header(max_stack: u32, max_for_iterations: u32) -> String {
let template = r#"
const STACK_SIZE: u32 = __MAX_STACK__u;
const MAX_FOR_ITERATIONS: u32 = __MAX_FOR_ITERATIONS__u;
const MAX_LOOP_DEPTH: u32 = 8u;
const ABORT_SENTINEL: u32 = 0xFFFFFFFFu;
struct Instruction {
opcode: u32,
operand: u32,
};
struct FileContext {
file_size: u32,
entropy_bucket: u32,
magic_u32: u32,
is_pe: u32,
is_dll: u32,
is_64bit: u32,
has_signature: u32,
num_sections: u32,
num_imports: u32,
entry_point_rva: u32,
unique_pattern_count: u32,
total_match_count: u32,
};
@group(0) @binding(0) var<storage, read> programs: array<Instruction>;
@group(0) @binding(1) var<storage, read> rule_program_spans: array<vec2<u32>>;
@group(0) @binding(2) var<storage, read> rule_bitmaps: array<u32>;
@group(0) @binding(3) var<storage, read> rule_counts: array<u32>;
@group(0) @binding(4) var<storage, read> rule_positions: array<u32>;
@group(0) @binding(5) var<uniform> params: vec4<u32>;
@group(0) @binding(6) var<storage, read> file_bytes: array<u32>;
@group(0) @binding(7) var<uniform> file_ctx: FileContext;
@group(0) @binding(8) var<storage, read> rule_lengths: array<u32>;
struct FiredResults {
count: atomic<u32>,
indices: array<u32, 1024>,
};
@group(0) @binding(9) var<storage, read_write> fired_results: FiredResults;
// params.x = rule_count
// params.y = max_strings_per_rule
// params.z = max_cached_positions
// params.w = file_size
fn count_index(rule_id: u32, string_id: u32) -> u32 {
return rule_id * params.y + string_id;
}
fn pos_index(rule_id: u32, string_id: u32, slot: u32) -> u32 {
return ((rule_id * params.y + string_id) * params.z) + slot;
}
fn pattern_is_valid(pattern_id: u32) -> bool {
return pattern_id < params.y;
}
fn cached_match_count(rule_id: u32, pattern_id: u32) -> u32 {
if (!pattern_is_valid(pattern_id)) {
return 0u;
}
return min(rule_counts[count_index(rule_id, pattern_id)], params.z);
}
fn match_position(rule_id: u32, pattern_id: u32, slot: u32) -> u32 {
if (slot >= cached_match_count(rule_id, pattern_id)) {
return ABORT_SENTINEL;
}
return rule_positions[pos_index(rule_id, pattern_id, slot)];
}
fn match_length(rule_id: u32, pattern_id: u32, slot: u32) -> u32 {
if (slot >= cached_match_count(rule_id, pattern_id)) {
return 0u;
}
return rule_lengths[pos_index(rule_id, pattern_id, slot)];
}
fn match_end(rule_id: u32, pattern_id: u32, slot: u32) -> u32 {
let start = match_position(rule_id, pattern_id, slot);
if (start == ABORT_SENTINEL) {
return ABORT_SENTINEL;
}
return start + match_length(rule_id, pattern_id, slot);
}
fn first_match_position(rule_id: u32, pattern_id: u32) -> u32 {
return match_position(rule_id, pattern_id, 0u);
}
fn read_byte(offset: u32) -> u32 {
if (offset >= file_ctx.file_size) {
return 0u;
}
let word_idx = offset / 4u;
if (word_idx >= arrayLength(&file_bytes)) {
return 0u;
}
let word = file_bytes[word_idx];
let shift = (offset % 4u) * 8u;
return (word >> shift) & 0xFFu;
}
fn popcount32(v: u32) -> u32 {
var x = v;
x = x - ((x >> 1u) & 0x55555555u);
x = (x & 0x33333333u) + ((x >> 2u) & 0x33333333u);
x = (x + (x >> 4u)) & 0x0F0F0F0Fu;
x = x + (x >> 8u);
x = x + (x >> 16u);
return x & 0x3Fu;
}
fn brace_depth(offset: u32) -> u32 {
if (offset > file_ctx.file_size) {
return 0u;
}
var depth = 0u;
var idx = 0u;
loop {
if (idx >= offset) {
break;
}
let byte = read_byte(idx);
if (byte == 123u) {
depth = depth + 1u;
} else if (byte == 125u && depth > 0u) {
depth = depth - 1u;
}
idx = idx + 1u;
}
return depth;
}
fn region_entropy_bucket(offset: u32, length: u32) -> u32 {
if (length == 0u || offset >= file_ctx.file_size || offset + length > file_ctx.file_size) {
return 0u;
}
var counts: array<u32, 256>;
var idx = 0u;
loop {
if (idx >= 256u) {
break;
}
counts[idx] = 0u;
idx = idx + 1u;
}
idx = 0u;
loop {
if (idx >= length) {
break;
}
let byte = read_byte(offset + idx);
counts[byte] = counts[byte] + 1u;
idx = idx + 1u;
}
var entropy = 0.0;
let length_f = f32(length);
idx = 0u;
loop {
if (idx >= 256u) {
break;
}
let count = counts[idx];
if (count != 0u) {
let p = f32(count) / length_f;
entropy = entropy - (p * log2(p));
}
idx = idx + 1u;
}
return u32(round(clamp((entropy / 8.0) * 255.0, 0.0, 255.0)));
}
fn fnv1a_hash(offset: u32, length: u32) -> u32 {
if (length == 0u || offset >= file_ctx.file_size || offset + length > file_ctx.file_size) {
return 0u;
}
var hash = 2166136261u;
var idx = 0u;
loop {
if (idx >= length) {
break;
}
hash = (hash ^ read_byte(offset + idx)) * 16777619u;
idx = idx + 1u;
}
return hash;
}
"#;
template
.replace("__MAX_STACK__", &max_stack.to_string())
.replace("__MAX_FOR_ITERATIONS__", &max_for_iterations.to_string())
}