rage_quant/
gemm_kernel.rs1use rayon::prelude::*;
2
3pub fn dot_f32(lhs: &[f32], rhs: &[f32]) -> f32 {
4 assert_eq!(lhs.len(), rhs.len(), "dot_f32 requiere slices del mismo largo");
5
6 #[cfg(target_arch = "x86_64")]
7 {
8 if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") {
9 return unsafe { dot_f32_avx2(lhs, rhs) };
11 }
12 }
13
14 dot_f32_scalar(lhs, rhs)
15}
16
17pub fn gemv_rows_f32(rows: &[f32], row_len: usize, vec: &[f32]) -> Vec<f32> {
18 assert_eq!(vec.len(), row_len, "gemv_rows_f32 requiere vector del mismo ancho que la fila");
19 assert_eq!(rows.len() % row_len, 0, "gemv_rows_f32 requiere rows divisible por row_len");
20
21 rows.par_chunks_exact(row_len)
22 .map(|row| dot_f32(row, vec))
23 .collect()
24}
25
26fn dot_f32_scalar(lhs: &[f32], rhs: &[f32]) -> f32 {
27 lhs.iter().zip(rhs.iter()).map(|(a, b)| a * b).sum()
28}
29
30#[cfg(target_arch = "x86_64")]
31#[target_feature(enable = "avx2")]
32#[target_feature(enable = "fma")]
33unsafe fn dot_f32_avx2(lhs: &[f32], rhs: &[f32]) -> f32 {
34 use std::arch::x86_64::{
35 __m256, _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_loadu_ps,
36 _mm256_setzero_ps, _mm_add_ps, _mm_cvtss_f32, _mm_hadd_ps,
37 };
38
39 let len = lhs.len();
40 let mut index = 0usize;
41 let mut acc: __m256 = _mm256_setzero_ps();
42 while index + 8 <= len {
43 let a = _mm256_loadu_ps(lhs.as_ptr().add(index));
44 let b = _mm256_loadu_ps(rhs.as_ptr().add(index));
45 acc = _mm256_fmadd_ps(a, b, acc);
46 index += 8;
47 }
48
49 let hi = _mm256_extractf128_ps(acc, 1);
50 let lo = _mm256_castps256_ps128(acc);
51 let sum128 = _mm_add_ps(lo, hi);
52 let sum64 = _mm_hadd_ps(sum128, sum128);
53 let sum32 = _mm_hadd_ps(sum64, sum64);
54 let mut total = _mm_cvtss_f32(sum32);
55
56 while index < len {
57 total += lhs[index] * rhs[index];
58 index += 1;
59 }
60 total
61}
62
63pub fn gemm_f32_row_major(m: usize, n: usize, k: usize, lhs: &[f32], rhs: &[f32]) -> Vec<f32> {
64 if lhs.len() != m * k {
65 panic!("lhs len {} != m*k {}", lhs.len(), m * k);
66 }
67 if rhs.len() != k * n {
68 panic!("rhs len {} != k*n {}", rhs.len(), k * n);
69 }
70
71 let mut out = vec![0.0f32; m * n];
72 unsafe {
75 gemm::gemm(
76 m,
77 n,
78 k,
79 out.as_mut_ptr(),
80 1,
81 n as isize,
82 false,
83 lhs.as_ptr(),
84 1,
85 k as isize,
86 rhs.as_ptr(),
87 1,
88 n as isize,
89 0.0f32,
90 1.0f32,
91 false,
92 false,
93 false,
94 gemm::Parallelism::Rayon(0),
95 );
96 }
97 out
98}