use trueno::Vector;
const TILE_SIZE: usize = 64;
#[must_use]
pub fn simd_matmul(input: &[f32], weight: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
let input_vec = Vector::from_slice(input);
let mut output = vec![0.0; out_dim];
for tile_start in (0..out_dim).step_by(TILE_SIZE) {
let tile_end = (tile_start + TILE_SIZE).min(out_dim);
for row in tile_start..tile_end {
let row_start = row * in_dim;
let row_end = row_start + in_dim;
let row_vec = Vector::from_slice(&weight[row_start..row_end]);
output[row] = input_vec.dot(&row_vec).expect("dot product failed");
}
}
output
}
#[inline]
#[must_use]
pub fn simd_dot(a: &[f32], b: &[f32]) -> f32 {
Vector::from_slice(a)
.dot(&Vector::from_slice(b))
.expect("dot product failed")
}
#[inline]
pub fn simd_add(a: &mut [f32], b: &[f32]) {
for (x, y) in a.iter_mut().zip(b.iter()) {
*x += y;
}
}
#[inline]
pub fn simd_mul(a: &mut [f32], b: &[f32]) {
for (x, y) in a.iter_mut().zip(b.iter()) {
*x *= y;
}
}
#[inline]
pub fn simd_silu(data: &mut [f32]) {
for x in data.iter_mut() {
*x = trueno::silu_scalar(*x);
}
}
#[inline]
pub fn simd_gelu(data: &mut [f32]) {
for x in data.iter_mut() {
*x = trueno::gelu_scalar(*x);
}
}
pub fn simd_softmax(data: &mut [f32]) {
if data.is_empty() {
return;
}
let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0;
for x in data.iter_mut() {
*x = (*x - max_val).exp();
sum += *x;
}
if sum > 0.0 {
let inv_sum = 1.0 / sum;
for x in data.iter_mut() {
*x *= inv_sum;
}
}
}
#[must_use]
pub fn simd_bf16_to_f32(input: &[u8]) -> Vec<f32> {
let count = input.len() / 2;
if count == 0 {
return Vec::new();
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
simd_bf16_to_f32_avx2(input, count)
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
{
bf16_to_f32_fast(input, count)
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
fn simd_bf16_to_f32_avx2(input: &[u8], count: usize) -> Vec<f32> {
use std::arch::x86_64::*;
let mut output = vec![0.0f32; count];
let chunks = count / 8;
let remainder = count % 8;
unsafe {
for i in 0..chunks {
let in_offset = i * 16;
let out_offset = i * 8;
let bf16_bytes = _mm_loadu_si128(input.as_ptr().add(in_offset) as *const __m128i);
let lo = _mm_unpacklo_epi16(bf16_bytes, _mm_setzero_si128());
let lo_shifted = _mm_slli_epi32(lo, 16);
let hi = _mm_unpackhi_epi16(bf16_bytes, _mm_setzero_si128());
let hi_shifted = _mm_slli_epi32(hi, 16);
_mm_storeu_ps(
output.as_mut_ptr().add(out_offset),
_mm_castsi128_ps(lo_shifted),
);
_mm_storeu_ps(
output.as_mut_ptr().add(out_offset + 4),
_mm_castsi128_ps(hi_shifted),
);
}
}
let remainder_start = chunks * 8;
for i in 0..remainder {
let offset = (remainder_start + i) * 2;
let bits = u16::from_le_bytes([input[offset], input[offset + 1]]) as u32;
output[remainder_start + i] = f32::from_bits(bits << 16);
}
output
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
fn bf16_to_f32_fast(input: &[u8], count: usize) -> Vec<f32> {
let mut output = Vec::with_capacity(count);
for chunk in input.chunks_exact(2) {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]) as u32;
output.push(f32::from_bits(bits << 16));
}
output
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
fn bf16_to_f32_fast(input: &[u8], count: usize) -> Vec<f32> {
let mut output = Vec::with_capacity(count);
for chunk in input.chunks_exact(2) {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]) as u32;
output.push(f32::from_bits(bits << 16));
}
output
}
#[must_use]
pub fn simd_f16_to_f32(input: &[u8]) -> Vec<f32> {
input
.chunks_exact(2)
.map(|chunk| {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
half::f16::from_bits(bits).to_f32()
})
.collect()
}
#[must_use]
pub fn simd_bf16_dot(a: &[u8], b: &[u8]) -> f32 {
const CHUNK_SIZE: usize = 64;
let count = a.len().min(b.len()) / 2;
let mut sum = 0.0f32;
for chunk_start in (0..count).step_by(CHUNK_SIZE) {
let chunk_end = (chunk_start + CHUNK_SIZE).min(count);
let byte_start = chunk_start * 2;
let byte_end = chunk_end * 2;
let a_f32 = simd_bf16_to_f32(&a[byte_start..byte_end]);
let b_f32 = simd_bf16_to_f32(&b[byte_start..byte_end]);
sum += simd_dot(&a_f32, &b_f32);
}
sum
}
include!("simd_bf16_ops.rs");
include!("simd_bf16.rs");