aprender-core 0.34.0

Next-generation machine learning library in pure Rust
// `avx2-fma-dot-v1` algorithm-level PARTIAL discharge for the 5
// AVX2+FMA dot-product falsifiers (scalar equivalence, empty/unit base
// case, decoder dimension match, commutativity, NaN propagation).
//
// Contract: `contracts/avx2-fma-dot-v1.yaml`.
// Refs: Intel 64 Optimization Manual §11.6 FMA, Agner Fog (2024)
// Instruction Tables (vfmadd231ps: 4-5c latency, 0.5c throughput).

/// Tolerance budget for SIMD-vs-scalar equivalence (4 ULP per contract).
/// At f32 with values ~1.0, 4 ULP ≈ 4 * 2^-23 ≈ 4.77e-7. We use a
/// conservative 4.0 absolute tolerance for the `proof_obligations`
/// `tolerance: 4.0` line.
pub const AC_DOT_SIMD_TOLERANCE_ABS: f32 = 4.0;

/// Commutativity tolerance per `proof_obligations`.
pub const AC_DOT_COMMUT_TOLERANCE: f32 = 1.0e-6;

/// Whisper-tiny canonical decoder dimensions per FALSIFY-DOT-003.
pub const AC_DOT_WHISPER_D_MODEL: usize = 384;
pub const AC_DOT_WHISPER_D_FF: usize = 1536;

// =============================================================================
// FALSIFY-DOT-001 — SIMD matches scalar (within 4 ULP)
// =============================================================================

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DotScalarEquivalenceVerdict {
    /// |dot_simd - dot_scalar| < 4 absolute.
    Pass,
    /// Above tolerance.
    Fail,
}

#[must_use]
pub fn verdict_from_dot_scalar_equivalence(simd: f32, scalar: f32) -> DotScalarEquivalenceVerdict {
    if !simd.is_finite() && !scalar.is_finite() {
        // NaN/inf comparison handled by gate 005; treat divergent NaN
        // signals as out-of-scope here.
        return DotScalarEquivalenceVerdict::Pass;
    }
    if (simd - scalar).abs() < AC_DOT_SIMD_TOLERANCE_ABS {
        DotScalarEquivalenceVerdict::Pass
    } else {
        DotScalarEquivalenceVerdict::Fail
    }
}

// =============================================================================
// FALSIFY-DOT-002 — empty and unit base cases
// =============================================================================

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DotBaseCaseVerdict {
    /// dot([], []) == 0.0 AND dot([x], [y]) == x*y exactly.
    Pass,
    /// Either case wrong.
    Fail,
}

#[must_use]
pub fn verdict_from_dot_base_case(empty_result: f32, unit_x: f32, unit_y: f32, unit_result: f32) -> DotBaseCaseVerdict {
    if empty_result != 0.0 {
        return DotBaseCaseVerdict::Fail;
    }
    let expected = unit_x * unit_y;
    // Exact equality (the contract test is "exact equality tests").
    if expected.is_nan() {
        if !unit_result.is_nan() {
            return DotBaseCaseVerdict::Fail;
        }
    } else if unit_result != expected {
        return DotBaseCaseVerdict::Fail;
    }
    DotBaseCaseVerdict::Pass
}

// =============================================================================
// FALSIFY-DOT-003 — decoder dimensions match scalar
// =============================================================================

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DotDecoderDimVerdict {
    /// SIMD matches scalar at d_model=384 AND d_ff=1536 within tolerance.
    Pass,
    /// Mismatch at either alignment boundary.
    Fail,
}

#[must_use]
pub fn verdict_from_dot_decoder_dim(
    simd_at_dmodel: f32,
    scalar_at_dmodel: f32,
    simd_at_dff: f32,
    scalar_at_dff: f32,
) -> DotDecoderDimVerdict {
    if (simd_at_dmodel - scalar_at_dmodel).abs() >= AC_DOT_SIMD_TOLERANCE_ABS {
        return DotDecoderDimVerdict::Fail;
    }
    if (simd_at_dff - scalar_at_dff).abs() >= AC_DOT_SIMD_TOLERANCE_ABS {
        return DotDecoderDimVerdict::Fail;
    }
    DotDecoderDimVerdict::Pass
}

// =============================================================================
// FALSIFY-DOT-004 — commutativity
// =============================================================================

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DotCommutativityVerdict {
    /// |dot(a,b) - dot(b,a)| < 1e-6.
    Pass,
    /// Asymmetric — accumulation order leaked.
    Fail,
}

