use crate::dtype::Element;
#[cfg(all(feature = "f16", target_arch = "x86_64"))]
#[inline]
unsafe fn simd_dot_f32(
a: *const f32,
b: *const f32,
len: usize,
level: super::simd::SimdLevel,
) -> f32 {
use super::simd::SimdLevel;
match level {
SimdLevel::Avx512 => simd_dot_f32_avx512(a, b, len),
SimdLevel::Avx2Fma => simd_dot_f32_avx2(a, b, len),
_ => {
let mut sum = 0.0f32;
for i in 0..len {
sum += *a.add(i) * *b.add(i);
}
sum
}
}
}
#[cfg(all(feature = "f16", target_arch = "x86_64"))]
#[target_feature(enable = "avx512f")]
unsafe fn simd_dot_f32_avx512(a: *const f32, b: *const f32, len: usize) -> f32 {
use std::arch::x86_64::*;
let mut offset = 0;
let mut acc0 = _mm512_setzero_ps();
let mut acc1 = _mm512_setzero_ps();
while offset + 32 <= len {
let av0 = _mm512_loadu_ps(a.add(offset));
let bv0 = _mm512_loadu_ps(b.add(offset));
acc0 = _mm512_fmadd_ps(av0, bv0, acc0);
let av1 = _mm512_loadu_ps(a.add(offset + 16));
let bv1 = _mm512_loadu_ps(b.add(offset + 16));
acc1 = _mm512_fmadd_ps(av1, bv1, acc1);
offset += 32;
}
acc0 = _mm512_add_ps(acc0, acc1);
while offset + 16 <= len {
let av = _mm512_loadu_ps(a.add(offset));
let bv = _mm512_loadu_ps(b.add(offset));
acc0 = _mm512_fmadd_ps(av, bv, acc0);
offset += 16;
}
let mut sum = _mm512_reduce_add_ps(acc0);
while offset < len {
sum += *a.add(offset) * *b.add(offset);
offset += 1;
}
sum
}
#[cfg(all(feature = "f16", target_arch = "aarch64"))]
#[inline]
unsafe fn simd_dot_f32(
a: *const f32,
b: *const f32,
len: usize,
_level: super::simd::SimdLevel,
) -> f32 {
simd_dot_f32_neon(a, b, len)
}
#[cfg(all(feature = "f16", target_arch = "aarch64"))]
#[target_feature(enable = "neon")]
unsafe fn simd_dot_f32_neon(a: *const f32, b: *const f32, len: usize) -> f32 {
use std::arch::aarch64::*;
let mut offset = 0;
let mut acc0 = vdupq_n_f32(0.0);
let mut acc1 = vdupq_n_f32(0.0);
while offset + 8 <= len {
let av0 = vld1q_f32(a.add(offset));
let bv0 = vld1q_f32(b.add(offset));
acc0 = vfmaq_f32(acc0, av0, bv0);
let av1 = vld1q_f32(a.add(offset + 4));
let bv1 = vld1q_f32(b.add(offset + 4));
acc1 = vfmaq_f32(acc1, av1, bv1);
offset += 8;
}
acc0 = vaddq_f32(acc0, acc1);
while offset + 4 <= len {
let av = vld1q_f32(a.add(offset));
let bv = vld1q_f32(b.add(offset));
acc0 = vfmaq_f32(acc0, av, bv);
offset += 4;
}
let mut sum = vaddvq_f32(acc0);
while offset < len {
sum += *a.add(offset) * *b.add(offset);
offset += 1;
}
sum
}
#[cfg(all(feature = "f16", target_arch = "x86_64"))]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn simd_dot_f32_avx2(a: *const f32, b: *const f32, len: usize) -> f32 {
use std::arch::x86_64::*;
let mut offset = 0;
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
while offset + 16 <= len {
let av0 = _mm256_loadu_ps(a.add(offset));
let bv0 = _mm256_loadu_ps(b.add(offset));
acc0 = _mm256_fmadd_ps(av0, bv0, acc0);
let av1 = _mm256_loadu_ps(a.add(offset + 8));
let bv1 = _mm256_loadu_ps(b.add(offset + 8));
acc1 = _mm256_fmadd_ps(av1, bv1, acc1);
offset += 16;
}
acc0 = _mm256_add_ps(acc0, acc1);
while offset + 8 <= len {
let av = _mm256_loadu_ps(a.add(offset));
let bv = _mm256_loadu_ps(b.add(offset));
acc0 = _mm256_fmadd_ps(av, bv, acc0);
offset += 8;
}
let hi = _mm256_extractf128_ps(acc0, 1);
let lo = _mm256_castps256_ps128(acc0);
let sum128 = _mm_add_ps(lo, hi);
let shuf = _mm_movehdup_ps(sum128);
let sums = _mm_add_ps(sum128, shuf);
let shuf2 = _mm_movehl_ps(sums, sums);
let sums2 = _mm_add_ss(sums, shuf2);
let mut sum = _mm_cvtss_f32(sums2);
while offset < len {
sum += *a.add(offset) * *b.add(offset);
offset += 1;
}
sum
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn gemv_bt_kernel<T: Element>(
a: *const T,
b_nk: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
ldc: usize,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::detect_simd;
use super::simd::matmul::gemv_bt;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
let level = detect_simd();
gemv_bt::gemv_bt_f32(
a as *const f32,
b_nk as *const f32,
out as *mut f32,
m,
n,
k,
ldc,
level,
);
return;
}
DType::F64 => {
let level = detect_simd();
gemv_bt::gemv_bt_f64(
a as *const f64,
b_nk as *const f64,
out as *mut f64,
m,
n,
k,
ldc,
level,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 | DType::BF16 => {
gemv_bt_via_f32(a, b_nk, out, m, n, k, ldc);
return;
}
_ => {}
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
#[allow(unused_imports)]
use crate::dtype::DType;
match T::DTYPE {
#[cfg(feature = "f16")]
DType::F16 | DType::BF16 => {
gemv_bt_via_f32(a, b_nk, out, m, n, k, ldc);
return;
}
_ => {}
}
}
gemv_bt_scalar(a, b_nk, out, m, n, k, ldc);
}
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn gemv_bt_scalar<T: Element>(
a: *const T,
b_nk: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
ldc: usize,
) {
for row in 0..m {
let a_row = a.add(row * k);
let out_row = out.add(row * ldc);
for col in 0..n {
let b_row = b_nk.add(col * k);
let mut sum = T::zero();
for i in 0..k {
sum = sum + *a_row.add(i) * *b_row.add(i);
}
*out_row.add(col) = sum;
}
}
}
#[cfg(feature = "f16")]
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn gemv_bt_via_f32<T: Element>(
a: *const T,
b_nk: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
ldc: usize,
) {
let mut a_f32 = vec![0.0f32; k];
let mut b_f32 = vec![0.0f32; k];
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
let level = super::simd::detect_simd();
for row in 0..m {
let a_row = a.add(row * k);
batch_half_to_f32::<T>(a_row, a_f32.as_mut_ptr(), k);
let out_row = out.add(row * ldc);
for col in 0..n {
let b_row = b_nk.add(col * k);
batch_half_to_f32::<T>(b_row, b_f32.as_mut_ptr(), k);
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
let dot = simd_dot_f32(a_f32.as_ptr(), b_f32.as_ptr(), k, level);
*out_row.add(col) = T::from_f32(dot);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let mut sum = 0.0f32;
for i in 0..k {
sum += a_f32[i] * b_f32[i];
}
*out_row.add(col) = T::from_f32(sum);
}
}
}
}
#[cfg(feature = "f16")]
#[inline]
unsafe fn batch_half_to_f32<T: Element>(src: *const T, dst: *mut f32, len: usize) {
use crate::dtype::DType;
match T::DTYPE {
#[cfg(target_arch = "x86_64")]
DType::BF16 => {
batch_bf16_to_f32(src as *const u16, dst, len);
}
#[cfg(target_arch = "x86_64")]
DType::F16 => {
batch_f16_to_f32(src as *const u16, dst, len);
}
_ => {
for i in 0..len {
*dst.add(i) = (*src.add(i)).to_f32();
}
}
}
}
#[cfg(all(feature = "f16", target_arch = "x86_64"))]
#[inline]
unsafe fn batch_bf16_to_f32(src: *const u16, dst: *mut f32, len: usize) {
if is_x86_feature_detected!("avx2") {
batch_bf16_to_f32_avx2(src, dst, len);
} else {
batch_bf16_to_f32_scalar(src, dst, len);
}
}
#[cfg(all(feature = "f16", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn batch_bf16_to_f32_avx2(src: *const u16, dst: *mut f32, len: usize) {
use std::arch::x86_64::*;
let mut i = 0usize;
while i + 8 <= len {
let bf16_vals = _mm_loadu_si128(src.add(i) as *const __m128i);
let i32_vals = _mm256_cvtepu16_epi32(bf16_vals);
let f32_bits = _mm256_slli_epi32(i32_vals, 16);
_mm256_storeu_ps(dst.add(i), _mm256_castsi256_ps(f32_bits));
i += 8;
}
while i < len {
let bits = (*src.add(i) as u32) << 16;
*dst.add(i) = f32::from_bits(bits);
i += 1;
}
}
#[cfg(all(feature = "f16", target_arch = "x86_64"))]
unsafe fn batch_bf16_to_f32_scalar(src: *const u16, dst: *mut f32, len: usize) {
for i in 0..len {
let bits = (*src.add(i) as u32) << 16;
*dst.add(i) = f32::from_bits(bits);
}
}
#[cfg(all(feature = "f16", target_arch = "x86_64"))]
#[inline]
unsafe fn batch_f16_to_f32(src: *const u16, dst: *mut f32, len: usize) {
if is_x86_feature_detected!("f16c") {
batch_f16_to_f32_f16c(src, dst, len);
} else {
batch_f16_to_f32_scalar(src, dst, len);
}
}
#[cfg(all(feature = "f16", target_arch = "x86_64"))]
#[target_feature(enable = "f16c", enable = "avx")]
unsafe fn batch_f16_to_f32_f16c(src: *const u16, dst: *mut f32, len: usize) {
use std::arch::x86_64::*;
let mut i = 0usize;
while i + 8 <= len {
let f16_vals = _mm_loadu_si128(src.add(i) as *const __m128i);
let f32_vals = _mm256_cvtph_ps(f16_vals);
_mm256_storeu_ps(dst.add(i), f32_vals);
i += 8;
}
while i < len {
*dst.add(i) = half::f16::from_bits(*src.add(i)).to_f32();
i += 1;
}
}
#[cfg(all(feature = "f16", target_arch = "x86_64"))]
unsafe fn batch_f16_to_f32_scalar(src: *const u16, dst: *mut f32, len: usize) {
for i in 0..len {
*dst.add(i) = half::f16::from_bits(*src.add(i)).to_f32();
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_kernel<T: Element>(
a: *const T,
b: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
#[cfg(target_arch = "x86_64")]
{
use super::simd::matmul;
use crate::dtype::DType;
match T::DTYPE {
DType::I32 => {
matmul::int32::matmul_i32(
a as *const i32,
b as *const i32,
out as *mut i32,
m,
n,
k,
lda,
ldb,
ldc,
);
return;
}
DType::F32 => {
matmul::matmul_f32(
a as *const f32,
b as *const f32,
out as *mut f32,
m,
n,
k,
lda,
ldb,
ldc,
);
return;
}
DType::F64 => {
matmul::matmul_f64(
a as *const f64,
b as *const f64,
out as *mut f64,
m,
n,
k,
lda,
ldb,
ldc,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 | DType::BF16 => {
matmul::half_convert::matmul_via_f32(a, b, out, m, n, k, lda, ldb, ldc);
return;
}
_ => {} }
}
matmul_scalar(a, b, out, m, n, k, lda, ldb, ldc);
}
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn matmul_scalar<T: Element>(
a: *const T,
b: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
for i in 0..m {
for j in 0..n {
*out.add(i * ldc + j) = T::zero();
}
}
for i in 0..m {
for kk in 0..k {
let a_val = *a.add(i * lda + kk);
for j in 0..n {
let b_val = *b.add(kk * ldb + j);
let out_ptr = out.add(i * ldc + j);
*out_ptr = *out_ptr + a_val * b_val;
}
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul_bias_kernel<T: Element>(
a: *const T,
b: *const T,
bias: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
#[cfg(target_arch = "x86_64")]
{
use super::simd::matmul;
use crate::dtype::DType;
match T::DTYPE {
DType::F32 => {
matmul::matmul_bias_f32(
a as *const f32,
b as *const f32,
bias as *const f32,
out as *mut f32,
m,
n,
k,
lda,
ldb,
ldc,
);
return;
}
DType::F64 => {
matmul::matmul_bias_f64(
a as *const f64,
b as *const f64,
bias as *const f64,
out as *mut f64,
m,
n,
k,
lda,
ldb,
ldc,
);
return;
}
#[cfg(feature = "f16")]
DType::F16 | DType::BF16 => {
matmul::half_convert::matmul_bias_via_f32(a, b, bias, out, m, n, k, lda, ldb, ldc);
return;
}
_ => {} }
}
matmul_bias_scalar(a, b, bias, out, m, n, k, lda, ldb, ldc);
}
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn matmul_bias_scalar<T: Element>(
a: *const T,
b: *const T,
bias: *const T,
out: *mut T,
m: usize,
n: usize,
k: usize,
lda: usize,
ldb: usize,
ldc: usize,
) {
for i in 0..m {
for j in 0..n {
*out.add(i * ldc + j) = *bias.add(j);
}
}
for i in 0..m {
for kk in 0..k {
let a_val = *a.add(i * lda + kk);
for j in 0..n {
let b_val = *b.add(kk * ldb + j);
let out_ptr = out.add(i * ldc + j);
*out_ptr = *out_ptr + a_val * b_val;
}
}
}
}