vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! IR-specific mutations.

use crate::adversarial::mutations::catalog::BinOpKind;
use crate::spec::types::DataType;

fn binop_variant(op: BinOpKind) -> &'static str {
    match op {
        BinOpKind::Add => "Add",
        BinOpKind::Sub => "Sub",
        BinOpKind::Mul => "Mul",
        BinOpKind::Div => "Div",
        BinOpKind::Shl => "Shl",
        BinOpKind::Shr => "Shr",
        BinOpKind::And => "And",
        BinOpKind::Or => "Or",
        BinOpKind::Xor => "Xor",
        BinOpKind::Lt => "Lt",
        BinOpKind::Le => "Le",
        BinOpKind::Gt => "Gt",
        BinOpKind::Ge => "Ge",
        BinOpKind::Eq => "Eq",
        BinOpKind::Ne => "Ne",
    }
}

fn datatype_variant(dt: &DataType) -> &'static str {
    match dt {
        DataType::U32 => "U32",
        DataType::I32 => "I32",
        DataType::Bool => "Bool",
        DataType::U64 => "U64",
        DataType::Vec2U32 => "Vec2U32",
        DataType::Vec4U32 => "Vec4U32",
        DataType::Bytes => "Bytes",
        DataType::Array { .. } => "Array",
        DataType::F16 => "F16",
        DataType::BF16 => "BF16",
        DataType::F32 => "F32",
        DataType::F64 => "F64",
        DataType::Tensor => "Tensor",
    }
}

/// Swap a `BinOp` variant in IR definitions.
#[inline]
pub fn apply_binop_swap(source: &str, from: BinOpKind, to: BinOpKind) -> String {
    let from_pat = format!("BinOp::{}", binop_variant(from));
    let to_pat = format!("BinOp::{}", binop_variant(to));
    crate::adversarial::mutations::catalog::lexical::replace_code(source, &from_pat, &to_pat, 1)
}

/// Swap a `DataType` variant in signatures.
#[inline]
pub fn apply_datatype_swap(source: &str, from: DataType, to: DataType) -> String {
    let from_pat = format!("DataType::{}", datatype_variant(&from));
    let to_pat = format!("DataType::{}", datatype_variant(&to));
    crate::adversarial::mutations::catalog::lexical::replace_code(source, &from_pat, &to_pat, 1)
}

/// Comment out the first `law:` occurrence.
#[inline]
pub fn apply_delete_law(source: &str, _op: &str) -> String {
    crate::adversarial::mutations::catalog::lexical::replace_code(source, "law:", "// law:", 1)
}

/// Rename the first `reference_fn` occurrence.
#[inline]
pub fn apply_swap_reference_fn(source: &str, _op: &str, _wrong_op: &str) -> String {
    crate::adversarial::mutations::catalog::lexical::replace_code_word(
        source,
        "reference_fn",
        "reference_fn_wrong",
        1,
    )
}

/// Swap `BufferAccess::Storage` for `BufferAccess::Workgroup`.
#[inline]
pub fn apply_buffer_access_swap(source: &str) -> String {
    crate::adversarial::mutations::catalog::lexical::replace_code(
        source,
        "BufferAccess::Storage",
        "BufferAccess::Workgroup",
        1,
    )
}

/// Neutralise the first `validate(` call.
#[inline]
pub fn apply_remove_validation_rule(source: &str) -> String {
    crate::adversarial::mutations::catalog::lexical::replace_code(
        source,
        "validate(",
        "/* validate */(",
        1,
    )
}

/// Swap two opcode literals in the bytecode converter.
#[inline]
pub fn apply_bytecode_swap(source: &str, from_opcode: u8, to_opcode: u8) -> String {
    let from_dec = format!("{}u8", from_opcode);
    let to_dec = format!("{}u8", to_opcode);
    let tmp = crate::adversarial::mutations::catalog::lexical::replace_code(
        source, &from_dec, &to_dec, 1,
    );
    if tmp == source {
        let from_hex = format!("0x{:02X}u8", from_opcode);
        let to_hex = format!("0x{:02X}u8", to_opcode);
        crate::adversarial::mutations::catalog::lexical::replace_code(&tmp, &from_hex, &to_hex, 1)
    } else {
        tmp
    }
}

/// Corrupt the first `wgsl_op` occurrence.
#[inline]
pub fn apply_wrong_wgsl_op(source: &str) -> String {
    crate::adversarial::mutations::catalog::lexical::replace_code_word(
        source,
        "wgsl_op",
        "wgsl_op_wrong",
        1,
    )
}

/// Change the first `workgroup_size: 64` occurrence to `1`.
#[inline]
pub fn apply_workgroup_size_change(source: &str) -> String {
    crate::adversarial::mutations::catalog::lexical::replace_code(
        source,
        "workgroup_size: 64",
        "workgroup_size: 1",
        1,
    )
}

/// Multiply the first `workgroup_stride` expression.
#[inline]
pub fn apply_workgroup_stride_mul(source: &str, factor: u32) -> String {
    rewrite_named_number(source, "workgroup_stride", &format!("* {factor}"))
}

/// Divide the first `workgroup_stride` expression.
#[inline]
pub fn apply_workgroup_stride_div(source: &str, divisor: u32) -> String {
    rewrite_named_number(source, "workgroup_stride", &format!("/ {divisor}"))
}

/// Multiply the first `workgroup_size` expression.
#[inline]
pub fn apply_workgroup_size_mul(source: &str, factor: u32) -> String {
    rewrite_named_number(source, "workgroup_size", &format!("* {factor}"))
}

/// Divide the first `workgroup_size` expression.
#[inline]
pub fn apply_workgroup_size_div(source: &str, divisor: u32) -> String {
    rewrite_named_number(source, "workgroup_size", &format!("/ {divisor}"))
}

/// Offset the first `workgroup_size` expression.
#[inline]
pub fn apply_workgroup_size_offset(source: &str, by: i32) -> String {
    if by >= 0 {
        rewrite_named_number(source, "workgroup_size", &format!("+ {by}"))
    } else {
        rewrite_named_number(
            source,
            "workgroup_size",
            &format!("- {}", by.unsigned_abs()),
        )
    }
}

fn rewrite_named_number(source: &str, name: &str, suffix: &str) -> String {
    let Some(name_pos) = crate::adversarial::mutations::catalog::lexical::find_code(source, name)
    else {
        return source.to_string();
    };
    let Some(colon_rel) = source[name_pos..].find(':') else {
        return source.to_string();
    };
    let mut start = name_pos + colon_rel + 1;
    while source[start..]
        .chars()
        .next()
        .is_some_and(|ch| ch.is_ascii_whitespace())
    {
        start += source[start..].chars().next().unwrap().len_utf8();
    }
    let mut end = start;
    while source[end..]
        .chars()
        .next()
        .is_some_and(|ch| ch.is_ascii_digit())
    {
        end += source[end..].chars().next().unwrap().len_utf8();
    }
    if start == end {
        return source.to_string();
    }
    let expr = &source[start..end];
    format!("{}({expr} {suffix}){}", &source[..start], &source[end..])
}

/// Replace the first bounds-check pattern with `if true`.
#[inline]
pub fn apply_remove_bounds_check(source: &str) -> String {
    crate::adversarial::mutations::catalog::lexical::replace_code(
        source,
        "if index < len",
        "if true",
        1,
    )
}

/// Remove the first shift-mask pattern.
#[inline]
pub fn apply_remove_shift_mask(source: &str) -> String {
    crate::adversarial::mutations::catalog::lexical::replace_code(source, " & 0x1F", "", 1)
}