#[must_use]
pub fn verdict_from_dot_commutativity(dot_ab: f32, dot_ba: f32) -> DotCommutativityVerdict {
    if (dot_ab - dot_ba).abs() < AC_DOT_COMMUT_TOLERANCE {
        DotCommutativityVerdict::Pass
    } else {
        DotCommutativityVerdict::Fail
    }
}

// =============================================================================
// FALSIFY-DOT-005 — NaN propagation
// =============================================================================

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DotNanPropagationVerdict {
    /// dot input contains NaN ⇒ dot output is NaN.
    Pass,
    /// NaN silently dropped in SIMD lane.
    Fail,
}

#[must_use]
pub fn verdict_from_dot_nan_propagation(input_has_nan: bool, output_is_nan: bool) -> DotNanPropagationVerdict {
    match (input_has_nan, output_is_nan) {
        (true, true) => DotNanPropagationVerdict::Pass,
        (true, false) => DotNanPropagationVerdict::Fail,
        // No NaN in input → output should NOT be NaN. (NaN-from-nowhere is
        // a different bug; this gate only catches NaN-loss.)
        (false, _) => DotNanPropagationVerdict::Pass,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    // -------------------------------------------------------------------------
    // Section 1: Provenance pins.
    // -------------------------------------------------------------------------
    #[test]
    fn provenance_simd_tolerance_4_0() {
        assert!((AC_DOT_SIMD_TOLERANCE_ABS - 4.0).abs() < f32::EPSILON);
    }

    #[test]
    fn provenance_commut_tolerance_1e_neg6() {
        assert!((AC_DOT_COMMUT_TOLERANCE - 1.0e-6).abs() < f32::EPSILON);
    }

    #[test]
    fn provenance_whisper_d_model_384() {
        assert_eq!(AC_DOT_WHISPER_D_MODEL, 384);
    }

    #[test]
    fn provenance_whisper_d_ff_1536() {
        assert_eq!(AC_DOT_WHISPER_D_FF, 1536);
    }

    // -------------------------------------------------------------------------
    // Section 2: DOT-001 SIMD-vs-scalar equivalence.
    // -------------------------------------------------------------------------
    #[test]
    fn fd001_pass_exact_match() {
        assert_eq!(
            verdict_from_dot_scalar_equivalence(123.456, 123.456),
            DotScalarEquivalenceVerdict::Pass
        );
    }

    #[test]
    fn fd001_pass_tiny_drift() {
        assert_eq!(
            verdict_from_dot_scalar_equivalence(123.4567, 123.4565),
            DotScalarEquivalenceVerdict::Pass
        );
    }

    #[test]
    fn fd001_fail_above_tolerance() {
        assert_eq!(
            verdict_from_dot_scalar_equivalence(100.0, 95.0),
            DotScalarEquivalenceVerdict::Fail
        );
    }

    #[test]
    fn fd001_fail_at_tolerance() {
        // Strict less-than.
        assert_eq!(
            verdict_from_dot_scalar_equivalence(0.0, 4.0),
            DotScalarEquivalenceVerdict::Fail
        );
    }

    // -------------------------------------------------------------------------
    // Section 3: DOT-002 base case.
    // -------------------------------------------------------------------------
    #[test]
    fn fd002_pass_canonical_base() {
        assert_eq!(
            verdict_from_dot_base_case(0.0, 3.0, 4.0, 12.0),
            DotBaseCaseVerdict::Pass
        );
    }

    #[test]
    fn fd002_pass_unit_with_zeros() {
        assert_eq!(
            verdict_from_dot_base_case(0.0, 0.0, 0.0, 0.0),
            DotBaseCaseVerdict::Pass
        );
    }

    #[test]
    fn fd002_fail_empty_returned_nonzero() {
        assert_eq!(
            verdict_from_dot_base_case(0.5, 1.0, 1.0, 1.0),
            DotBaseCaseVerdict::Fail
        );
    }

    #[test]
    fn fd002_fail_unit_wrong() {
        // 3*4 should be 12; got 11.
        assert_eq!(
            verdict_from_dot_base_case(0.0, 3.0, 4.0, 11.0),
            DotBaseCaseVerdict::Fail
        );
    }

    // -------------------------------------------------------------------------
    // Section 4: DOT-003 decoder dimensions.
    // -------------------------------------------------------------------------
    #[test]
    fn fd003_pass_both_dims() {
        // 384-dim and 1536-dim agree.
        assert_eq!(
            verdict_from_dot_decoder_dim(123.45, 123.45, 1234.5, 1234.5),
            DotDecoderDimVerdict::Pass
        );
    }

    #[test]
    fn fd003_fail_dmodel_mismatch() {
        assert_eq!(
            verdict_from_dot_decoder_dim(100.0, 90.0, 1234.5, 1234.5),
            DotDecoderDimVerdict::Fail
        );
    }

    #[test]
    fn fd003_fail_dff_mismatch() {
        assert_eq!(
            verdict_from_dot_decoder_dim(123.45, 123.45, 1234.5, 1200.0),
            DotDecoderDimVerdict::Fail
        );
    }

    // -------------------------------------------------------------------------
    // Section 5: DOT-004 commutativity.
    // -------------------------------------------------------------------------
    #[test]
    fn fd004_pass_exact_commutativity() {
        assert_eq!(
            verdict_from_dot_commutativity(42.0, 42.0),
            DotCommutativityVerdict::Pass
        );
    }

    #[test]
    fn fd004_pass_within_1ulp() {
        assert_eq!(
            verdict_from_dot_commutativity(42.0, 42.0 + 1e-7),
            DotCommutativityVerdict::Pass
        );
    }

    #[test]
    fn fd004_fail_asymmetric_accumulation() {
        assert_eq!(
            verdict_from_dot_commutativity(42.0, 41.0),
            DotCommutativityVerdict::Fail
        );
    }

    // -------------------------------------------------------------------------
    // Section 6: DOT-005 NaN propagation.
    // -------------------------------------------------------------------------
    #[test]
    fn fd005_pass_nan_in_nan_out() {
        assert_eq!(
            verdict_from_dot_nan_propagation(true, true),
            DotNanPropagationVerdict::Pass
        );
    }

    #[test]
    fn fd005_pass_no_nan_anywhere() {
        assert_eq!(
            verdict_from_dot_nan_propagation(false, false),
            DotNanPropagationVerdict::Pass
        );
    }

    #[test]
    fn fd005_pass_no_nan_input_finite_out() {
        assert_eq!(
            verdict_from_dot_nan_propagation(false, false),
            DotNanPropagationVerdict::Pass
        );
    }

    #[test]
    fn fd005_fail_nan_in_finite_out() {
        // The exact regression class: NaN silently dropped.
        assert_eq!(
            verdict_from_dot_nan_propagation(true, false),
            DotNanPropagationVerdict::Fail
        );
    }

    // -------------------------------------------------------------------------
    // Section 7: Realistic — full healthy dot product passes all 5.
    // -------------------------------------------------------------------------
    #[test]
    fn realistic_healthy_dot_passes_all_5() {
        // SIMD agrees with scalar.
        assert_eq!(
            verdict_from_dot_scalar_equivalence(123.456, 123.4567),
            DotScalarEquivalenceVerdict::Pass
        );
        // Empty/unit base.
        assert_eq!(
            verdict_from_dot_base_case(0.0, 3.0, 4.0, 12.0),
            DotBaseCaseVerdict::Pass
        );
        // Decoder dims pass.
        assert_eq!(
            verdict_from_dot_decoder_dim(100.0, 100.0, 400.0, 400.0),
            DotDecoderDimVerdict::Pass
        );
        // Commutative.
        assert_eq!(
            verdict_from_dot_commutativity(42.0, 42.0),
            DotCommutativityVerdict::Pass
        );
        // NaN propagated.
        assert_eq!(
            verdict_from_dot_nan_propagation(true, true),
            DotNanPropagationVerdict::Pass
        );
    }

    #[test]
    fn realistic_pre_fix_all_5_failures() {
        // 001: SIMD diverged from scalar by > 4 abs.
        assert_eq!(
            verdict_from_dot_scalar_equivalence(100.0, 50.0),
            DotScalarEquivalenceVerdict::Fail
        );
        // 002: empty returned 1.0.
        assert_eq!(
            verdict_from_dot_base_case(1.0, 3.0, 4.0, 12.0),
            DotBaseCaseVerdict::Fail
        );
        // 003: alignment bug at d_ff=1536.
        assert_eq!(
            verdict_from_dot_decoder_dim(100.0, 100.0, 400.0, 350.0),
            DotDecoderDimVerdict::Fail
        );
        // 004: asymmetric accumulation.
        assert_eq!(
            verdict_from_dot_commutativity(42.0, 100.0),
            DotCommutativityVerdict::Fail
        );
        // 005: NaN dropped.
        assert_eq!(
            verdict_from_dot_nan_propagation(true, false),
            DotNanPropagationVerdict::Fail
        );
    }
}