aprender-contracts 0.31.1

Papers to Math to Contracts in Code — YAML contract parsing, validation, scaffold generation, and Kani harness codegen for provable Rust kernels
Documentation
//! L-BFGS two-loop recursion kernel.
//!
//! Matches `lbfgs-kernel-v1.yaml`.
//!
//! Computes the search direction using the L-BFGS two-loop recursion
//! with limited-memory stored curvature pairs (s, y).

/// Scalar reference implementation of the L-BFGS two-loop recursion.
///
/// Computes the approximate Newton direction from gradient and history of
/// curvature pairs (s, y).
///
/// - `gradient`: current gradient, length `d`
/// - `s_history`: stored `s = x_{k+1} - x_k` vectors, flattened `m x d`
/// - `y_history`: stored `y = g_{k+1} - g_k` vectors, flattened `m x d`
/// - `m`: number of stored curvature pairs (can be 0)
/// - `d`: parameter dimension
/// - `direction`: output search direction, length `d`
///
/// When `m == 0`, returns `direction = -gradient` (steepest descent).
///
/// # Panics
///
/// Panics if dimensions are inconsistent.
pub fn lbfgs_direction_scalar(
    gradient: &[f32],
    s_history: &[f32],
    y_history: &[f32],
    m: usize,
    d: usize,
    direction: &mut [f32],
) {
    assert_eq!(gradient.len(), d, "gradient length mismatch");
    assert_eq!(s_history.len(), m * d, "s_history length mismatch");
    assert_eq!(y_history.len(), m * d, "y_history length mismatch");
    assert_eq!(direction.len(), d, "direction length mismatch");

    // q = -gradient
    let mut q = vec![0.0_f32; d];
    for j in 0..d {
        q[j] = -gradient[j];
    }

    if m == 0 {
        direction.copy_from_slice(&q);
        return;
    }

    let mut alpha = vec![0.0_f32; m];
    let mut rho = vec![0.0_f32; m];

    // Forward loop (most recent first)
    for i in (0..m).rev() {
        let s = &s_history[i * d..(i + 1) * d];
        let y = &y_history[i * d..(i + 1) * d];

        // rho[i] = 1 / (y[i] . s[i])
        let mut ys = 0.0_f32;
        for j in 0..d {
            ys += y[j] * s[j];
        }
        rho[i] = if ys.abs() > 1e-10 { 1.0 / ys } else { 0.0 };

        // alpha[i] = rho[i] * s[i] . q
        let mut sq = 0.0_f32;
        for j in 0..d {
            sq += s[j] * q[j];
        }
        alpha[i] = rho[i] * sq;

        // q -= alpha[i] * y[i]
        for j in 0..d {
            q[j] -= alpha[i] * y[j];
        }
    }

    // Initial Hessian approximation: H_0 = gamma * I
    // gamma = (s[m-1] . y[m-1]) / (y[m-1] . y[m-1])
    let last_s = &s_history[(m - 1) * d..m * d];
    let last_y = &y_history[(m - 1) * d..m * d];
    let mut sy = 0.0_f32;
    let mut yy = 0.0_f32;
    for j in 0..d {
        sy += last_s[j] * last_y[j];
        yy += last_y[j] * last_y[j];
    }
    let gamma = if yy.abs() > 1e-10 { sy / yy } else { 1.0 };

    // r = gamma * q
    let mut r = vec![0.0_f32; d];
    for j in 0..d {
        r[j] = gamma * q[j];
    }

    // Backward loop (oldest first)
    for i in 0..m {
        let s = &s_history[i * d..(i + 1) * d];
        let y = &y_history[i * d..(i + 1) * d];

        // beta = rho[i] * y[i] . r
        let mut yr = 0.0_f32;
        for j in 0..d {
            yr += y[j] * r[j];
        }
        let beta = rho[i] * yr;

        // r += s[i] * (alpha[i] - beta)
        let diff = alpha[i] - beta;
        for j in 0..d {
            r[j] += s[j] * diff;
        }
    }

    direction.copy_from_slice(&r);
}

