vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Divergence conformance gate (VYRE_RELEASE_PLAN Phase 3.2).

use vyre::VyreBackend;

use crate::spec::program::program_for_spec_input;
use crate::spec::types::OpSpec;

/// A divergence-gate finding.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DivergenceFinding {
    /// CPU/GPU diverged under a divergence scenario.
    ParityMismatch {
        /// Op id.
        op_id: String,
        /// Scenario that triggered the mismatch.
        scenario: DivergenceScenario,
        /// Input bytes used for dispatch.
        witness_input: Vec<u8>,
        /// CPU reference output.
        cpu_output: Vec<u8>,
        /// GPU output after divergence scenario dispatch.
        diverged_output: Vec<u8>,
        /// Comparator mismatch message.
        message: String,
    },
    /// Backend dispatch failed when probing a scenario.
    BackendDispatchFailed {
        /// Op id.
        op_id: String,
        /// Scenario that triggered the error.
        scenario: DivergenceScenario,
        /// Error returned by backend.
        error: String,
    },
    /// The operation lacks WGSL lowering required for probing.
    UnsupportedSignature {
        /// Op id.
        op_id: String,
        /// Why this op could not be probed.
        reason: String,
    },
}

fn divergence_key(finding: &DivergenceFinding) -> (String, u8) {
    match finding {
        DivergenceFinding::ParityMismatch {
            op_id, scenario, ..
        }
        | DivergenceFinding::BackendDispatchFailed {
            op_id, scenario, ..
        } => (op_id.clone(), *scenario as u8),
        DivergenceFinding::UnsupportedSignature { op_id, .. } => (op_id.clone(), u8::MAX),
    }
}

/// Divergence scenario injected into the compute wrapper.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum DivergenceScenario {
    /// Half the lanes follow one path and half follow another.
    FiftyFiftySplit,
    /// ~33% / ~67% split.
    ThirtyThreeSixtySeven,
    /// ~25% / ~75% split.
    TwentyFiveSeventyFive,
    /// Branches based on `lid.x % 4`.
    LaneIdModuloFour,
    /// Lane zero exits before all peers.
    EarlyExitLaneZero,
    /// Last lane exits before all peers.
    EarlyExitLastLane,
    /// Branch depends on input bytes and lane id.
    DataDependentBranch,
    /// Lane seven executes extra barriers.
    StallLaneSeven,
    /// Divergent branches reconverge.
    FullReconvergence,
}

/// Run divergence conformance checks for the supplied operation specs.
#[must_use]
#[inline]
pub(crate) fn run(specs: &[OpSpec], backend: &dyn VyreBackend) -> Vec<DivergenceFinding> {
    let mut findings = Vec::new();

    for spec in specs {
        let witness_input = vec![0u8; spec.signature.min_input_bytes().max(8)];
        let cpu_output = (spec.cpu_fn)(&witness_input);
        let mut program = match program_for_spec_input(spec, &witness_input) {
            Ok(program) => program,
            Err(reason) => {
                findings.push(DivergenceFinding::UnsupportedSignature {
                    op_id: spec.id.to_string(),
                    reason,
                });
                continue;
            }
        };
        program.set_workgroup_size([64, 1, 1]);

        for scenario in DivergenceScenario::all() {
            let diverged_output = match dispatch_exact(
                backend,
                &program,
                &[witness_input.clone()],
                cpu_output.len(),
            ) {
                Ok(output) => output,
                Err(error) => {
                    findings.push(DivergenceFinding::BackendDispatchFailed {
                        op_id: spec.id.to_string(),
                        scenario,
                        error: error.to_string(),
                    });
                    continue;
                }
            };

            if let Err(message) = spec.comparator.compare(&diverged_output, &cpu_output) {
                findings.push(DivergenceFinding::ParityMismatch {
                    op_id: spec.id.to_string(),
                    scenario,
                    witness_input: witness_input.clone(),
                    cpu_output: cpu_output.clone(),
                    diverged_output,
                    message,
                });
            }
        }
    }

    findings.sort_unstable_by_key(divergence_key);
    findings
}

impl std::fmt::Display for DivergenceFinding {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::ParityMismatch {
                op_id,
                scenario,
                witness_input,
                cpu_output,
                diverged_output,
                message,
            } => {
                write!(
                    f,
                    "{op_id} [{}]: parity mismatch under divergence probe; witness={} CPU={} GPU={}. Fix: ensure every lane either reconverges with identical side effects and output bytes, including early-return and branch-split control flow.",
                    scenario.name(),
                    witness_input.len(),
                    cpu_output.len(),
                    diverged_output.len(),
                )?;
                write!(f, " Detail: {message}")
            }
            Self::BackendDispatchFailed {
                op_id,
                scenario,
                error,
            } => write!(
                f,
                "{op_id} [{}]: backend dispatch failed under divergence probe. Fix: make WGSL and backend scheduling robust to lane-divergent control flow: {error}",
                scenario.name(),
            ),
            Self::UnsupportedSignature { op_id, reason } => write!(
                f,
                "{op_id}: unsupported signature for divergence gate. Fix: provide WGSL lowering so `{}` can be wrapped and dispatched under a divergence probe.",
                reason
            ),
        }
    }
}

impl DivergenceScenario {
    /// Return all scenarios in stable, deterministic order.
    #[must_use]
    pub const fn all() -> [Self; 9] {
        [
            Self::FiftyFiftySplit,
            Self::ThirtyThreeSixtySeven,
            Self::TwentyFiveSeventyFive,
            Self::LaneIdModuloFour,
            Self::EarlyExitLaneZero,
            Self::EarlyExitLastLane,
            Self::DataDependentBranch,
            Self::StallLaneSeven,
            Self::FullReconvergence,
        ]
    }

