use super::fallback::SplitCandidate;
pub const BLOCK_SIZE: usize = 2048;
pub use super::fallback::{find_best_split_scalar, find_best_split_simd};
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn copy_gh_interleaved_neon(
gradients: &[f32],
hessians: &[f32],
start: usize,
len: usize,
gh_cache: &mut [(f32, f32); BLOCK_SIZE],
) {
use std::arch::aarch64::*;
let chunks = len / 4;
let remainder = len % 4;
let grad_ptr = gradients.as_ptr().add(start);
let hess_ptr = hessians.as_ptr().add(start);
let cache_ptr = gh_cache.as_mut_ptr() as *mut f32;
for i in 0..chunks {
let offset = i * 4;
let grads = vld1q_f32(grad_ptr.add(offset));
let hess = vld1q_f32(hess_ptr.add(offset));
let interleaved = vzipq_f32(grads, hess);
let dst = cache_ptr.add(offset * 2);
vst1q_f32(dst, interleaved.0);
vst1q_f32(dst.add(4), interleaved.1);
}
let rem_start = chunks * 4;
for i in 0..remainder {
let idx = rem_start + i;
let g = *gradients.get_unchecked(start + idx);
let h = *hessians.get_unchecked(start + idx);
*gh_cache.get_unchecked_mut(idx) = (g, h);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_copy_gh_interleaved_neon() {
let gradients: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
let hessians: Vec<f32> = (0..32).map(|i| i as f32 * 0.01).collect();
let mut gh_cache = [(0.0f32, 0.0f32); BLOCK_SIZE];
unsafe {
copy_gh_interleaved_neon(&gradients, &hessians, 0, 32, &mut gh_cache);
}
for i in 0..32 {
let expected_g = i as f32 * 0.1;
let expected_h = i as f32 * 0.01;
let (actual_g, actual_h) = gh_cache[i];
assert!(
(actual_g - expected_g).abs() < 1e-6,
"Gradient mismatch at {}: expected {}, got {}",
i,
expected_g,
actual_g
);
assert!(
(actual_h - expected_h).abs() < 1e-6,
"Hessian mismatch at {}: expected {}, got {}",
i,
expected_h,
actual_h
);
}
}
#[test]
fn test_copy_gh_interleaved_neon_with_remainder() {
let gradients: Vec<f32> = (0..35).map(|i| i as f32).collect();
let hessians: Vec<f32> = (0..35).map(|i| (i as f32) + 100.0).collect();
let mut gh_cache = [(0.0f32, 0.0f32); BLOCK_SIZE];
unsafe {
copy_gh_interleaved_neon(&gradients, &hessians, 0, 35, &mut gh_cache);
}
for i in 0..35 {
let (g, h) = gh_cache[i];
assert_eq!(g, i as f32, "Gradient mismatch at {}", i);
assert_eq!(h, i as f32 + 100.0, "Hessian mismatch at {}", i);
}
}
}