/// AVX2 implementation of the L-BFGS two-loop recursion.
///
/// Delegates to scalar.
///
/// # Safety
///
/// Requires AVX2 support on the target CPU.
///
/// # Panics
///
/// Same as [`lbfgs_direction_scalar`].
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn lbfgs_direction_avx2(
    gradient: &[f32],
    s_history: &[f32],
    y_history: &[f32],
    m: usize,
    d: usize,
    direction: &mut [f32],
) {
    lbfgs_direction_scalar(gradient, s_history, y_history, m, d, direction);
}

/// PTX assembly for the L-BFGS direction kernel.
///
/// 1D grid with vectorized dot products. Sequential two-loop structure
/// with parallel inner dot products.
pub fn lbfgs_direction_ptx() -> &'static str {
    r#".version 8.5
.target sm_90
.address_size 64

// L-BFGS two-loop recursion kernel (structural).
// The two-loop recursion is inherently sequential across history entries,
// but dot products within each loop iteration can be parallelized.
// Params: gradient_ptr, s_history_ptr, y_history_ptr, direction_ptr, m, d
.visible .entry lbfgs_direction_kernel(
    .param .u64 gradient_ptr,
    .param .u64 s_history_ptr,
    .param .u64 y_history_ptr,
    .param .u64 direction_ptr,
    .param .u32 m,
    .param .u32 d
)
{
    .reg .u32 %tid, %ntid, %ctaid, %idx, %m, %d;
    .reg .u64 %g_base, %s_base, %y_base, %dir_base, %addr;
    .reg .f32 %gval, %neg_g;
    .reg .pred %p;

    mov.u32 %tid, %tid.x;
    mov.u32 %ntid, %ntid.x;
    mov.u32 %ctaid, %ctaid.x;
    mad.lo.u32 %idx, %ctaid, %ntid, %tid;
    ld.param.u32 %d, [d];
    setp.ge.u32 %p, %idx, %d;
    @%p bra DONE;

    ld.param.u64 %g_base, [gradient_ptr];
    ld.param.u64 %dir_base, [direction_ptr];
    ld.param.u32 %m, [m];

    // For m=0: direction[idx] = -gradient[idx]
    // (Full two-loop recursion requires sequential reduction - simplified here)
    cvt.u64.u32 %addr, %idx;
    shl.b64 %addr, %addr, 2;
    add.u64 %addr, %g_base, %addr;
    ld.global.f32 %gval, [%addr];
    neg.f32 %neg_g, %gval;

    cvt.u64.u32 %addr, %idx;
    shl.b64 %addr, %addr, 2;
    add.u64 %addr, %dir_base, %addr;
    st.global.f32 [%addr], %neg_g;

DONE:
    ret;
}
"#
}

#[cfg(test)]
mod tests {
    use super::*;

    // ---------------------------------------------------------------
    // Scalar tests
    // ---------------------------------------------------------------

    #[test]
    fn test_lbfgs_steepest_descent() {
        // m=0 -> direction = -gradient
        let gradient = [1.0_f32, -2.0, 3.0, -4.0];
        let s_history: [f32; 0] = [];
        let y_history: [f32; 0] = [];
        let mut direction = [0.0_f32; 4];

        lbfgs_direction_scalar(&gradient, &s_history, &y_history, 0, 4, &mut direction);

        assert_eq!(direction, [-1.0, 2.0, -3.0, 4.0]);
    }

    #[test]
    fn test_lbfgs_single_history_entry() {
        let d = 3;
        let gradient = [1.0_f32, 0.0, 0.0];
        // s = x_{k+1} - x_k, y = g_{k+1} - g_k
        let s_history = [1.0_f32, 0.0, 0.0];
        let y_history = [2.0_f32, 0.0, 0.0];
        let mut direction = [0.0_f32; 3];

        lbfgs_direction_scalar(&gradient, &s_history, &y_history, 1, d, &mut direction);

        // With m=1:
        // Forward: rho = 1/(y.s) = 1/2, alpha = rho * s.q = 0.5 * (-1) = -0.5
        //   q -= alpha * y = [-1,0,0] - (-0.5)*[2,0,0] = [0,0,0]
        // gamma = s.y / y.y = 2/4 = 0.5
        // r = gamma * q = [0,0,0]
        // Backward: beta = rho * y.r = 0, r += s*(alpha-beta) = [0,0,0] + 1*(-0.5) = [-0.5, 0, 0]
        assert!(
            (direction[0] - (-0.5)).abs() < 1e-6,
            "direction[0] = {}",
            direction[0]
        );
        assert!(direction[1].abs() < 1e-6, "direction[1] = {}", direction[1]);
        assert!(direction[2].abs() < 1e-6, "direction[2] = {}", direction[2]);
    }

