pub mod fallback;
#[cfg(target_arch = "x86_64")]
pub mod x86;
#[cfg(target_arch = "aarch64")]
pub mod arm;
#[cfg(target_arch = "x86_64")]
pub use x86::{find_best_split_scalar, find_best_split_simd, SplitCandidate};
#[cfg(target_arch = "x86_64")]
pub use x86::{unpack_4bit, unpack_4bit_scalar};
#[cfg(target_arch = "aarch64")]
pub use arm::{find_best_split_scalar, find_best_split_simd};
#[cfg(target_arch = "aarch64")]
pub use fallback::SplitCandidate;
#[cfg(target_arch = "aarch64")]
pub use fallback::{unpack_4bit, unpack_4bit_scalar};
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
pub use fallback::{find_best_split_scalar, find_best_split_simd, SplitCandidate};
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
pub use fallback::{unpack_4bit, unpack_4bit_scalar};
use std::sync::OnceLock;
static SIMD_LEVEL: OnceLock<SimdLevel> = OnceLock::new();
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdLevel {
Scalar,
Avx2,
Avx512,
Neon,
}
impl SimdLevel {
#[cfg(target_arch = "x86_64")]
fn detect() -> Self {
if std::arch::is_x86_feature_detected!("avx512f") {
SimdLevel::Avx512
} else if std::arch::is_x86_feature_detected!("avx2") {
SimdLevel::Avx2
} else {
SimdLevel::Scalar
}
}
#[cfg(target_arch = "aarch64")]
fn detect() -> Self {
SimdLevel::Neon
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
fn detect() -> Self {
SimdLevel::Scalar
}
}
#[inline]
pub fn simd_level() -> SimdLevel {
*SIMD_LEVEL.get_or_init(SimdLevel::detect)
}
#[inline]
pub fn has_avx2() -> bool {
matches!(simd_level(), SimdLevel::Avx2 | SimdLevel::Avx512)
}
#[inline]
pub fn has_avx512() -> bool {
matches!(simd_level(), SimdLevel::Avx512)
}
#[inline]
pub fn has_neon() -> bool {
matches!(simd_level(), SimdLevel::Neon)
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub fn find_best_split(
hist_grads: &[f32; 256],
hist_hess: &[f32; 256],
hist_counts: &[u32; 256],
total_gradient: f32,
total_hessian: f32,
total_count: u32,
lambda: f32,
min_samples_leaf: u32,
min_hessian_leaf: f32,
) -> Option<SplitCandidate> {
let params = fallback::SplitParams {
total_gradient,
total_hessian,
total_count,
lambda,
min_samples_leaf,
min_hessian_leaf,
};
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
return unsafe { find_best_split_simd(hist_grads, hist_hess, hist_counts, params) };
}
}
find_best_split_scalar(hist_grads, hist_hess, hist_counts, params)
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn histogram_accumulate(
feature_bins: *const u8,
row_indices: *const usize,
num_rows: usize,
gradients: *const f32,
hessians: *const f32,
hist_grads: *mut f32,
hist_hess: *mut f32,
hist_counts: *mut u32,
) {
#[cfg(target_arch = "x86_64")]
{
match simd_level() {
SimdLevel::Avx512 | SimdLevel::Avx2 => {
x86::histogram_accumulate_avx2(fallback::HistogramAccumParams {
feature_bins,
row_indices,
num_rows,
gradients,
hessians,
hist_grads,
hist_hess,
hist_counts,
})
}
_ => fallback::histogram_accumulate_scalar(fallback::HistogramAccumParams {
feature_bins,
row_indices,
num_rows,
gradients,
hessians,
hist_grads,
hist_hess,
hist_counts,
}),
}
}
#[cfg(target_arch = "aarch64")]
{
fallback::histogram_accumulate_scalar(fallback::HistogramAccumParams {
feature_bins,
row_indices,
num_rows,
gradients,
hessians,
hist_grads,
hist_hess,
hist_counts,
})
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
fallback::histogram_accumulate_scalar(fallback::HistogramAccumParams {
feature_bins,
row_indices,
num_rows,
gradients,
hessians,
hist_grads,
hist_hess,
hist_counts,
})
}
}
#[inline]
pub unsafe fn histogram_accumulate_contiguous(
feature_bins: *const u8,
num_rows: usize,
gradients: *const f32,
hessians: *const f32,
hist_grads: *mut f32,
hist_hess: *mut f32,
hist_counts: *mut u32,
) {
#[cfg(target_arch = "x86_64")]
{
match simd_level() {
SimdLevel::Avx512 | SimdLevel::Avx2 => x86::histogram_accumulate_contiguous_avx2(
feature_bins,
num_rows,
gradients,
hessians,
hist_grads,
hist_hess,
hist_counts,
),
_ => fallback::histogram_accumulate_contiguous_scalar(
feature_bins,
num_rows,
gradients,
hessians,
hist_grads,
hist_hess,
hist_counts,
),
}
}
#[cfg(target_arch = "aarch64")]
{
fallback::histogram_accumulate_contiguous_scalar(
feature_bins,
num_rows,
gradients,
hessians,
hist_grads,
hist_hess,
hist_counts,
)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
fallback::histogram_accumulate_contiguous_scalar(
feature_bins,
num_rows,
gradients,
hessians,
hist_grads,
hist_hess,
hist_counts,
)
}
}
#[inline]
pub fn merge_histogram_grads(self_grads: &mut [f32; 256], other_grads: &[f32; 256]) {
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
unsafe {
x86::merge_histogram_grads_avx2(self_grads, other_grads);
}
return;
}
}
fallback::merge_histogram_grads_scalar(self_grads, other_grads);
}
#[inline]
pub fn merge_histogram_hess(self_hess: &mut [f32; 256], other_hess: &[f32; 256]) {
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
unsafe {
x86::merge_histogram_hess_avx2(self_hess, other_hess);
}
return;
}
}
fallback::merge_histogram_hess_scalar(self_hess, other_hess);
}
#[inline]
pub fn merge_histogram_counts(self_counts: &mut [u32; 256], other_counts: &[u32; 256]) {
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
unsafe {
x86::merge_histogram_counts_avx2(self_counts, other_counts);
}
return;
}
}
fallback::merge_histogram_counts_scalar(self_counts, other_counts);
}
#[inline]
pub fn subtract_histogram_grads(self_grads: &mut [f32; 256], other_grads: &[f32; 256]) {
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
unsafe {
x86::subtract_histogram_grads_avx2(self_grads, other_grads);
}
return;
}
}
fallback::subtract_histogram_grads_scalar(self_grads, other_grads);
}
#[inline]
pub fn subtract_histogram_hess(self_hess: &mut [f32; 256], other_hess: &[f32; 256]) {
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
unsafe {
x86::subtract_histogram_hess_avx2(self_hess, other_hess);
}
return;
}
}
fallback::subtract_histogram_hess_scalar(self_hess, other_hess);
}
#[inline]
pub fn subtract_histogram_counts(self_counts: &mut [u32; 256], other_counts: &[u32; 256]) {
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
unsafe {
x86::subtract_histogram_counts_avx2(self_counts, other_counts);
}
return;
}
}
fallback::subtract_histogram_counts_scalar(self_counts, other_counts);
}
pub const BLOCK_SIZE: usize = 2048;
#[inline]
pub unsafe fn copy_gh_interleaved(
gradients: &[f32],
hessians: &[f32],
start: usize,
len: usize,
gh_cache: &mut [(f32, f32); BLOCK_SIZE],
) {
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
x86::copy_gh_interleaved_avx2(gradients, hessians, start, len, gh_cache);
return;
}
}
#[cfg(target_arch = "aarch64")]
{
arm::copy_gh_interleaved_neon(gradients, hessians, start, len, gh_cache);
return;
}
fallback::copy_gh_interleaved_scalar(gradients, hessians, start, len, gh_cache);
}
#[inline]
pub fn copy_gh_indexed(
gradients: &[f32],
hessians: &[f32],
indices: &[usize],
gh_cache: &mut [(f32, f32); BLOCK_SIZE],
) {
fallback::copy_gh_indexed_scalar(gradients, hessians, indices, gh_cache);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_detection() {
let level = simd_level();
println!("Detected SIMD level: {:?}", level);
#[cfg(target_arch = "x86_64")]
{
println!("AVX2 available: {}", has_avx2());
println!("AVX-512 available: {}", has_avx512());
}
}
#[test]
fn test_histogram_accumulate_basic() {
let feature_bins: Vec<u8> = vec![0, 1, 2, 0, 1, 2, 0, 1];
let row_indices: Vec<usize> = (0..8).collect();
let gradients: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let hessians: Vec<f32> = vec![1.0; 8];
let mut hist_grads = [0.0f32; 256];
let mut hist_hess = [0.0f32; 256];
let mut hist_counts = [0u32; 256];
unsafe {
histogram_accumulate(
feature_bins.as_ptr(),
row_indices.as_ptr(),
8,
gradients.as_ptr(),
hessians.as_ptr(),
hist_grads.as_mut_ptr(),
hist_hess.as_mut_ptr(),
hist_counts.as_mut_ptr(),
);
}
assert!((hist_grads[0] - 12.0).abs() < 1e-5);
assert_eq!(hist_counts[0], 3);
assert!((hist_grads[1] - 15.0).abs() < 1e-5);
assert_eq!(hist_counts[1], 3);
assert!((hist_grads[2] - 9.0).abs() < 1e-5);
assert_eq!(hist_counts[2], 2);
}
#[test]
fn test_histogram_accumulate_contiguous() {
let feature_bins: Vec<u8> = vec![0, 1, 2, 0, 1, 2, 0, 1];
let gradients: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let hessians: Vec<f32> = vec![1.0; 8];
let mut hist_grads = [0.0f32; 256];
let mut hist_hess = [0.0f32; 256];
let mut hist_counts = [0u32; 256];
unsafe {
histogram_accumulate_contiguous(
feature_bins.as_ptr(),
8,
gradients.as_ptr(),
hessians.as_ptr(),
hist_grads.as_mut_ptr(),
hist_hess.as_mut_ptr(),
hist_counts.as_mut_ptr(),
);
}
assert!((hist_grads[0] - 12.0).abs() < 1e-5);
assert_eq!(hist_counts[0], 3);
assert!((hist_grads[1] - 15.0).abs() < 1e-5);
assert_eq!(hist_counts[1], 3);
assert!((hist_grads[2] - 9.0).abs() < 1e-5);
assert_eq!(hist_counts[2], 2);
}
}