vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! Catalog entry for `detect_email`.

use crate::ops::security_detection::detector_support::{spans, ByteSpan, DetectionError};

/// Embedded operation spec formerly stored in metadata/spec.toml.
pub const SPEC_TOML: &str = r#"schema_version = 1
id = "security_detection.detect_email"
archetype = "match-bytes-pattern"
display_name = "Detect Email"
summary = "Returns offset-length spans for RFC-5322-ish email addresses."
category = "C"

[intrinsic]
wgsl = "security_detection_detect_email"

[signature]
inputs = ["Bytes", "Bytes"]
output = "Bytes"

laws = []
equivalence_classes = ["valid_domain", "missing_tld_dot", "token_boundary", "t47_cap"]
workgroup_size = [64, 1, 1]
tags = ["security-detection", "email", "ioc", "t47"]
fixtures_dir = "fixtures/"
"#;

/// Embedded reference vectors formerly stored in fixtures/reference-vectors.toml.
pub const REFERENCE_VECTORS_TOML: &str = r#"[[case]]
name = "positive_email"
input = "Contact security@example.com today"
expected_spans = [{ offset = 8, len = 20 }]

[[case]]
name = "negative_no_tld_dot"
input = "Contact security@example today"
expected_spans = []
"#;

/// WGSL lowering source for this detector.
pub mod lowering {
    /// Return the detector-specific WGSL source.
    #[must_use]
    pub const fn source() -> &'static str {
        r#"struct Params {
    input_len: u32,
    max_spans: u32,
    _pad0: u32,
    _pad1: u32,
}

struct SpanOutput {
    count: atomic<u32>,
    data: array<u32>,
}

@group(0) @binding(0) var<storage, read> input: array<u32>;
@group(0) @binding(1) var<storage, read_write> output: SpanOutput;
@group(0) @binding(2) var<uniform> params: Params;

fn is_alpha(byte: u32) -> bool {
    return (byte >= 65u && byte <= 90u) || (byte >= 97u && byte <= 122u);
}

fn is_digit(byte: u32) -> bool {
    return byte >= 48u && byte <= 57u;
}

fn is_boundary(byte: u32) -> bool {
    return !(is_alpha(byte) || is_digit(byte) || byte == 95u || byte == 45u);
}

fn is_email_local(byte: u32) -> bool {
    return is_alpha(byte) || is_digit(byte) || byte == 33u || byte == 35u || byte == 36u ||
        byte == 37u || byte == 38u || byte == 39u || byte == 42u || byte == 43u ||
        byte == 45u || byte == 47u || byte == 61u || byte == 63u || byte == 94u ||
        byte == 95u || byte == 96u || byte == 123u || byte == 124u || byte == 125u ||
        byte == 126u || byte == 46u;
}

fn is_email_domain(byte: u32) -> bool {
    return is_alpha(byte) || is_digit(byte) || byte == 45u || byte == 46u;
}

fn rewind_local(at: u32) -> u32 {
    var start = at;
    loop {
        if (start == 0u || !is_email_local(input[start - 1u])) {
            break;
        }
        start = start - 1u;
    }
    return start;
}

fn advance_domain(start: u32) -> u32 {
    var end = start;
    loop {
        if (end >= params.input_len || !is_email_domain(input[end])) {
            break;
        }
        end = end + 1u;
    }
    return end;
}

fn domain_has_dot(start: u32, end: u32) -> bool {
    for (var index = start; index < end; index = index + 1u) {
        if (input[index] == 46u) {
            return true;
        }
    }
    return false;
}

fn emit_span(offset: u32, len: u32) {
    let slot = atomicAdd(&output.count, 1u);
    if (slot < params.max_spans) {
        output.data[slot * 2u] = offset;
        output.data[slot * 2u + 1u] = len;
    }
}

@compute @workgroup_size(64)
fn security_detection_detect_email(@builtin(global_invocation_id) gid: vec3<u32>) {
    let at = gid.x;
    if (at == 0u || at + 3u >= params.input_len || input[at] != 64u) {
        return;
    }
    let start = rewind_local(at);
    let end = advance_domain(at + 1u);
    let before_ok = start == 0u || is_boundary(input[start - 1u]);
    let after_ok = end >= params.input_len || is_boundary(input[end]);
    if (start < at && end > at + 3u && domain_has_dot(at + 1u, end) && before_ok && after_ok) {
        emit_span(start, end - start);
    }
}
"#
    }
}

/// Return RFC-5322-ish email spans.
///
/// # Errors
///
/// Returns `Fix: ...` when input exceeds 64 MiB.
pub fn detect_email(input: &[u8]) -> Result<Vec<ByteSpan>, DetectionError> {
    spans::email_spans(input)
}

/// Compatibility surface for the previous generated implementation module.
pub mod implementation {
    pub use super::detect_email;
    /// Compatibility module for callers that used the generated kernel path.
    pub mod kernel {
        pub use super::super::detect_email;
    }

    /// Compatibility module for callers that used the generated lowering path.
    pub mod lowering {
        /// Compatibility module for callers that used `implementation::lowering::wgsl`.
        pub mod wgsl {
            pub use super::super::super::lowering::source;
        }
    }
}