use vyre::VyreBackend;
use crate::spec::program::program_for_spec_input;
use crate::spec::types::OpSpec;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DivergenceFinding {
ParityMismatch {
op_id: String,
scenario: DivergenceScenario,
witness_input: Vec<u8>,
cpu_output: Vec<u8>,
diverged_output: Vec<u8>,
message: String,
},
BackendDispatchFailed {
op_id: String,
scenario: DivergenceScenario,
error: String,
},
UnsupportedSignature {
op_id: String,
reason: String,
},
}
fn divergence_key(finding: &DivergenceFinding) -> (String, u8) {
match finding {
DivergenceFinding::ParityMismatch {
op_id, scenario, ..
}
| DivergenceFinding::BackendDispatchFailed {
op_id, scenario, ..
} => (op_id.clone(), *scenario as u8),
DivergenceFinding::UnsupportedSignature { op_id, .. } => (op_id.clone(), u8::MAX),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum DivergenceScenario {
FiftyFiftySplit,
ThirtyThreeSixtySeven,
TwentyFiveSeventyFive,
LaneIdModuloFour,
EarlyExitLaneZero,
EarlyExitLastLane,
DataDependentBranch,
StallLaneSeven,
FullReconvergence,
}
#[must_use]
#[inline]
pub(crate) fn run(specs: &[OpSpec], backend: &dyn VyreBackend) -> Vec<DivergenceFinding> {
let mut findings = Vec::new();
for spec in specs {
let witness_input = vec![0u8; spec.signature.min_input_bytes().max(8)];
let cpu_output = (spec.cpu_fn)(&witness_input);
let mut program = match program_for_spec_input(spec, &witness_input) {
Ok(program) => program,
Err(reason) => {
findings.push(DivergenceFinding::UnsupportedSignature {
op_id: spec.id.to_string(),
reason,
});
continue;
}
};
program.set_workgroup_size([64, 1, 1]);
for scenario in DivergenceScenario::all() {
let diverged_output = match dispatch_exact(
backend,
&program,
&[witness_input.clone()],
cpu_output.len(),
) {
Ok(output) => output,
Err(error) => {
findings.push(DivergenceFinding::BackendDispatchFailed {
op_id: spec.id.to_string(),
scenario,
error: error.to_string(),
});
continue;
}
};
if let Err(message) = spec.comparator.compare(&diverged_output, &cpu_output) {
findings.push(DivergenceFinding::ParityMismatch {
op_id: spec.id.to_string(),
scenario,
witness_input: witness_input.clone(),
cpu_output: cpu_output.clone(),
diverged_output,
message,
});
}
}
}
findings.sort_unstable_by_key(divergence_key);
findings
}
impl std::fmt::Display for DivergenceFinding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ParityMismatch {
op_id,
scenario,
witness_input,
cpu_output,
diverged_output,
message,
} => {
write!(
f,
"{op_id} [{}]: parity mismatch under divergence probe; witness={} CPU={} GPU={}. Fix: ensure every lane either reconverges with identical side effects and output bytes, including early-return and branch-split control flow.",
scenario.name(),
witness_input.len(),
cpu_output.len(),
diverged_output.len(),
)?;
write!(f, " Detail: {message}")
}
Self::BackendDispatchFailed {
op_id,
scenario,
error,
} => write!(
f,
"{op_id} [{}]: backend dispatch failed under divergence probe. Fix: make WGSL and backend scheduling robust to lane-divergent control flow: {error}",
scenario.name(),
),
Self::UnsupportedSignature { op_id, reason } => write!(
f,
"{op_id}: unsupported signature for divergence gate. Fix: provide WGSL lowering so `{}` can be wrapped and dispatched under a divergence probe.",
reason
),
}
}
}
impl DivergenceScenario {
#[must_use]
pub const fn all() -> [Self; 9] {
[
Self::FiftyFiftySplit,
Self::ThirtyThreeSixtySeven,
Self::TwentyFiveSeventyFive,
Self::LaneIdModuloFour,
Self::EarlyExitLaneZero,
Self::EarlyExitLastLane,
Self::DataDependentBranch,
Self::StallLaneSeven,
Self::FullReconvergence,
]
}
#[must_use]
pub const fn name(self) -> &'static str {
match self {
Self::FiftyFiftySplit => "fifty_fifty_split",
Self::ThirtyThreeSixtySeven => "thirty_three_sixty_seven",
Self::TwentyFiveSeventyFive => "twenty_five_seventy_five",
Self::LaneIdModuloFour => "lane_id_modulo_four",
Self::EarlyExitLaneZero => "early_exit_lane_zero",
Self::EarlyExitLastLane => "early_exit_last_lane",
Self::DataDependentBranch => "data_dependent_branch",
Self::StallLaneSeven => "stall_lane_seven",
Self::FullReconvergence => "full_reconvergence",
}
}
}
impl VyreBackend for NullBackend {
fn id(&self) -> &'static str {
"null"
}
fn dispatch(
&self,
_program: &vyre::Program,
inputs: &[Vec<u8>],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
let input = inputs.first().cloned().unwrap_or_default();
let output_size = input.len().max(4);
let mut output = vec![0u8; output_size];
let copy_len = output_size.min(input.len());
output[..copy_len].copy_from_slice(&input[..copy_len]);
Ok(vec![output])
}
}
#[derive(Debug, Default)]
pub struct NullBackend;
fn scenario_body(scenario: DivergenceScenario) -> &'static str {
match scenario {
DivergenceScenario::FiftyFiftySplit => {
"if (lid.x < 32u) {\n workgroupBarrier();\n } else {\n workgroupBarrier();\n workgroupBarrier();\n }\n workgroupBarrier();"
}
DivergenceScenario::ThirtyThreeSixtySeven => {
"if (lid.x < 21u) {\n workgroupBarrier();\n } else {\n workgroupBarrier();\n workgroupBarrier();\n }\n workgroupBarrier();"
}
DivergenceScenario::TwentyFiveSeventyFive => {
"if (lid.x < 16u) {\n workgroupBarrier();\n } else {\n workgroupBarrier();\n workgroupBarrier();\n }\n workgroupBarrier();"
}
DivergenceScenario::LaneIdModuloFour => {
"if ((lid.x & 3u) == 0u) {\n workgroupBarrier();\n } else {\n workgroupBarrier();\n workgroupBarrier();\n }\n workgroupBarrier();"
}
DivergenceScenario::EarlyExitLaneZero => {
"if (lid.x == 0u) {\n return 0u;\n }\n workgroupBarrier();\n workgroupBarrier();"
}
DivergenceScenario::EarlyExitLastLane => {
"if (lid.x == 63u) {\n return 0u;\n }\n workgroupBarrier();\n workgroupBarrier();"
}
DivergenceScenario::DataDependentBranch => {
"if ((input.data[0] == 0u) && ((lid.x & 1u) == 0u) {\n workgroupBarrier();\n } else {\n workgroupBarrier();\n workgroupBarrier();\n }\n workgroupBarrier();"
}
DivergenceScenario::StallLaneSeven => {
"if (lid.x == 7u) {\n workgroupBarrier();\n workgroupBarrier();\n workgroupBarrier();\n }\n workgroupBarrier();"
}
DivergenceScenario::FullReconvergence => {
"if (lid.x < 32u) {\n workgroupBarrier();\n } else {\n workgroupBarrier();\n }\n workgroupBarrier();"
}
}
}
fn dispatch_exact(
backend: &dyn VyreBackend,
program: &vyre::Program,
inputs: &[Vec<u8>],
output_size: usize,
) -> Result<Vec<u8>, vyre::BackendError> {
let mut outputs = backend.dispatch(program, inputs, &vyre::DispatchConfig::default())?;
if outputs.is_empty() {
return Err(vyre::BackendError::new(
"backend returned zero output buffers. Fix: return the divergence probe output as outputs[0].",
));
}
let output = outputs.remove(0);
if output.len() != output_size {
return Err(vyre::BackendError::new(format!(
"backend returned {} bytes, expected {output_size}. Fix: size the first output buffer from the divergence probe output declaration.",
output.len()
)));
}
Ok(output)
}
pub struct DivergenceEnforcer;
impl crate::enforce::EnforceGate for DivergenceEnforcer {
fn id(&self) -> &'static str {
"divergence"
}
fn name(&self) -> &'static str {
"divergence"
}
fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
let Some(backend) = ctx.backend else {
return vec![crate::enforce::aggregate_finding(
self.id(),
vec![
"divergence: backend is required. Fix: provide a VyreBackend in EnforceCtx."
.to_string(),
],
)];
};
let findings = run(ctx.specs, backend);
let messages = findings
.into_iter()
.map(|finding| finding.to_string())
.collect::<Vec<_>>();
crate::enforce::finding_result(self.id(), messages)
}
}
pub const REGISTERED: DivergenceEnforcer = DivergenceEnforcer;