use super::constants::WEIGHT_SCALE_BITS;
pub trait FtActivation: Clone + Copy + Default + Send + Sync + 'static {
const OUTPUT_DIM_DIVISOR: usize;
fn activate_i16_to_u8(input: &[i16], output: &mut [u8], qa: i16);
fn activate_i32_to_u8(input: &[i32], output: &mut [u8]);
fn header_suffix() -> &'static str;
fn name() -> &'static str;
}
#[derive(Clone, Copy, Default)]
pub struct CReLU;
impl FtActivation for CReLU {
const OUTPUT_DIM_DIVISOR: usize = 1;
#[inline]
fn activate_i16_to_u8(input: &[i16], output: &mut [u8], qa: i16) {
debug_assert_eq!(input.len(), output.len());
crelu_i16_to_u8(input, output, qa);
}
#[inline]
fn activate_i32_to_u8(input: &[i32], output: &mut [u8]) {
debug_assert_eq!(input.len(), output.len());
crelu_i32_to_u8(input, output);
}
fn header_suffix() -> &'static str {
""
}
fn name() -> &'static str {
"CReLU"
}
}
fn crelu_i16_to_u8(input: &[i16], output: &mut [u8], qa: i16) {
#[cfg(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "wasm32", target_feature = "simd128")
))]
let mut processed = 0;
#[cfg(not(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "wasm32", target_feature = "simd128")
)))]
let processed = 0;
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
let num_chunks = input.len() / 16;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm256_setzero_si256();
let max_val = _mm256_set1_epi16(qa);
let in_ptr = input.as_ptr();
let out_ptr = output.as_mut_ptr();
for i in 0..num_chunks {
let v = _mm256_loadu_si256(in_ptr.add(i * 16) as *const __m256i);
let clamped = _mm256_min_epi16(_mm256_max_epi16(v, zero), max_val);
let packed = _mm256_packus_epi16(clamped, clamped);
let result = _mm256_permute4x64_epi64(packed, 0b11011000);
_mm_storeu_si128(
out_ptr.add(i * 16) as *mut __m128i,
_mm256_castsi256_si128(result),
);
}
}
processed = num_chunks * 16;
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
{
let remaining = input.len() - processed;
let num_chunks = remaining / 16;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm_setzero_si128();
let max_val = _mm_set1_epi16(qa);
let in_ptr = input.as_ptr().add(processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let v0 = _mm_loadu_si128(in_ptr.add(i * 16) as *const __m128i);
let v1 = _mm_loadu_si128(in_ptr.add(i * 16 + 8) as *const __m128i);
let clamped0 = _mm_min_epi16(_mm_max_epi16(v0, zero), max_val);
let clamped1 = _mm_min_epi16(_mm_max_epi16(v1, zero), max_val);
let packed = _mm_packus_epi16(clamped0, clamped1);
_mm_storeu_si128(out_ptr.add(i * 16) as *mut __m128i, packed);
}
}
processed += num_chunks * 16;
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
let remaining = input.len() - processed;
let num_chunks = remaining / 8;
if num_chunks > 0 {
unsafe {
use std::arch::wasm32::*;
let zero = i16x8_splat(0);
let max_val = i16x8_splat(qa);
let in_ptr = input.as_ptr().add(processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let v = v128_load(in_ptr.add(i * 8) as *const v128);
let clamped = i16x8_min(i16x8_max(v, zero), max_val);
let packed = u8x16_narrow_i16x8(clamped, clamped);
v128_store64_lane::<0>(packed, out_ptr.add(i * 8) as *mut u64);
}
}
processed += num_chunks * 8;
}
}
for i in processed..input.len() {
output[i] = input[i].clamp(0, qa) as u8;
}
}
fn crelu_i32_to_u8(input: &[i32], output: &mut [u8]) {
#[cfg(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse2")
))]
let mut processed = 0;
#[cfg(not(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse2")
)))]
let processed = 0;
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
let num_chunks = input.len() / 32;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm256_setzero_si256();
let offsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
let in_ptr = input.as_ptr() as *const __m256i;
let out_ptr = output.as_mut_ptr() as *mut __m256i;
for i in 0..num_chunks {
let in0 = _mm256_loadu_si256(in_ptr.add(i * 4));
let in1 = _mm256_loadu_si256(in_ptr.add(i * 4 + 1));
let in2 = _mm256_loadu_si256(in_ptr.add(i * 4 + 2));
let in3 = _mm256_loadu_si256(in_ptr.add(i * 4 + 3));
let words0 =
_mm256_srai_epi16(_mm256_packs_epi32(in0, in1), WEIGHT_SCALE_BITS as i32);
let words1 =
_mm256_srai_epi16(_mm256_packs_epi32(in2, in3), WEIGHT_SCALE_BITS as i32);
let bytes = _mm256_max_epi8(_mm256_packs_epi16(words0, words1), zero);
let result = _mm256_permutevar8x32_epi32(bytes, offsets);
_mm256_storeu_si256(out_ptr.add(i), result);
}
}
processed = num_chunks * 32;
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
{
let remaining = input.len() - processed;
let num_chunks = remaining / 16;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
#[cfg(target_feature = "sse4.1")]
let zero = _mm_setzero_si128();
#[cfg(not(target_feature = "sse4.1"))]
let k0x80s = _mm_set1_epi8(-128i8);
let in_ptr = input.as_ptr().add(processed) as *const __m128i;
let out_ptr = output.as_mut_ptr().add(processed) as *mut __m128i;
for i in 0..num_chunks {
let in0 = _mm_loadu_si128(in_ptr.add(i * 4));
let in1 = _mm_loadu_si128(in_ptr.add(i * 4 + 1));
let in2 = _mm_loadu_si128(in_ptr.add(i * 4 + 2));
let in3 = _mm_loadu_si128(in_ptr.add(i * 4 + 3));
let words0 =
_mm_srai_epi16(_mm_packs_epi32(in0, in1), WEIGHT_SCALE_BITS as i32);
let words1 =
_mm_srai_epi16(_mm_packs_epi32(in2, in3), WEIGHT_SCALE_BITS as i32);
let packedbytes = _mm_packs_epi16(words0, words1);
#[cfg(target_feature = "sse4.1")]
let result = _mm_max_epi8(packedbytes, zero);
#[cfg(not(target_feature = "sse4.1"))]
let result = _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s);
_mm_storeu_si128(out_ptr.add(i), result);
}
}
processed += num_chunks * 16;
}
}
for i in processed..input.len() {
let shifted = input[i] >> WEIGHT_SCALE_BITS;
output[i] = shifted.clamp(0, 127) as u8;
}
}
#[derive(Clone, Copy, Default)]
pub struct PairwiseCReLU;
impl FtActivation for PairwiseCReLU {
const OUTPUT_DIM_DIVISOR: usize = 2;
#[inline]
fn activate_i16_to_u8(input: &[i16], output: &mut [u8], qa: i16) {
debug_assert_eq!(input.len(), output.len() * 2);
pairwise_crelu_i16_to_u8(input, output, qa);
}
#[inline]
fn activate_i32_to_u8(input: &[i32], output: &mut [u8]) {
debug_assert_eq!(input.len(), output.len());
crelu_i32_to_u8(input, output);
}
fn header_suffix() -> &'static str {
"-Pairwise"
}
fn name() -> &'static str {
"PairwiseCReLU"
}
}
fn pairwise_crelu_i16_to_u8(input: &[i16], output: &mut [u8], qa: i16) {
let ft_out = input.len(); let l1 = ft_out / 2; let quarter = l1 / 2;
debug_assert_eq!(output.len(), l1, "output length must be L1 (= input.len() / 2)");
if qa >= 255 {
pairwise_crelu_i16_to_u8_inner::<255, 9, 127>(
&input[0..l1],
&mut output[0..quarter],
quarter,
);
pairwise_crelu_i16_to_u8_inner::<255, 9, 127>(
&input[l1..ft_out],
&mut output[quarter..l1],
quarter,
);
} else {
pairwise_crelu_i16_to_u8_inner::<127, 7, 126>(
&input[0..l1],
&mut output[0..quarter],
quarter,
);
pairwise_crelu_i16_to_u8_inner::<127, 7, 126>(
&input[l1..ft_out],
&mut output[quarter..l1],
quarter,
);
}
}
fn pairwise_crelu_i16_to_u8_inner<const QA: i32, const SHIFT: i32, const MAX_OUT: i32>(
input: &[i16],
output: &mut [u8],
half: usize,
) {
const {
assert!(
(QA == 127 && SHIFT == 7 && MAX_OUT == 126)
|| (QA == 255 && SHIFT == 9 && MAX_OUT == 127),
"Invalid QA/SHIFT/MAX_OUT combination"
);
}
#[cfg(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse4.1"),
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "wasm32", target_feature = "simd128")
))]
let mut processed = 0usize;
#[cfg(not(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse4.1"),
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "wasm32", target_feature = "simd128")
)))]
let processed = 0usize;
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
let num_chunks = half / 8;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm256_setzero_si256();
let max_clamp = _mm256_set1_epi32(QA);
let max_out = _mm256_set1_epi32(MAX_OUT);
let a_ptr = input.as_ptr();
let b_ptr = input.as_ptr().add(half);
let out_ptr = output.as_mut_ptr();
for i in 0..num_chunks {
let a_i16 = _mm_loadu_si128(a_ptr.add(i * 8) as *const __m128i);
let b_i16 = _mm_loadu_si128(b_ptr.add(i * 8) as *const __m128i);
let a = _mm256_cvtepi16_epi32(a_i16);
let b = _mm256_cvtepi16_epi32(b_i16);
let a_clamped = _mm256_min_epi32(_mm256_max_epi32(a, zero), max_clamp);
let b_clamped = _mm256_min_epi32(_mm256_max_epi32(b, zero), max_clamp);
let product = _mm256_mullo_epi32(a_clamped, b_clamped);
let result = _mm256_min_epi32(_mm256_srai_epi32(product, SHIFT), max_out);
let packed16 = _mm256_packs_epi32(result, result);
let packed8 = _mm256_packus_epi16(packed16, packed16);
let lo = _mm256_castsi256_si128(packed8);
let hi = _mm256_extracti128_si256(packed8, 1);
let combined = _mm_unpacklo_epi32(lo, hi);
_mm_storel_epi64(out_ptr.add(i * 8) as *mut __m128i, combined);
}
}
processed = num_chunks * 8;
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "sse4.1"))]
{
let remaining = half - processed;
let num_chunks = remaining / 4;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm_setzero_si128();
let max_clamp = _mm_set1_epi32(QA);
let max_out = _mm_set1_epi32(MAX_OUT);
let a_ptr = input.as_ptr().add(processed);
let b_ptr = input.as_ptr().add(half + processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let a_i16 = _mm_loadl_epi64(a_ptr.add(i * 4) as *const __m128i);
let b_i16 = _mm_loadl_epi64(b_ptr.add(i * 4) as *const __m128i);
let a = _mm_cvtepi16_epi32(a_i16);
let b = _mm_cvtepi16_epi32(b_i16);
let a_clamped = _mm_min_epi32(_mm_max_epi32(a, zero), max_clamp);
let b_clamped = _mm_min_epi32(_mm_max_epi32(b, zero), max_clamp);
let product = _mm_mullo_epi32(a_clamped, b_clamped);
let result = _mm_min_epi32(_mm_srai_epi32(product, SHIFT), max_out);
let packed16 = _mm_packs_epi32(result, result);
let packed8 = _mm_packus_epi16(packed16, packed16);
let val = _mm_cvtsi128_si32(packed8) as u32;
std::ptr::copy_nonoverlapping(
&val as *const u32 as *const u8,
out_ptr.add(i * 4),
4,
);
}
}
processed += num_chunks * 4;
}
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "sse2",
not(target_feature = "sse4.1")
))]
{
let remaining = half - processed;
let num_chunks = remaining / 4;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm_setzero_si128();
let max_clamp = _mm_set1_epi32(QA);
let max_out = _mm_set1_epi32(MAX_OUT);
let a_ptr = input.as_ptr().add(processed);
let b_ptr = input.as_ptr().add(half + processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let a_i16 = _mm_loadl_epi64(a_ptr.add(i * 4) as *const __m128i);
let b_i16 = _mm_loadl_epi64(b_ptr.add(i * 4) as *const __m128i);
let a_sign = _mm_cmpgt_epi16(zero, a_i16);
let a = _mm_unpacklo_epi16(a_i16, a_sign);
let b_sign = _mm_cmpgt_epi16(zero, b_i16);
let b = _mm_unpacklo_epi16(b_i16, b_sign);
let a_gt_zero = _mm_cmpgt_epi32(a, zero);
let a_max_zero = _mm_and_si128(a, a_gt_zero);
let a_lt_clamp = _mm_cmpgt_epi32(max_clamp, a_max_zero);
let a_clamped = _mm_or_si128(
_mm_and_si128(a_max_zero, a_lt_clamp),
_mm_andnot_si128(a_lt_clamp, max_clamp),
);
let b_gt_zero = _mm_cmpgt_epi32(b, zero);
let b_max_zero = _mm_and_si128(b, b_gt_zero);
let b_lt_clamp = _mm_cmpgt_epi32(max_clamp, b_max_zero);
let b_clamped = _mm_or_si128(
_mm_and_si128(b_max_zero, b_lt_clamp),
_mm_andnot_si128(b_lt_clamp, max_clamp),
);
let a_lo = a_clamped;
let b_lo = b_clamped;
let a_hi = _mm_srli_epi64(a_clamped, 32);
let b_hi = _mm_srli_epi64(b_clamped, 32);
let lo_product = _mm_mul_epu32(a_lo, b_lo);
let hi_product = _mm_mul_epu32(a_hi, b_hi);
let lo_shifted = _mm_shuffle_epi32(lo_product, 0b00_00_10_00);
let hi_shifted = _mm_shuffle_epi32(hi_product, 0b00_00_10_00);
let product = _mm_unpacklo_epi32(lo_shifted, hi_shifted);
let shifted = _mm_srai_epi32(product, SHIFT);
let result_lt_max = _mm_cmpgt_epi32(max_out, shifted);
let result = _mm_or_si128(
_mm_and_si128(shifted, result_lt_max),
_mm_andnot_si128(result_lt_max, max_out),
);
let packed16 = _mm_packs_epi32(result, result);
let packed8 = _mm_packus_epi16(packed16, packed16);
let val = _mm_cvtsi128_si32(packed8) as u32;
std::ptr::copy_nonoverlapping(
&val as *const u32 as *const u8,
out_ptr.add(i * 4),
4,
);
}
}
processed += num_chunks * 4;
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
let remaining = half - processed;
let num_chunks = remaining / 4;
if num_chunks > 0 {
unsafe {
use std::arch::wasm32::*;
let zero = i32x4_splat(0);
let max_clamp = i32x4_splat(QA);
let max_out = i32x4_splat(MAX_OUT);
let a_ptr = input.as_ptr().add(processed);
let b_ptr = input.as_ptr().add(half + processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let a_i64 = v128_load64_zero(a_ptr.add(i * 4) as *const u64);
let b_i64 = v128_load64_zero(b_ptr.add(i * 4) as *const u64);
let a = i32x4_extend_low_i16x8(a_i64);
let b = i32x4_extend_low_i16x8(b_i64);
let a_clamped = i32x4_min(i32x4_max(a, zero), max_clamp);
let b_clamped = i32x4_min(i32x4_max(b, zero), max_clamp);
let product = i32x4_mul(a_clamped, b_clamped);
let result = i32x4_min(i32x4_shr(product, SHIFT as u32), max_out);
let narrow16 = i16x8_narrow_i32x4(result, result);
let narrow8 = u8x16_narrow_i16x8(narrow16, narrow16);
v128_store32_lane::<0>(narrow8, out_ptr.add(i * 4) as *mut u32);
}
}
processed += num_chunks * 4;
}
}
for j in processed..half {
let a = i32::from(input[j]).clamp(0, QA);
let b = i32::from(input[j + half]).clamp(0, QA);
output[j] = ((a * b) >> SHIFT).min(MAX_OUT) as u8;
}
}
#[allow(dead_code)]
fn pairwise_crelu_i32_to_u8(input: &[i32], output: &mut [u8]) {
let half = input.len() / 2;
debug_assert_eq!(output.len(), half, "output length must be half of input length");
#[cfg(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse4.1"),
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "wasm32", target_feature = "simd128")
))]
let mut processed = 0usize;
#[cfg(not(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse4.1"),
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "wasm32", target_feature = "simd128")
)))]
let processed = 0usize;
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
let num_chunks = half / 8;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm256_setzero_si256();
let max_val = _mm256_set1_epi32(127);
let a_ptr = input.as_ptr();
let b_ptr = input.as_ptr().add(half);
let out_ptr = output.as_mut_ptr();
for i in 0..num_chunks {
let a = _mm256_loadu_si256(a_ptr.add(i * 8) as *const __m256i);
let b = _mm256_loadu_si256(b_ptr.add(i * 8) as *const __m256i);
let a_shifted = _mm256_srai_epi32(a, WEIGHT_SCALE_BITS as i32);
let b_shifted = _mm256_srai_epi32(b, WEIGHT_SCALE_BITS as i32);
let a_clamped = _mm256_min_epi32(_mm256_max_epi32(a_shifted, zero), max_val);
let b_clamped = _mm256_min_epi32(_mm256_max_epi32(b_shifted, zero), max_val);
let product = _mm256_mullo_epi32(a_clamped, b_clamped);
let result = _mm256_min_epi32(_mm256_srai_epi32(product, 7), max_val);
let packed16 = _mm256_packs_epi32(result, result);
let packed8 = _mm256_packus_epi16(packed16, packed16);
let lo = _mm256_castsi256_si128(packed8);
let hi = _mm256_extracti128_si256(packed8, 1);
let combined = _mm_unpacklo_epi32(lo, hi);
_mm_storel_epi64(out_ptr.add(i * 8) as *mut __m128i, combined);
}
}
processed = num_chunks * 8;
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "sse4.1"))]
{
let remaining = half - processed;
let num_chunks = remaining / 4;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm_setzero_si128();
let max_val = _mm_set1_epi32(127);
let a_ptr = input.as_ptr().add(processed);
let b_ptr = input.as_ptr().add(half + processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let a = _mm_loadu_si128(a_ptr.add(i * 4) as *const __m128i);
let b = _mm_loadu_si128(b_ptr.add(i * 4) as *const __m128i);
let a_shifted = _mm_srai_epi32(a, WEIGHT_SCALE_BITS as i32);
let b_shifted = _mm_srai_epi32(b, WEIGHT_SCALE_BITS as i32);
let a_clamped = _mm_min_epi32(_mm_max_epi32(a_shifted, zero), max_val);
let b_clamped = _mm_min_epi32(_mm_max_epi32(b_shifted, zero), max_val);
let product = _mm_mullo_epi32(a_clamped, b_clamped);
let result = _mm_min_epi32(_mm_srai_epi32(product, 7), max_val);
let packed16 = _mm_packs_epi32(result, result);
let packed8 = _mm_packus_epi16(packed16, packed16);
let val = _mm_cvtsi128_si32(packed8) as u32;
std::ptr::copy_nonoverlapping(
&val as *const u32 as *const u8,
out_ptr.add(i * 4),
4,
);
}
}
processed += num_chunks * 4;
}
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "sse2",
not(target_feature = "sse4.1")
))]
{
let remaining = half - processed;
let num_chunks = remaining / 4;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm_setzero_si128();
let max_val = _mm_set1_epi32(127);
let a_ptr = input.as_ptr().add(processed);
let b_ptr = input.as_ptr().add(half + processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let a = _mm_loadu_si128(a_ptr.add(i * 4) as *const __m128i);
let b = _mm_loadu_si128(b_ptr.add(i * 4) as *const __m128i);
let a_shifted = _mm_srai_epi32(a, WEIGHT_SCALE_BITS as i32);
let b_shifted = _mm_srai_epi32(b, WEIGHT_SCALE_BITS as i32);
let a_gt_zero = _mm_cmpgt_epi32(a_shifted, zero);
let a_max_zero = _mm_and_si128(a_shifted, a_gt_zero);
let a_lt_max = _mm_cmpgt_epi32(max_val, a_max_zero);
let a_clamped = _mm_or_si128(
_mm_and_si128(a_max_zero, a_lt_max),
_mm_andnot_si128(a_lt_max, max_val),
);
let b_gt_zero = _mm_cmpgt_epi32(b_shifted, zero);
let b_max_zero = _mm_and_si128(b_shifted, b_gt_zero);
let b_lt_max = _mm_cmpgt_epi32(max_val, b_max_zero);
let b_clamped = _mm_or_si128(
_mm_and_si128(b_max_zero, b_lt_max),
_mm_andnot_si128(b_lt_max, max_val),
);
let a_lo = a_clamped;
let b_lo = b_clamped;
let a_hi = _mm_srli_epi64(a_clamped, 32);
let b_hi = _mm_srli_epi64(b_clamped, 32);
let lo_product = _mm_mul_epu32(a_lo, b_lo);
let hi_product = _mm_mul_epu32(a_hi, b_hi);
let lo_shifted = _mm_shuffle_epi32(lo_product, 0b00_00_10_00);
let hi_shifted = _mm_shuffle_epi32(hi_product, 0b00_00_10_00);
let product = _mm_unpacklo_epi32(lo_shifted, hi_shifted);
let shifted = _mm_srai_epi32(product, 7);
let result_lt_max = _mm_cmpgt_epi32(max_val, shifted);
let result = _mm_or_si128(
_mm_and_si128(shifted, result_lt_max),
_mm_andnot_si128(result_lt_max, max_val),
);
let packed16 = _mm_packs_epi32(result, result);
let packed8 = _mm_packus_epi16(packed16, packed16);
let val = _mm_cvtsi128_si32(packed8) as u32;
std::ptr::copy_nonoverlapping(
&val as *const u32 as *const u8,
out_ptr.add(i * 4),
4,
);
}
}
processed += num_chunks * 4;
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
let remaining = half - processed;
let num_chunks = remaining / 4;
if num_chunks > 0 {
unsafe {
use std::arch::wasm32::*;
let zero = i32x4_splat(0);
let max_val = i32x4_splat(127);
let a_ptr = input.as_ptr().add(processed);
let b_ptr = input.as_ptr().add(half + processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let a = v128_load(a_ptr.add(i * 4) as *const v128);
let b = v128_load(b_ptr.add(i * 4) as *const v128);
let a_shifted = i32x4_shr(a, WEIGHT_SCALE_BITS as u32);
let b_shifted = i32x4_shr(b, WEIGHT_SCALE_BITS as u32);
let a_clamped = i32x4_min(i32x4_max(a_shifted, zero), max_val);
let b_clamped = i32x4_min(i32x4_max(b_shifted, zero), max_val);
let product = i32x4_mul(a_clamped, b_clamped);
let result = i32x4_min(i32x4_shr(product, 7), max_val);
let narrow16 = i16x8_narrow_i32x4(result, result);
let narrow8 = u8x16_narrow_i16x8(narrow16, narrow16);
v128_store32_lane::<0>(narrow8, out_ptr.add(i * 4) as *mut u32);
}
}
processed += num_chunks * 4;
}
}
for j in processed..half {
let a = (input[j] >> WEIGHT_SCALE_BITS).clamp(0, 127);
let b = (input[j + half] >> WEIGHT_SCALE_BITS).clamp(0, 127);
output[j] = ((a * b) >> 7).min(127) as u8;
}
}
#[derive(Clone, Copy, Default)]
pub struct SCReLU;
impl FtActivation for SCReLU {
const OUTPUT_DIM_DIVISOR: usize = 1;
#[inline]
fn activate_i16_to_u8(input: &[i16], output: &mut [u8], qa: i16) {
debug_assert_eq!(input.len(), output.len());
screlu_i16_to_u8(input, output, qa);
}
#[inline]
fn activate_i32_to_u8(input: &[i32], output: &mut [u8]) {
debug_assert_eq!(input.len(), output.len());
screlu_i32_to_u8(input, output);
}
fn header_suffix() -> &'static str {
"-SCReLU"
}
fn name() -> &'static str {
"SCReLU"
}
}
fn screlu_i16_to_u8(input: &[i16], output: &mut [u8], qa: i16) {
debug_assert_eq!(input.len(), output.len(), "input and output must have same length");
if qa >= 255 {
screlu_i16_to_u8_inner::<255, 9>(input, output);
} else {
screlu_i16_to_u8_inner::<127, 7>(input, output);
}
}
fn screlu_i16_to_u8_inner<const QA: i32, const SHIFT: i32>(input: &[i16], output: &mut [u8]) {
const {
assert!(
(QA == 127 && SHIFT == 7) || (QA == 255 && SHIFT == 9),
"Invalid QA/SHIFT combination"
);
}
#[cfg(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "wasm32", target_feature = "simd128")
))]
let mut processed = 0;
#[cfg(not(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "wasm32", target_feature = "simd128")
)))]
let processed = 0;
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
let num_chunks = input.len() / 8;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm256_setzero_si256();
let max_clamp = _mm256_set1_epi32(QA);
let max_out = _mm256_set1_epi32(127);
let in_ptr = input.as_ptr();
let out_ptr = output.as_mut_ptr();
for i in 0..num_chunks {
let v_i16 = _mm_loadu_si128(in_ptr.add(i * 8) as *const __m128i);
let v = _mm256_cvtepi16_epi32(v_i16);
let clamped = _mm256_min_epi32(_mm256_max_epi32(v, zero), max_clamp);
let squared = _mm256_mullo_epi32(clamped, clamped);
let result = _mm256_min_epi32(_mm256_srai_epi32(squared, SHIFT), max_out);
let packed16 = _mm256_packs_epi32(result, result);
let packed8 = _mm256_packus_epi16(packed16, packed16);
let lo = _mm256_castsi256_si128(packed8);
let hi = _mm256_extracti128_si256(packed8, 1);
let combined = _mm_unpacklo_epi32(lo, hi);
_mm_storel_epi64(out_ptr.add(i * 8) as *mut __m128i, combined);
}
}
processed = num_chunks * 8;
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
{
let remaining = input.len() - processed;
let num_chunks = remaining / 4;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm_setzero_si128();
let max_clamp = _mm_set1_epi32(QA);
let max_out = _mm_set1_epi32(127);
let in_ptr = input.as_ptr().add(processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let v_i16 = _mm_loadl_epi64(in_ptr.add(i * 4) as *const __m128i);
#[cfg(target_feature = "sse4.1")]
let v = _mm_cvtepi16_epi32(v_i16);
#[cfg(not(target_feature = "sse4.1"))]
let v = {
let sign_mask = _mm_cmpgt_epi16(zero, v_i16);
_mm_unpacklo_epi16(v_i16, sign_mask)
};
let clamped = _mm_min_epi32(_mm_max_epi32(v, zero), max_clamp);
#[cfg(target_feature = "sse4.1")]
let squared = _mm_mullo_epi32(clamped, clamped);
#[cfg(not(target_feature = "sse4.1"))]
let squared = {
let a_lo = clamped;
let a_hi = _mm_srli_epi64(clamped, 32); let lo_lo = _mm_mul_epu32(a_lo, a_lo);
let hi_hi = _mm_mul_epu32(a_hi, a_hi);
let lo_lo_shifted = _mm_shuffle_epi32(lo_lo, 0b00_00_10_00);
let hi_hi_shifted = _mm_shuffle_epi32(hi_hi, 0b00_00_10_00);
_mm_unpacklo_epi32(lo_lo_shifted, hi_hi_shifted)
};
let shifted = _mm_srai_epi32(squared, SHIFT);
let result = _mm_min_epi32(shifted, max_out);
let packed16 = _mm_packs_epi32(result, result);
let packed8 = _mm_packus_epi16(packed16, packed16);
let val = _mm_cvtsi128_si32(packed8) as u32;
std::ptr::copy_nonoverlapping(
&val as *const u32 as *const u8,
out_ptr.add(i * 4),
4,
);
}
}
processed += num_chunks * 4;
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
let remaining = input.len() - processed;
let num_chunks = remaining / 4;
if num_chunks > 0 {
unsafe {
use std::arch::wasm32::*;
let zero = i32x4_splat(0);
let max_clamp = i32x4_splat(QA);
let max_out = i32x4_splat(127);
let in_ptr = input.as_ptr().add(processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let v_i64 = v128_load64_zero(in_ptr.add(i * 4) as *const u64);
let v = i32x4_extend_low_i16x8(v_i64);
let clamped = i32x4_min(i32x4_max(v, zero), max_clamp);
let squared = i32x4_mul(clamped, clamped);
let shifted = i32x4_shr(squared, SHIFT as u32);
let result = i32x4_min(shifted, max_out);
let narrow16 = i16x8_narrow_i32x4(result, result);
let narrow8 = u8x16_narrow_i16x8(narrow16, narrow16);
v128_store32_lane::<0>(narrow8, out_ptr.add(i * 4) as *mut u32);
}
}
processed += num_chunks * 4;
}
}
for i in processed..input.len() {
let clamped = i32::from(input[i]).clamp(0, QA);
output[i] = ((clamped * clamped) >> SHIFT).min(127) as u8;
}
}
fn screlu_i32_to_u8(input: &[i32], output: &mut [u8]) {
use super::constants::SCRELU_QB;
debug_assert_eq!(input.len(), output.len(), "input and output must have same length");
#[cfg(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "wasm32", target_feature = "simd128")
))]
let mut processed = 0;
#[cfg(not(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "wasm32", target_feature = "simd128")
)))]
let processed = 0;
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
let num_chunks = input.len() / 8;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm256_setzero_si256();
let max_clamp = _mm256_set1_epi32(127);
let _qb = _mm256_set1_epi32(SCRELU_QB);
let in_ptr = input.as_ptr();
let out_ptr = output.as_mut_ptr();
for i in 0..num_chunks {
let v = _mm256_loadu_si256(in_ptr.add(i * 8) as *const __m256i);
let shifted = _mm256_srai_epi32(v, WEIGHT_SCALE_BITS as i32);
let clamped = _mm256_min_epi32(_mm256_max_epi32(shifted, zero), max_clamp);
let squared = _mm256_mullo_epi32(clamped, clamped);
let result = _mm256_min_epi32(_mm256_srli_epi32(squared, 6), max_clamp);
let packed16 = _mm256_packs_epi32(result, result);
let packed8 = _mm256_packus_epi16(packed16, packed16);
let lo = _mm256_castsi256_si128(packed8);
let hi = _mm256_extracti128_si256(packed8, 1);
let combined = _mm_unpacklo_epi32(lo, hi);
_mm_storel_epi64(out_ptr.add(i * 8) as *mut __m128i, combined);
}
}
processed = num_chunks * 8;
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
{
let remaining = input.len() - processed;
let num_chunks = remaining / 4;
if num_chunks > 0 {
unsafe {
use std::arch::x86_64::*;
let zero = _mm_setzero_si128();
let max_clamp = _mm_set1_epi32(127);
let in_ptr = input.as_ptr().add(processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let v = _mm_loadu_si128(in_ptr.add(i * 4) as *const __m128i);
let shifted = _mm_srai_epi32(v, WEIGHT_SCALE_BITS as i32);
let clamped = _mm_min_epi32(_mm_max_epi32(shifted, zero), max_clamp);
#[cfg(target_feature = "sse4.1")]
let squared = _mm_mullo_epi32(clamped, clamped);
#[cfg(not(target_feature = "sse4.1"))]
let squared = {
let a_lo = clamped;
let a_hi = _mm_srli_epi64(clamped, 32); let lo_lo = _mm_mul_epu32(a_lo, a_lo);
let hi_hi = _mm_mul_epu32(a_hi, a_hi);
let lo_lo_shifted = _mm_shuffle_epi32(lo_lo, 0b00_00_10_00);
let hi_hi_shifted = _mm_shuffle_epi32(hi_hi, 0b00_00_10_00);
_mm_unpacklo_epi32(lo_lo_shifted, hi_hi_shifted)
};
let result = _mm_min_epi32(_mm_srli_epi32(squared, 6), max_clamp);
let packed16 = _mm_packs_epi32(result, result);
let packed8 = _mm_packus_epi16(packed16, packed16);
let val = _mm_cvtsi128_si32(packed8) as u32;
std::ptr::copy_nonoverlapping(
&val as *const u32 as *const u8,
out_ptr.add(i * 4),
4,
);
}
}
processed += num_chunks * 4;
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
let remaining = input.len() - processed;
let num_chunks = remaining / 4;
if num_chunks > 0 {
unsafe {
use std::arch::wasm32::*;
let zero = i32x4_splat(0);
let max_clamp = i32x4_splat(127);
let in_ptr = input.as_ptr().add(processed);
let out_ptr = output.as_mut_ptr().add(processed);
for i in 0..num_chunks {
let v = v128_load(in_ptr.add(i * 4) as *const v128);
let shifted = i32x4_shr(v, WEIGHT_SCALE_BITS as u32);
let clamped = i32x4_min(i32x4_max(shifted, zero), max_clamp);
let squared = i32x4_mul(clamped, clamped);
let result = i32x4_min(u32x4_shr(squared, 6), max_clamp);
let narrow16 = i16x8_narrow_i32x4(result, result);
let narrow8 = u8x16_narrow_i16x8(narrow16, narrow16);
v128_store32_lane::<0>(narrow8, out_ptr.add(i * 4) as *mut u32);
}
}
processed += num_chunks * 4;
}
}
for i in processed..input.len() {
let shifted = input[i] >> WEIGHT_SCALE_BITS;
let clamped = shifted.clamp(0, 127);
let squared = clamped * clamped;
output[i] = (squared / SCRELU_QB).min(127) as u8;
}
}
pub fn detect_activation_from_arch(arch_str: &str) -> &'static str {
if arch_str.contains("-SCReLU-Pairwise") {
"SCReLU-Pairwise"
} else if arch_str.contains(PairwiseCReLU::header_suffix()) {
PairwiseCReLU::name()
} else if arch_str.contains(SCReLU::header_suffix()) {
SCReLU::name()
} else {
CReLU::name()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_activation_from_arch() {
assert_eq!(detect_activation_from_arch("Features=HalfKA_hm[73305->512x2]"), "CReLU");
assert_eq!(
detect_activation_from_arch("Features=HalfKA_hm[73305->512x2]-SCReLU"),
"SCReLU"
);
assert_eq!(
detect_activation_from_arch("Features=HalfKA_hm[73305->512/2x2]-Pairwise"),
"PairwiseCReLU"
);
assert_eq!(
detect_activation_from_arch("Features=HalfKA_hm[73305->512/2x2]-PairwiseCReLU"),
"PairwiseCReLU"
);
assert_eq!(
detect_activation_from_arch("Features=HalfKA_hm[73305->512/2x2]-SCReLU-Pairwise"),
"SCReLU-Pairwise"
);
}
#[test]
fn test_crelu_i16_to_u8() {
let input = [0i16, 50, 127, 200, -10, -50];
let mut output = [0u8; 6];
CReLU::activate_i16_to_u8(&input, &mut output, 127);
assert_eq!(output[0], 0);
assert_eq!(output[1], 50);
assert_eq!(output[2], 127);
assert_eq!(output[3], 127); assert_eq!(output[4], 0); assert_eq!(output[5], 0); }
#[test]
fn test_crelu_i32_to_u8() {
let input = [0i32, 64, 128, 8192, -64, 64 * 100];
let mut output = [0u8; 6];
CReLU::activate_i32_to_u8(&input, &mut output);
assert_eq!(output[0], 0); assert_eq!(output[1], 1); assert_eq!(output[2], 2); assert_eq!(output[3], 127); assert_eq!(output[4], 0); assert_eq!(output[5], 100); }
#[test]
fn test_pairwise_crelu_i16_to_u8_qa127() {
let input = [64i16, 100, 127, 0, 64, 50, 127, 100];
let mut output = [0u8; 4];
PairwiseCReLU::activate_i16_to_u8(&input, &mut output, 127);
assert_eq!(output[0], 63);
assert_eq!(output[1], 0);
assert_eq!(output[2], 63);
assert_eq!(output[3], 39);
}
#[test]
fn test_pairwise_crelu_i16_to_u8_qa255() {
let input = [128i16, 200, 255, 0, 128, 100, 255, 200];
let mut output = [0u8; 4];
PairwiseCReLU::activate_i16_to_u8(&input, &mut output, 255);
assert_eq!(output[0], 63);
assert_eq!(output[1], 0);
assert_eq!(output[2], 63);
assert_eq!(output[3], 39);
}
#[test]
fn test_screlu_i16_to_u8() {
let input = [0i16, 50, 127, 200, -10];
let mut output = [0u8; 5];
SCReLU::activate_i16_to_u8(&input, &mut output, 127);
assert_eq!(output[0], 0);
assert_eq!(output[1], 19);
assert_eq!(output[2], 126);
assert_eq!(output[3], 126);
assert_eq!(output[4], 0);
}
#[test]
fn test_detect_activation() {
assert_eq!(detect_activation_from_arch("HalfKA_hm^512x2-8-96"), "CReLU");
assert_eq!(detect_activation_from_arch("HalfKA_hm^512x2-8-96-SCReLU"), "SCReLU");
assert_eq!(detect_activation_from_arch("HalfKP256x2-32-32-PairwiseCReLU"), "PairwiseCReLU");
}
#[test]
fn test_output_dim_divisor() {
assert_eq!(CReLU::OUTPUT_DIM_DIVISOR, 1);
assert_eq!(PairwiseCReLU::OUTPUT_DIM_DIVISOR, 2);
assert_eq!(SCReLU::OUTPUT_DIM_DIVISOR, 1);
}
#[test]
fn test_pairwise_crelu_i32_to_u8() {
let input = [0i32, 64, 128, 8192, -64, 64 * 100];
let mut output = [0u8; 6];
PairwiseCReLU::activate_i32_to_u8(&input, &mut output);
assert_eq!(output[0], 0); assert_eq!(output[1], 1); assert_eq!(output[2], 2); assert_eq!(output[3], 127); assert_eq!(output[4], 0); assert_eq!(output[5], 100); }
#[test]
fn test_screlu_i32_to_u8() {
use crate::nnue::constants::SCRELU_QB;
let input = [
0i32,
64 * 50, 64 * 127, 64 * 200, -64, ];
let mut output = [0u8; 5];
SCReLU::activate_i32_to_u8(&input, &mut output);
assert_eq!(output[0], 0); assert_eq!(output[1], (2500 / SCRELU_QB) as u8); assert_eq!(output[2], 127); assert_eq!(output[3], 127); assert_eq!(output[4], 0); }
#[test]
fn test_pairwise_crelu_actual_network_size() {
const L1: usize = 512;
const QUARTER: usize = L1 / 2; let mut input = [0i16; L1 * 2]; let mut output = [0u8; L1];
for i in 0..L1 {
let seed = i as u32;
let random = seed.wrapping_mul(1103515245).wrapping_add(12345);
let val = ((random >> 16) & 0xFF) as i16; input[i] = val; input[i + L1] = (val.wrapping_add(128)) & 0xFF; }
PairwiseCReLU::activate_i16_to_u8(&input, &mut output, 255);
for i in 0..QUARTER {
let a = (input[i] as i32).clamp(0, 255);
let b = (input[i + QUARTER] as i32).clamp(0, 255);
let expected = ((a * b) >> 9).min(127) as u8;
assert_eq!(
output[i], expected,
"STM mismatch at index {i}: expected {expected}, got {}, a={a}, b={b}",
output[i]
);
}
for i in 0..QUARTER {
let a = (input[L1 + i] as i32).clamp(0, 255);
let b = (input[L1 + i + QUARTER] as i32).clamp(0, 255);
let expected = ((a * b) >> 9).min(127) as u8;
assert_eq!(
output[QUARTER + i],
expected,
"NTM mismatch at index {i}: expected {expected}, got {}, a={a}, b={b}",
output[QUARTER + i]
);
}
}
#[test]
fn test_pairwise_crelu_simd_path() {
const L1: usize = 32;
const QUARTER: usize = L1 / 2; let mut input = [0i16; L1 * 2];
let mut output = [0u8; L1];
for i in 0..L1 {
input[i] = (i as i16) * 4; input[i + L1] = 100 - (i as i16) * 2; }
PairwiseCReLU::activate_i16_to_u8(&input, &mut output, 127);
for i in 0..QUARTER {
let a = (input[i] as i32).clamp(0, 127);
let b = (input[i + QUARTER] as i32).clamp(0, 127);
let expected = ((a * b) >> 7).min(127) as u8;
assert_eq!(
output[i], expected,
"STM mismatch at index {i}: expected {expected}, got {}",
output[i]
);
}
for i in 0..QUARTER {
let a = (input[L1 + i] as i32).clamp(0, 127);
let b = (input[L1 + i + QUARTER] as i32).clamp(0, 127);
let expected = ((a * b) >> 7).min(127) as u8;
assert_eq!(
output[QUARTER + i],
expected,
"NTM mismatch at index {i}: expected {expected}, got {}",
output[QUARTER + i]
);
}
}
#[test]
fn test_pairwise_crelu_i32_simd_path() {
const SIZE: usize = 64;
let mut input = [0i32; SIZE];
let mut output = [0u8; SIZE];
for (i, value) in input.iter_mut().enumerate() {
*value = (i as i32) * 4 * 64; }
PairwiseCReLU::activate_i32_to_u8(&input, &mut output);
for (i, value) in input.iter().enumerate() {
let expected = (value >> WEIGHT_SCALE_BITS).clamp(0, 127) as u8;
assert_eq!(
output[i], expected,
"mismatch at index {i}: expected {expected}, got {}",
output[i]
);
}
}
}