    /// Stable snake_case scenario name.
    #[must_use]
    pub const fn name(self) -> &'static str {
        match self {
            Self::FiftyFiftySplit => "fifty_fifty_split",
            Self::ThirtyThreeSixtySeven => "thirty_three_sixty_seven",
            Self::TwentyFiveSeventyFive => "twenty_five_seventy_five",
            Self::LaneIdModuloFour => "lane_id_modulo_four",
            Self::EarlyExitLaneZero => "early_exit_lane_zero",
            Self::EarlyExitLastLane => "early_exit_last_lane",
            Self::DataDependentBranch => "data_dependent_branch",
            Self::StallLaneSeven => "stall_lane_seven",
            Self::FullReconvergence => "full_reconvergence",
        }
    }
}

impl VyreBackend for NullBackend {
    fn id(&self) -> &'static str {
        "null"
    }
    fn dispatch(
        &self,
        _program: &vyre::Program,
        inputs: &[Vec<u8>],
        _config: &vyre::DispatchConfig,
    ) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
        let input = inputs.first().cloned().unwrap_or_default();
        let output_size = input.len().max(4);
        let mut output = vec![0u8; output_size];
        let copy_len = output_size.min(input.len());
        output[..copy_len].copy_from_slice(&input[..copy_len]);
        Ok(vec![output])
    }
}

/// No-op backend used by the phase-3.2 scaffolding.
#[derive(Debug, Default)]
pub struct NullBackend;

fn scenario_body(scenario: DivergenceScenario) -> &'static str {
    match scenario {
        DivergenceScenario::FiftyFiftySplit => {
            "if (lid.x < 32u) {\n        workgroupBarrier();\n    } else {\n        workgroupBarrier();\n        workgroupBarrier();\n    }\n    workgroupBarrier();"
        }
        DivergenceScenario::ThirtyThreeSixtySeven => {
            "if (lid.x < 21u) {\n        workgroupBarrier();\n    } else {\n        workgroupBarrier();\n        workgroupBarrier();\n    }\n    workgroupBarrier();"
        }
        DivergenceScenario::TwentyFiveSeventyFive => {
            "if (lid.x < 16u) {\n        workgroupBarrier();\n    } else {\n        workgroupBarrier();\n        workgroupBarrier();\n    }\n    workgroupBarrier();"
        }
        DivergenceScenario::LaneIdModuloFour => {
            "if ((lid.x & 3u) == 0u) {\n        workgroupBarrier();\n    } else {\n        workgroupBarrier();\n        workgroupBarrier();\n    }\n    workgroupBarrier();"
        }
        DivergenceScenario::EarlyExitLaneZero => {
            "if (lid.x == 0u) {\n        return 0u;\n    }\n    workgroupBarrier();\n    workgroupBarrier();"
        }
        DivergenceScenario::EarlyExitLastLane => {
            "if (lid.x == 63u) {\n        return 0u;\n    }\n    workgroupBarrier();\n    workgroupBarrier();"
        }
        DivergenceScenario::DataDependentBranch => {
            "if ((input.data[0] == 0u) && ((lid.x & 1u) == 0u) {\n        workgroupBarrier();\n    } else {\n        workgroupBarrier();\n        workgroupBarrier();\n    }\n    workgroupBarrier();"
        }
        DivergenceScenario::StallLaneSeven => {
            "if (lid.x == 7u) {\n        workgroupBarrier();\n        workgroupBarrier();\n        workgroupBarrier();\n    }\n    workgroupBarrier();"
        }
        DivergenceScenario::FullReconvergence => {
            "if (lid.x < 32u) {\n        workgroupBarrier();\n    } else {\n        workgroupBarrier();\n    }\n    workgroupBarrier();"
        }
    }
}

fn dispatch_exact(
    backend: &dyn VyreBackend,
    program: &vyre::Program,
    inputs: &[Vec<u8>],
    output_size: usize,
) -> Result<Vec<u8>, vyre::BackendError> {
    let mut outputs = backend.dispatch(program, inputs, &vyre::DispatchConfig::default())?;
    if outputs.is_empty() {
        return Err(vyre::BackendError::new(
            "backend returned zero output buffers. Fix: return the divergence probe output as outputs[0].",
        ));
    }
    let output = outputs.remove(0);
    if output.len() != output_size {
        return Err(vyre::BackendError::new(format!(
            "backend returned {} bytes, expected {output_size}. Fix: size the first output buffer from the divergence probe output declaration.",
            output.len()
        )));
    }
    Ok(output)
}

/// Registry entry for `divergence` enforcement.
pub struct DivergenceEnforcer;

impl crate::enforce::EnforceGate for DivergenceEnforcer {
    fn id(&self) -> &'static str {
        "divergence"
    }

    fn name(&self) -> &'static str {
        "divergence"
    }

    fn run(&self, ctx: &crate::enforce::EnforceCtx<'_>) -> Vec<crate::enforce::Finding> {
        let Some(backend) = ctx.backend else {
            return vec![crate::enforce::aggregate_finding(
                self.id(),
                vec![
                    "divergence: backend is required. Fix: provide a VyreBackend in EnforceCtx."
                        .to_string(),
                ],
            )];
        };
        let findings = run(ctx.specs, backend);
        let messages = findings
            .into_iter()
            .map(|finding| finding.to_string())
            .collect::<Vec<_>>();
        crate::enforce::finding_result(self.id(), messages)
    }
}

/// Auto-registered `divergence` enforcer.
pub const REGISTERED: DivergenceEnforcer = DivergenceEnforcer;