use crate::adversarial::mutations::catalog::BinOpKind;
use crate::spec::types::DataType;
fn binop_variant(op: BinOpKind) -> &'static str {
match op {
BinOpKind::Add => "Add",
BinOpKind::Sub => "Sub",
BinOpKind::Mul => "Mul",
BinOpKind::Div => "Div",
BinOpKind::Shl => "Shl",
BinOpKind::Shr => "Shr",
BinOpKind::And => "And",
BinOpKind::Or => "Or",
BinOpKind::Xor => "Xor",
BinOpKind::Lt => "Lt",
BinOpKind::Le => "Le",
BinOpKind::Gt => "Gt",
BinOpKind::Ge => "Ge",
BinOpKind::Eq => "Eq",
BinOpKind::Ne => "Ne",
}
}
fn datatype_variant(dt: &DataType) -> &'static str {
match dt {
DataType::U32 => "U32",
DataType::I32 => "I32",
DataType::Bool => "Bool",
DataType::U64 => "U64",
DataType::Vec2U32 => "Vec2U32",
DataType::Vec4U32 => "Vec4U32",
DataType::Bytes => "Bytes",
DataType::Array { .. } => "Array",
DataType::F16 => "F16",
DataType::BF16 => "BF16",
DataType::F32 => "F32",
DataType::F64 => "F64",
DataType::Tensor => "Tensor",
}
}
#[inline]
pub fn apply_binop_swap(source: &str, from: BinOpKind, to: BinOpKind) -> String {
let from_pat = format!("BinOp::{}", binop_variant(from));
let to_pat = format!("BinOp::{}", binop_variant(to));
crate::adversarial::mutations::catalog::lexical::replace_code(source, &from_pat, &to_pat, 1)
}
#[inline]
pub fn apply_datatype_swap(source: &str, from: DataType, to: DataType) -> String {
let from_pat = format!("DataType::{}", datatype_variant(&from));
let to_pat = format!("DataType::{}", datatype_variant(&to));
crate::adversarial::mutations::catalog::lexical::replace_code(source, &from_pat, &to_pat, 1)
}
#[inline]
pub fn apply_delete_law(source: &str, _op: &str) -> String {
crate::adversarial::mutations::catalog::lexical::replace_code(source, "law:", "// law:", 1)
}
#[inline]
pub fn apply_swap_reference_fn(source: &str, _op: &str, _wrong_op: &str) -> String {
crate::adversarial::mutations::catalog::lexical::replace_code_word(
source,
"reference_fn",
"reference_fn_wrong",
1,
)
}
#[inline]
pub fn apply_buffer_access_swap(source: &str) -> String {
crate::adversarial::mutations::catalog::lexical::replace_code(
source,
"BufferAccess::Storage",
"BufferAccess::Workgroup",
1,
)
}
#[inline]
pub fn apply_remove_validation_rule(source: &str) -> String {
crate::adversarial::mutations::catalog::lexical::replace_code(
source,
"validate(",
"/* validate */(",
1,
)
}
#[inline]
pub fn apply_bytecode_swap(source: &str, from_opcode: u8, to_opcode: u8) -> String {
let from_dec = format!("{}u8", from_opcode);
let to_dec = format!("{}u8", to_opcode);
let tmp = crate::adversarial::mutations::catalog::lexical::replace_code(
source, &from_dec, &to_dec, 1,
);
if tmp == source {
let from_hex = format!("0x{:02X}u8", from_opcode);
let to_hex = format!("0x{:02X}u8", to_opcode);
crate::adversarial::mutations::catalog::lexical::replace_code(&tmp, &from_hex, &to_hex, 1)
} else {
tmp
}
}
#[inline]
pub fn apply_wrong_wgsl_op(source: &str) -> String {
crate::adversarial::mutations::catalog::lexical::replace_code_word(
source,
"wgsl_op",
"wgsl_op_wrong",
1,
)
}
#[inline]
pub fn apply_workgroup_size_change(source: &str) -> String {
crate::adversarial::mutations::catalog::lexical::replace_code(
source,
"workgroup_size: 64",
"workgroup_size: 1",
1,
)
}
#[inline]
pub fn apply_workgroup_stride_mul(source: &str, factor: u32) -> String {
rewrite_named_number(source, "workgroup_stride", &format!("* {factor}"))
}
#[inline]
pub fn apply_workgroup_stride_div(source: &str, divisor: u32) -> String {
rewrite_named_number(source, "workgroup_stride", &format!("/ {divisor}"))
}
#[inline]
pub fn apply_workgroup_size_mul(source: &str, factor: u32) -> String {
rewrite_named_number(source, "workgroup_size", &format!("* {factor}"))
}
#[inline]
pub fn apply_workgroup_size_div(source: &str, divisor: u32) -> String {
rewrite_named_number(source, "workgroup_size", &format!("/ {divisor}"))
}
#[inline]
pub fn apply_workgroup_size_offset(source: &str, by: i32) -> String {
if by >= 0 {
rewrite_named_number(source, "workgroup_size", &format!("+ {by}"))
} else {
rewrite_named_number(
source,
"workgroup_size",
&format!("- {}", by.unsigned_abs()),
)
}
}
fn rewrite_named_number(source: &str, name: &str, suffix: &str) -> String {
let Some(name_pos) = crate::adversarial::mutations::catalog::lexical::find_code(source, name)
else {
return source.to_string();
};
let Some(colon_rel) = source[name_pos..].find(':') else {
return source.to_string();
};
let mut start = name_pos + colon_rel + 1;
while source[start..]
.chars()
.next()
.is_some_and(|ch| ch.is_ascii_whitespace())
{
start += source[start..].chars().next().unwrap().len_utf8();
}
let mut end = start;
while source[end..]
.chars()
.next()
.is_some_and(|ch| ch.is_ascii_digit())
{
end += source[end..].chars().next().unwrap().len_utf8();
}
if start == end {
return source.to_string();
}
let expr = &source[start..end];
format!("{}({expr} {suffix}){}", &source[..start], &source[end..])
}
#[inline]
pub fn apply_remove_bounds_check(source: &str) -> String {
crate::adversarial::mutations::catalog::lexical::replace_code(
source,
"if index < len",
"if true",
1,
)
}
#[inline]
pub fn apply_remove_shift_mask(source: &str) -> String {
crate::adversarial::mutations::catalog::lexical::replace_code(source, " & 0x1F", "", 1)
}