pub fn lbfgs_direction_scalar(
gradient: &[f32],
s_history: &[f32],
y_history: &[f32],
m: usize,
d: usize,
direction: &mut [f32],
) {
assert_eq!(gradient.len(), d, "gradient length mismatch");
assert_eq!(s_history.len(), m * d, "s_history length mismatch");
assert_eq!(y_history.len(), m * d, "y_history length mismatch");
assert_eq!(direction.len(), d, "direction length mismatch");
let mut q = vec![0.0_f32; d];
for j in 0..d {
q[j] = -gradient[j];
}
if m == 0 {
direction.copy_from_slice(&q);
return;
}
let mut alpha = vec![0.0_f32; m];
let mut rho = vec![0.0_f32; m];
for i in (0..m).rev() {
let s = &s_history[i * d..(i + 1) * d];
let y = &y_history[i * d..(i + 1) * d];
let mut ys = 0.0_f32;
for j in 0..d {
ys += y[j] * s[j];
}
rho[i] = if ys.abs() > 1e-10 { 1.0 / ys } else { 0.0 };
let mut sq = 0.0_f32;
for j in 0..d {
sq += s[j] * q[j];
}
alpha[i] = rho[i] * sq;
for j in 0..d {
q[j] -= alpha[i] * y[j];
}
}
let last_s = &s_history[(m - 1) * d..m * d];
let last_y = &y_history[(m - 1) * d..m * d];
let mut sy = 0.0_f32;
let mut yy = 0.0_f32;
for j in 0..d {
sy += last_s[j] * last_y[j];
yy += last_y[j] * last_y[j];
}
let gamma = if yy.abs() > 1e-10 { sy / yy } else { 1.0 };
let mut r = vec![0.0_f32; d];
for j in 0..d {
r[j] = gamma * q[j];
}
for i in 0..m {
let s = &s_history[i * d..(i + 1) * d];
let y = &y_history[i * d..(i + 1) * d];
let mut yr = 0.0_f32;
for j in 0..d {
yr += y[j] * r[j];
}
let beta = rho[i] * yr;
let diff = alpha[i] - beta;
for j in 0..d {
r[j] += s[j] * diff;
}
}
direction.copy_from_slice(&r);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn lbfgs_direction_avx2(
gradient: &[f32],
s_history: &[f32],
y_history: &[f32],
m: usize,
d: usize,
direction: &mut [f32],
) {
lbfgs_direction_scalar(gradient, s_history, y_history, m, d, direction);
}
pub fn lbfgs_direction_ptx() -> &'static str {
r#".version 8.5
.target sm_90
.address_size 64
// L-BFGS two-loop recursion kernel (structural).
// The two-loop recursion is inherently sequential across history entries,
// but dot products within each loop iteration can be parallelized.
// Params: gradient_ptr, s_history_ptr, y_history_ptr, direction_ptr, m, d
.visible .entry lbfgs_direction_kernel(
.param .u64 gradient_ptr,
.param .u64 s_history_ptr,
.param .u64 y_history_ptr,
.param .u64 direction_ptr,
.param .u32 m,
.param .u32 d
)
{
.reg .u32 %tid, %ntid, %ctaid, %idx, %m, %d;
.reg .u64 %g_base, %s_base, %y_base, %dir_base, %addr;
.reg .f32 %gval, %neg_g;
.reg .pred %p;
mov.u32 %tid, %tid.x;
mov.u32 %ntid, %ntid.x;
mov.u32 %ctaid, %ctaid.x;
mad.lo.u32 %idx, %ctaid, %ntid, %tid;
ld.param.u32 %d, [d];
setp.ge.u32 %p, %idx, %d;
@%p bra DONE;
ld.param.u64 %g_base, [gradient_ptr];
ld.param.u64 %dir_base, [direction_ptr];
ld.param.u32 %m, [m];
// For m=0: direction[idx] = -gradient[idx]
// (Full two-loop recursion requires sequential reduction - simplified here)
cvt.u64.u32 %addr, %idx;
shl.b64 %addr, %addr, 2;
add.u64 %addr, %g_base, %addr;
ld.global.f32 %gval, [%addr];
neg.f32 %neg_g, %gval;
cvt.u64.u32 %addr, %idx;
shl.b64 %addr, %addr, 2;
add.u64 %addr, %dir_base, %addr;
st.global.f32 [%addr], %neg_g;
DONE:
ret;
}
"#
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lbfgs_steepest_descent() {
let gradient = [1.0_f32, -2.0, 3.0, -4.0];
let s_history: [f32; 0] = [];
let y_history: [f32; 0] = [];
let mut direction = [0.0_f32; 4];
lbfgs_direction_scalar(&gradient, &s_history, &y_history, 0, 4, &mut direction);
assert_eq!(direction, [-1.0, 2.0, -3.0, 4.0]);
}
#[test]
fn test_lbfgs_single_history_entry() {
let d = 3;
let gradient = [1.0_f32, 0.0, 0.0];
let s_history = [1.0_f32, 0.0, 0.0];
let y_history = [2.0_f32, 0.0, 0.0];
let mut direction = [0.0_f32; 3];
lbfgs_direction_scalar(&gradient, &s_history, &y_history, 1, d, &mut direction);
assert!(
(direction[0] - (-0.5)).abs() < 1e-6,
"direction[0] = {}",
direction[0]
);
assert!(direction[1].abs() < 1e-6, "direction[1] = {}", direction[1]);
assert!(direction[2].abs() < 1e-6, "direction[2] = {}", direction[2]);
}
#[test]
fn test_lbfgs_direction_is_descent() {
let d = 4;
let gradient = [1.0_f32, 2.0, 3.0, 4.0];
let s_history = [0.1_f32, 0.2, 0.3, 0.4];
let y_history = [0.5_f32, 0.6, 0.7, 0.8];
let mut direction = [0.0_f32; 4];
lbfgs_direction_scalar(&gradient, &s_history, &y_history, 1, d, &mut direction);
let dot: f32 = gradient
.iter()
.zip(direction.iter())
.map(|(g, d)| g * d)
.sum();
assert!(
dot < 0.0,
"direction must be a descent direction, got g.d = {dot}"
);
}
#[test]
#[should_panic(expected = "gradient length mismatch")]
fn test_lbfgs_gradient_mismatch() {
let mut direction = [0.0_f32; 3];
lbfgs_direction_scalar(&[1.0, 2.0], &[], &[], 0, 3, &mut direction);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_lbfgs_avx2_parity() {
if !is_x86_feature_detected!("avx2") {
return;
}
let d = 8;
let gradient: Vec<f32> = (0..d).map(|i| (i as f32 + 1.0) * 0.5).collect();
let s_history: Vec<f32> = (0..d).map(|i| (i as f32) * 0.1 + 0.1).collect();
let y_history: Vec<f32> = (0..d).map(|i| (i as f32) * 0.2 + 0.2).collect();
let mut scalar_dir = vec![0.0_f32; d];
let mut avx2_dir = vec![0.0_f32; d];
lbfgs_direction_scalar(&gradient, &s_history, &y_history, 1, d, &mut scalar_dir);
unsafe {
lbfgs_direction_avx2(&gradient, &s_history, &y_history, 1, d, &mut avx2_dir);
}
assert_eq!(scalar_dir, avx2_dir);
}
#[test]
fn test_lbfgs_ptx_version() {
let ptx = lbfgs_direction_ptx();
assert!(
ptx.contains(".version 8.5"),
"PTX must declare .version 8.5"
);
}
#[test]
fn test_lbfgs_ptx_target() {
let ptx = lbfgs_direction_ptx();
assert!(ptx.contains(".target sm_90"), "PTX must target sm_90");
}
#[test]
fn test_lbfgs_ptx_entry() {
let ptx = lbfgs_direction_ptx();
assert!(
ptx.contains(".entry lbfgs_direction_kernel"),
"PTX must have .entry"
);
}
#[test]
fn test_lbfgs_ptx_ret() {
let ptx = lbfgs_direction_ptx();
assert!(ptx.contains("ret;"), "PTX must have ret;");
}
#[test]
fn test_lbfgs_ptx_balanced_braces() {
let ptx = lbfgs_direction_ptx();
let opens = ptx.chars().filter(|&c| c == '{').count();
let closes = ptx.chars().filter(|&c| c == '}').count();
assert_eq!(
opens, closes,
"PTX must have balanced braces: {opens} opens vs {closes} closes"
);
}
}