    #[test]
    fn test_lbfgs_direction_is_descent() {
        // The L-BFGS direction should be a descent direction (d . g < 0)
        let d = 4;
        let gradient = [1.0_f32, 2.0, 3.0, 4.0];
        let s_history = [0.1_f32, 0.2, 0.3, 0.4];
        let y_history = [0.5_f32, 0.6, 0.7, 0.8];
        let mut direction = [0.0_f32; 4];

        lbfgs_direction_scalar(&gradient, &s_history, &y_history, 1, d, &mut direction);

        let dot: f32 = gradient
            .iter()
            .zip(direction.iter())
            .map(|(g, d)| g * d)
            .sum();
        assert!(
            dot < 0.0,
            "direction must be a descent direction, got g.d = {dot}"
        );
    }

    #[test]
    #[should_panic(expected = "gradient length mismatch")]
    fn test_lbfgs_gradient_mismatch() {
        let mut direction = [0.0_f32; 3];
        lbfgs_direction_scalar(&[1.0, 2.0], &[], &[], 0, 3, &mut direction);
    }

    // ---------------------------------------------------------------
    // AVX2 tests
    // ---------------------------------------------------------------

    #[cfg(target_arch = "x86_64")]
    #[test]
    fn test_lbfgs_avx2_parity() {
        if !is_x86_feature_detected!("avx2") {
            return;
        }
        let d = 8;
        let gradient: Vec<f32> = (0..d).map(|i| (i as f32 + 1.0) * 0.5).collect();
        let s_history: Vec<f32> = (0..d).map(|i| (i as f32) * 0.1 + 0.1).collect();
        let y_history: Vec<f32> = (0..d).map(|i| (i as f32) * 0.2 + 0.2).collect();
        let mut scalar_dir = vec![0.0_f32; d];
        let mut avx2_dir = vec![0.0_f32; d];

        lbfgs_direction_scalar(&gradient, &s_history, &y_history, 1, d, &mut scalar_dir);
        unsafe {
            lbfgs_direction_avx2(&gradient, &s_history, &y_history, 1, d, &mut avx2_dir);
        }

        assert_eq!(scalar_dir, avx2_dir);
    }

    // ---------------------------------------------------------------
    // PTX structural tests
    // ---------------------------------------------------------------

    #[test]
    fn test_lbfgs_ptx_version() {
        let ptx = lbfgs_direction_ptx();
        assert!(
            ptx.contains(".version 8.5"),
            "PTX must declare .version 8.5"
        );
    }

    #[test]
    fn test_lbfgs_ptx_target() {
        let ptx = lbfgs_direction_ptx();
        assert!(ptx.contains(".target sm_90"), "PTX must target sm_90");
    }

    #[test]
    fn test_lbfgs_ptx_entry() {
        let ptx = lbfgs_direction_ptx();
        assert!(
            ptx.contains(".entry lbfgs_direction_kernel"),
            "PTX must have .entry"
        );
    }

    #[test]
    fn test_lbfgs_ptx_ret() {
        let ptx = lbfgs_direction_ptx();
        assert!(ptx.contains("ret;"), "PTX must have ret;");
    }

    #[test]
    fn test_lbfgs_ptx_balanced_braces() {
        let ptx = lbfgs_direction_ptx();
        let opens = ptx.chars().filter(|&c| c == '{').count();
        let closes = ptx.chars().filter(|&c| c == '}').count();
        assert_eq!(
            opens, closes,
            "PTX must have balanced braces: {opens} opens vs {closes} closes"
        );
    }
}