use crate::pipeline::backend::{wrap_shader, ConformDispatchConfig, WgslBackend};
use crate::OpSpec;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct WgslMutation {
pub id: &'static str,
pub description: &'static str,
kind: WgslMutationKind,
}
impl WgslMutation {
#[must_use]
pub const fn flip_comparison() -> Self {
Self {
id: "wgsl.flip_comparison",
description: "flip a WGSL comparison operator",
kind: WgslMutationKind::FlipComparison,
}
}
#[must_use]
pub const fn swap_bitop() -> Self {
Self {
id: "wgsl.swap_bitop",
description: "swap a WGSL bitwise operator",
kind: WgslMutationKind::SwapBitOp,
}
}
#[must_use]
pub const fn drop_instruction() -> Self {
Self {
id: "wgsl.drop_instruction",
description: "replace a WGSL return instruction with a neutral zero return",
kind: WgslMutationKind::DropInstruction,
}
}
fn apply(&self, source: &str) -> Result<String, String> {
self.kind.apply(source)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum WgslMutationKind {
FlipComparison,
SwapBitOp,
DropInstruction,
}
impl WgslMutationKind {
fn apply(self, source: &str) -> Result<String, String> {
match self {
Self::FlipComparison => replace_first_token(
source,
&[
("==", "!="),
("!=", "=="),
(">=", "<"),
("<=", ">"),
(">", "<="),
("<", ">="),
],
),
Self::SwapBitOp => replace_first_token(source, &[("^", "&"), ("&", "|"), ("|", "^")]),
Self::DropInstruction => replace_first_return(source),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum WgslMutationOutcome {
Killed,
Survived,
Skipped {
reason: String,
},
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct WgslMutationResult {
pub mutation_id: String,
pub description: String,
pub outcome: WgslMutationOutcome,
pub detail: String,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct WgslMutationReport {
pub op_id: String,
pub backend: String,
pub results: Vec<WgslMutationResult>,
}
impl WgslMutationReport {
#[must_use]
#[inline]
pub fn passed(&self) -> bool {
self.results
.iter()
.all(|result| !matches!(result.outcome, WgslMutationOutcome::Survived))
}
#[must_use]
#[inline]
pub fn survivors(&self) -> Vec<&WgslMutationResult> {
self.results
.iter()
.filter(|result| matches!(result.outcome, WgslMutationOutcome::Survived))
.collect()
}
}
#[inline]
pub fn wgsl_mutation_probe(
backend: &dyn WgslBackend,
spec: &OpSpec,
mutations: &[WgslMutation],
) -> Result<WgslMutationReport, String> {
let inputs = default_inputs(spec);
wgsl_mutation_probe_with_inputs(backend, spec, mutations, &inputs)
}
#[inline]
pub fn wgsl_mutation_probe_with_inputs(
backend: &dyn WgslBackend,
spec: &OpSpec,
mutations: &[WgslMutation],
inputs: &[Vec<u8>],
) -> Result<WgslMutationReport, String> {
if inputs.is_empty() {
return Err(format!(
"{} WGSL mutation probe received no inputs. Fix: pass at least one parity input.",
spec.id
));
}
let original_wgsl = (spec.wgsl_fn)();
verify_original_parity(backend, spec, &original_wgsl, inputs)?;
let mut results = Vec::with_capacity(mutations.len());
for mutation in mutations {
let mutated = match mutation.apply(&original_wgsl) {
Ok(mutated) if mutated != original_wgsl => mutated,
Ok(_) => {
results.push(skip_result(mutation, "mutation produced identical WGSL"));
continue;
}
Err(reason) => {
results.push(skip_result(mutation, &reason));
continue;
}
};
results.push(run_mutation(backend, spec, mutation, &mutated, inputs));
}
Ok(WgslMutationReport {
op_id: spec.id.to_string(),
backend: backend.name().to_string(),
results,
})
}
fn run_mutation(
backend: &dyn WgslBackend,
spec: &OpSpec,
mutation: &WgslMutation,
mutated_wgsl: &str,
inputs: &[Vec<u8>],
) -> WgslMutationResult {
let config = dispatch_config(spec);
let shader = wrap_shader(mutated_wgsl, &config);
if let Err(error) = validate_wrapped_shader(&shader) {
return WgslMutationResult {
mutation_id: mutation.id.to_string(),
description: mutation.description.to_string(),
outcome: WgslMutationOutcome::Killed,
detail: format!(
"mutated WGSL failed naga recompile: {error}. Fix: keep shader mutations syntactically valid when measuring semantic survivors."
),
};
}
for input in inputs {
let expected = (spec.cpu_fn)(input);
let output_size = output_size(spec, &expected);
let actual = match backend.dispatch(&shader, input, output_size, config.clone()) {
Ok(actual) => actual,
Err(error) => {
return WgslMutationResult {
mutation_id: mutation.id.to_string(),
description: mutation.description.to_string(),
outcome: WgslMutationOutcome::Killed,
detail: format!(
"mutated WGSL failed backend dispatch: {error}. Fix: keep shader mutations syntactically valid when measuring semantic survivors."
),
};
}
};
if actual != expected {
return WgslMutationResult {
mutation_id: mutation.id.to_string(),
description: mutation.description.to_string(),
outcome: WgslMutationOutcome::Killed,
detail: format!(
"mutated WGSL diverged on input {:02x?}: expected {:02x?}, got {:02x?}",
input, expected, actual
),
};
}
}
WgslMutationResult {
mutation_id: mutation.id.to_string(),
description: mutation.description.to_string(),
outcome: WgslMutationOutcome::Survived,
detail: format!(
"{} survived WGSL mutation {}. Fix: add parity inputs that distinguish the mutated shader from the CPU reference.",
spec.id, mutation.id
),
}
}
fn verify_original_parity(
backend: &dyn WgslBackend,
spec: &OpSpec,
original_wgsl: &str,
inputs: &[Vec<u8>],
) -> Result<(), String> {
let config = dispatch_config(spec);
let shader = wrap_shader(original_wgsl, &config);
validate_wrapped_shader(&shader).map_err(|error| {
format!(
"{} original WGSL fails naga recompile before mutation probing: {error}. Fix: repair baseline WGSL before running WGSL mutation probes.",
spec.id
)
})?;
for input in inputs {
let expected = (spec.cpu_fn)(input);
let output_size = output_size(spec, &expected);
let actual = backend.dispatch(&shader, input, output_size, config.clone())?;
if actual != expected {
return Err(format!(
"{} original WGSL does not match CPU reference on input {:02x?}: expected {:02x?}, got {:02x?}. Fix: repair baseline parity before running WGSL mutation probes.",
spec.id, input, expected, actual
));
}
}
Ok(())
}
fn validate_wrapped_shader(shader: &str) -> Result<(), String> {
let module = naga::front::wgsl::parse_str(shader)
.map_err(|error| format!("Fix: WGSL shader fails naga parsing: {error}"))?;
naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::empty(),
)
.validate(&module)
.map(|_| ())
.map_err(|error| format!("Fix: WGSL shader fails naga validation: {error}"))
}
fn dispatch_config(spec: &OpSpec) -> ConformDispatchConfig {
ConformDispatchConfig {
workgroup_size: spec.workgroup_size.unwrap_or(1),
convention: spec.convention.clone(),
..ConformDispatchConfig::default()
}
}
fn output_size(spec: &OpSpec, expected: &[u8]) -> usize {
spec.expected_output_bytes.unwrap_or(expected.len())
}
fn default_inputs(spec: &OpSpec) -> Vec<Vec<u8>> {
if !spec.spec_table.is_empty() {
return spec
.spec_table
.iter()
.map(|row| {
row.inputs
.iter()
.flat_map(|input| input.iter().copied())
.collect()
})
.collect();
}
let len = spec.signature.min_input_bytes().max(4);
let mut inputs = vec![
vec![0; len],
vec![0xFF; len],
vec![0x55; len],
vec![0xAA; len],
];
if len >= 8 {
let mut mixed = vec![0; len];
mixed[..4].copy_from_slice(&1u32.to_le_bytes());
mixed[4..8].copy_from_slice(&2u32.to_le_bytes());
inputs.push(mixed);
}
inputs
}
fn skip_result(mutation: &WgslMutation, reason: &str) -> WgslMutationResult {
WgslMutationResult {
mutation_id: mutation.id.to_string(),
description: mutation.description.to_string(),
outcome: WgslMutationOutcome::Skipped {
reason: reason.to_string(),
},
detail: format!(
"WGSL mutation skipped: {reason}. Fix: select mutations applicable to this shader."
),
}
}
fn replace_first_token(source: &str, replacements: &[(&str, &str)]) -> Result<String, String> {
for (from, to) in replacements {
if let Some(index) = source.find(from) {
let mut out = String::with_capacity(source.len() + to.len().saturating_sub(from.len()));
out.push_str(&source[..index]);
out.push_str(to);
out.push_str(&source[index + from.len()..]);
return Ok(out);
}
}
Err("target WGSL token was not present".to_string())
}
fn replace_first_return(source: &str) -> Result<String, String> {
let Some(return_start) = source.find("return ") else {
return Err("WGSL return instruction was not present".to_string());
};
let Some(relative_end) = source[return_start..].find(';') else {
return Err("WGSL return instruction had no semicolon".to_string());
};
let return_end = return_start + relative_end + 1;
let mut out = String::with_capacity(source.len());
out.push_str(&source[..return_start]);
out.push_str("return 0u;");
out.push_str(&source[return_end..]);
Ok(out)
}
#[cfg(test)]
mod tests {
use super::{wgsl_mutation_probe_with_inputs, WgslMutation, WgslMutationOutcome};
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
struct XorAwareBackend;
impl WgslBackend for XorAwareBackend {
fn name(&self) -> &str {
"xor-aware-mock"
}
fn dispatch(
&self,
wgsl: &str,
input: &[u8],
output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
if input.len() < 8 || output_size != 4 {
return Err(
"mock expects one binary u32 dispatch. Fix: use binary u32 parity inputs."
.to_string(),
);
}
let left = u32::from_le_bytes([input[0], input[1], input[2], input[3]]);
let right = u32::from_le_bytes([input[4], input[5], input[6], input[7]]);
let has_xor = wgsl.contains("input.data[0u] ^ input.data[1u]")
|| (wgsl.contains("_vyre_load_a(") && wgsl.contains(" ^ _vyre_load_b("));
let has_and = wgsl.contains("input.data[0u] & input.data[1u]")
|| (wgsl.contains("_vyre_load_a(") && wgsl.contains(" & _vyre_load_b("));
let has_or = wgsl.contains("input.data[0u] | input.data[1u]")
|| (wgsl.contains("_vyre_load_a(") && wgsl.contains(" | _vyre_load_b("));
let value = if has_xor {
left ^ right
} else if has_and {
left & right
} else if has_or {
left | right
} else {
return Err("mock cannot execute this WGSL. Fix: extend the mock interpreter for this shader.".to_string());
};
Ok(value.to_le_bytes().to_vec())
}
}
#[test]
fn wgsl_mutation_probe_kills_bitop_swap_through_dispatch() {
let spec = crate::spec::primitive::xor::spec();
let input = {
let mut bytes = Vec::new();
bytes.extend_from_slice(&0xAAu32.to_le_bytes());
bytes.extend_from_slice(&0x55u32.to_le_bytes());
bytes
};
let report = wgsl_mutation_probe_with_inputs(
&XorAwareBackend,
&spec,
&[WgslMutation::swap_bitop()],
&[input],
)
.expect("Fix: baseline XOR mock parity must pass before mutation probing");
assert!(
report.passed(),
"WGSL bitop mutation must be killed: {report:?}"
);
assert!(matches!(
report.results[0].outcome,
WgslMutationOutcome::Killed
));
}
}