Skip to main content

rage_quant/
gemm_kernel.rs

1use rayon::prelude::*;
2
3pub fn dot_f32(lhs: &[f32], rhs: &[f32]) -> f32 {
4    assert_eq!(lhs.len(), rhs.len(), "dot_f32 requiere slices del mismo largo");
5
6    #[cfg(target_arch = "x86_64")]
7    {
8        if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") {
9            // SAFETY: se verifica soporte AVX2/FMA antes de invocar el kernel especializado.
10            return unsafe { dot_f32_avx2(lhs, rhs) };
11        }
12    }
13
14    dot_f32_scalar(lhs, rhs)
15}
16
17pub fn gemv_rows_f32(rows: &[f32], row_len: usize, vec: &[f32]) -> Vec<f32> {
18    assert_eq!(vec.len(), row_len, "gemv_rows_f32 requiere vector del mismo ancho que la fila");
19    assert_eq!(rows.len() % row_len, 0, "gemv_rows_f32 requiere rows divisible por row_len");
20
21    rows.par_chunks_exact(row_len)
22        .map(|row| dot_f32(row, vec))
23        .collect()
24}
25
26fn dot_f32_scalar(lhs: &[f32], rhs: &[f32]) -> f32 {
27    lhs.iter().zip(rhs.iter()).map(|(a, b)| a * b).sum()
28}
29
30#[cfg(target_arch = "x86_64")]
31#[target_feature(enable = "avx2")]
32#[target_feature(enable = "fma")]
33unsafe fn dot_f32_avx2(lhs: &[f32], rhs: &[f32]) -> f32 {
34    use std::arch::x86_64::{
35        __m256, _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_loadu_ps,
36        _mm256_setzero_ps, _mm_add_ps, _mm_cvtss_f32, _mm_hadd_ps,
37    };
38
39    let len = lhs.len();
40    let mut index = 0usize;
41    let mut acc: __m256 = _mm256_setzero_ps();
42    while index + 8 <= len {
43        let a = _mm256_loadu_ps(lhs.as_ptr().add(index));
44        let b = _mm256_loadu_ps(rhs.as_ptr().add(index));
45        acc = _mm256_fmadd_ps(a, b, acc);
46        index += 8;
47    }
48
49    let hi = _mm256_extractf128_ps(acc, 1);
50    let lo = _mm256_castps256_ps128(acc);
51    let sum128 = _mm_add_ps(lo, hi);
52    let sum64 = _mm_hadd_ps(sum128, sum128);
53    let sum32 = _mm_hadd_ps(sum64, sum64);
54    let mut total = _mm_cvtss_f32(sum32);
55
56    while index < len {
57        total += lhs[index] * rhs[index];
58        index += 1;
59    }
60    total
61}
62
63pub fn gemm_f32_row_major(m: usize, n: usize, k: usize, lhs: &[f32], rhs: &[f32]) -> Vec<f32> {
64    if lhs.len() != m * k {
65        panic!("lhs len {} != m*k {}", lhs.len(), m * k);
66    }
67    if rhs.len() != k * n {
68        panic!("rhs len {} != k*n {}", rhs.len(), k * n);
69    }
70
71    let mut out = vec![0.0f32; m * n];
72    // SAFETY: todos los punteros apuntan a buffers contiguos de tamano valido
73    // en layout row-major y no se solapan entre si.
74    unsafe {
75        gemm::gemm(
76            m,
77            n,
78            k,
79            out.as_mut_ptr(),
80            1,
81            n as isize,
82            false,
83            lhs.as_ptr(),
84            1,
85            k as isize,
86            rhs.as_ptr(),
87            1,
88            n as isize,
89            0.0f32,
90            1.0f32,
91            false,
92            false,
93            false,
94            gemm::Parallelism::Rayon(0),
95        );
96    }
97    out
98}