Skip to main content

forgellm_runtime/
kernels.rs

1//! Optimized compute kernels.
2//!
3//! Provides SIMD-accelerated implementations of core operations.
4//! Uses ARM NEON intrinsics on aarch64, falls back to scalar code elsewhere.
5//!
6//! The primary bottleneck is matmul (matrix-vector multiply for single-token
7//! generation). The NEON version processes 4 f32s per cycle using vfmaq_f32.
8
9/// NEON-accelerated dot product of two f32 slices.
10#[cfg(target_arch = "aarch64")]
11#[inline]
12fn dot_f32_neon(a: &[f32], b: &[f32], len: usize) -> f32 {
13    use std::arch::aarch64::*;
14    unsafe {
15        let mut sum0 = vdupq_n_f32(0.0);
16        let mut sum1 = vdupq_n_f32(0.0);
17        let mut sum2 = vdupq_n_f32(0.0);
18        let mut sum3 = vdupq_n_f32(0.0);
19
20        let chunks = len / 16;
21        for i in 0..chunks {
22            let base = i * 16;
23            let a0 = vld1q_f32(a.as_ptr().add(base));
24            let b0 = vld1q_f32(b.as_ptr().add(base));
25            sum0 = vfmaq_f32(sum0, a0, b0);
26
27            let a1 = vld1q_f32(a.as_ptr().add(base + 4));
28            let b1 = vld1q_f32(b.as_ptr().add(base + 4));
29            sum1 = vfmaq_f32(sum1, a1, b1);
30
31            let a2 = vld1q_f32(a.as_ptr().add(base + 8));
32            let b2 = vld1q_f32(b.as_ptr().add(base + 8));
33            sum2 = vfmaq_f32(sum2, a2, b2);
34
35            let a3 = vld1q_f32(a.as_ptr().add(base + 12));
36            let b3 = vld1q_f32(b.as_ptr().add(base + 12));
37            sum3 = vfmaq_f32(sum3, a3, b3);
38        }
39
40        // Combine accumulators
41        sum0 = vaddq_f32(sum0, sum1);
42        sum2 = vaddq_f32(sum2, sum3);
43        sum0 = vaddq_f32(sum0, sum2);
44
45        let mut result = vaddvq_f32(sum0);
46
47        // Handle remainder
48        for i in (chunks * 16)..len {
49            result += *a.get_unchecked(i) * *b.get_unchecked(i);
50        }
51
52        result
53    }
54}
55
56/// Scalar dot product fallback.
57#[cfg(not(target_arch = "aarch64"))]
58#[inline]
59fn dot_f32_neon(a: &[f32], b: &[f32], len: usize) -> f32 {
60    let mut sum: f32 = 0.0;
61    for i in 0..len {
62        sum += a[i] * b[i];
63    }
64    sum
65}
66
67/// Optimized matrix-vector multiply using NEON dot products.
68///
69/// For single-token inference (m=1), computes dot products between
70/// the input vector and each weight row.
71///
72/// Weight layout: [n, k] (row-major), so weight row j is at offset j*k.
73pub fn matmul_vec(output: &mut [f32], input: &[f32], weight: &[f32], k: usize, n: usize) {
74    // Process 4 output rows at a time for ILP
75    let n_chunks = n / 4;
76    let n_remainder = n % 4;
77
78    for chunk in 0..n_chunks {
79        let j0 = chunk * 4;
80        output[j0] = dot_f32_neon(input, &weight[j0 * k..(j0 + 1) * k], k);
81        output[j0 + 1] = dot_f32_neon(input, &weight[(j0 + 1) * k..(j0 + 2) * k], k);
82        output[j0 + 2] = dot_f32_neon(input, &weight[(j0 + 2) * k..(j0 + 3) * k], k);
83        output[j0 + 3] = dot_f32_neon(input, &weight[(j0 + 3) * k..(j0 + 4) * k], k);
84    }
85
86    // Handle remaining output elements
87    let j_base = n_chunks * 4;
88    for r in 0..n_remainder {
89        let j = j_base + r;
90        output[j] = dot_f32_neon(input, &weight[j * k..(j + 1) * k], k);
91    }
92}
93
94/// General matrix multiply: output[m,n] = input[m,k] * weight^T[k,n]
95///
96/// For m=1 (single token), delegates to the optimized vector version.
97pub fn matmul(output: &mut [f32], input: &[f32], weight: &[f32], m: usize, k: usize, n: usize) {
98    if m == 1 {
99        matmul_vec(output, input, weight, k, n);
100    } else {
101        for i in 0..m {
102            let in_row = &input[i * k..(i + 1) * k];
103            let out_row = &mut output[i * n..(i + 1) * n];
104            matmul_vec(out_row, in_row, weight, k, n);
105        }
106    }
107}
108
109/// NEON-accelerated RMSNorm: output = (input / rms(input)) * weight
110pub fn rms_norm(output: &mut [f32], input: &[f32], weight: &[f32], eps: f32) {
111    let n = input.len();
112
113    // NEON dot product for sum of squares
114    let sum_sq = dot_f32_neon(input, input, n);
115    let inv_rms = 1.0 / (sum_sq / n as f32 + eps).sqrt();
116
117    // NEON-accelerated normalization + weight multiply
118    rms_norm_apply(output, input, weight, inv_rms);
119}
120
121#[cfg(target_arch = "aarch64")]
122fn rms_norm_apply(output: &mut [f32], input: &[f32], weight: &[f32], inv_rms: f32) {
123    use std::arch::aarch64::*;
124    let n = input.len();
125    let chunks = n / 4;
126
127    unsafe {
128        let scale = vdupq_n_f32(inv_rms);
129        for i in 0..chunks {
130            let base = i * 4;
131            let x = vld1q_f32(input.as_ptr().add(base));
132            let w = vld1q_f32(weight.as_ptr().add(base));
133            let r = vmulq_f32(vmulq_f32(x, scale), w);
134            vst1q_f32(output.as_mut_ptr().add(base), r);
135        }
136    }
137    for i in (chunks * 4)..n {
138        output[i] = input[i] * inv_rms * weight[i];
139    }
140}
141
142#[cfg(not(target_arch = "aarch64"))]
143fn rms_norm_apply(output: &mut [f32], input: &[f32], weight: &[f32], inv_rms: f32) {
144    for i in 0..input.len() {
145        output[i] = input[i] * inv_rms * weight[i];
146    }
147}
148
149/// Optimized SiLU: x * sigmoid(x) = x / (1 + exp(-x))
150pub fn silu(output: &mut [f32], input: &[f32]) {
151    for (o, &x) in output.iter_mut().zip(input.iter()) {
152        *o = x / (1.0 + (-x).exp());
153    }
154}
155
156/// GeLU activation (approximate): 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
157pub fn gelu(output: &mut [f32], input: &[f32]) {
158    const SQRT_2_OVER_PI: f32 = 0.797_884_6; // sqrt(2/pi)
159    for (o, &x) in output.iter_mut().zip(input.iter()) {
160        let inner = SQRT_2_OVER_PI * (x + 0.044715 * x * x * x);
161        *o = 0.5 * x * (1.0 + inner.tanh());
162    }
163}
164
165/// NEON-accelerated elementwise multiply
166pub fn elementwise_mul(output: &mut [f32], a: &[f32], b: &[f32]) {
167    elementwise_binary_op(output, a, b, BinaryOp::Mul);
168}
169
170/// NEON-accelerated elementwise add
171pub fn elementwise_add(output: &mut [f32], a: &[f32], b: &[f32]) {
172    elementwise_binary_op(output, a, b, BinaryOp::Add);
173}
174
175enum BinaryOp {
176    Mul,
177    Add,
178}
179
180#[cfg(target_arch = "aarch64")]
181fn elementwise_binary_op(output: &mut [f32], a: &[f32], b: &[f32], op: BinaryOp) {
182    use std::arch::aarch64::*;
183    let n = a.len();
184    let chunks = n / 16;
185
186    unsafe {
187        for i in 0..chunks {
188            let base = i * 16;
189            let a0 = vld1q_f32(a.as_ptr().add(base));
190            let b0 = vld1q_f32(b.as_ptr().add(base));
191            let a1 = vld1q_f32(a.as_ptr().add(base + 4));
192            let b1 = vld1q_f32(b.as_ptr().add(base + 4));
193            let a2 = vld1q_f32(a.as_ptr().add(base + 8));
194            let b2 = vld1q_f32(b.as_ptr().add(base + 8));
195            let a3 = vld1q_f32(a.as_ptr().add(base + 12));
196            let b3 = vld1q_f32(b.as_ptr().add(base + 12));
197
198            let (r0, r1, r2, r3) = match op {
199                BinaryOp::Mul => (
200                    vmulq_f32(a0, b0),
201                    vmulq_f32(a1, b1),
202                    vmulq_f32(a2, b2),
203                    vmulq_f32(a3, b3),
204                ),
205                BinaryOp::Add => (
206                    vaddq_f32(a0, b0),
207                    vaddq_f32(a1, b1),
208                    vaddq_f32(a2, b2),
209                    vaddq_f32(a3, b3),
210                ),
211            };
212
213            vst1q_f32(output.as_mut_ptr().add(base), r0);
214            vst1q_f32(output.as_mut_ptr().add(base + 4), r1);
215            vst1q_f32(output.as_mut_ptr().add(base + 8), r2);
216            vst1q_f32(output.as_mut_ptr().add(base + 12), r3);
217        }
218    }
219
220    // Scalar remainder
221    for i in (chunks * 16)..n {
222        output[i] = match op {
223            BinaryOp::Mul => a[i] * b[i],
224            BinaryOp::Add => a[i] + b[i],
225        };
226    }
227}
228
229#[cfg(not(target_arch = "aarch64"))]
230fn elementwise_binary_op(output: &mut [f32], a: &[f32], b: &[f32], op: BinaryOp) {
231    for i in 0..a.len() {
232        output[i] = match op {
233            BinaryOp::Mul => a[i] * b[i],
234            BinaryOp::Add => a[i] + b[i],
235        };
236    }
237}
238
239/// Softmax with numerical stability
240pub fn softmax(values: &mut [f32]) {
241    let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
242    let mut sum: f32 = 0.0;
243    for v in values.iter_mut() {
244        *v = (*v - max_val).exp();
245        sum += *v;
246    }
247    let inv_sum = if sum > 0.0 { 1.0 / sum } else { 0.0 };
248    for v in values.iter_mut() {
249        *v *= inv_sum;
250    }
251}
252
253/// Optimized grouped-query attention using NEON dot products.
254///
255/// Computes: for each head, score = softmax(Q·K^T / sqrt(d)), output = score·V
256#[allow(clippy::too_many_arguments)]
257pub fn attention(
258    output: &mut [f32],
259    q: &[f32],
260    k_cache: &[f32],
261    v_cache: &[f32],
262    seq_len: usize,
263    num_heads: usize,
264    num_kv_heads: usize,
265    head_dim: usize,
266) {
267    let kv_group_size = num_heads / num_kv_heads;
268    let scale = 1.0 / (head_dim as f32).sqrt();
269    let kv_stride = num_kv_heads * head_dim;
270
271    for h in 0..num_heads {
272        let kv_h = h / kv_group_size;
273        let q_offset = h * head_dim;
274        let q_head = &q[q_offset..q_offset + head_dim];
275
276        // Compute attention scores using NEON dot product
277        let mut scores = vec![0.0f32; seq_len];
278        for (t, score) in scores.iter_mut().enumerate() {
279            let k_offset = t * kv_stride + kv_h * head_dim;
280            *score =
281                dot_f32_neon(q_head, &k_cache[k_offset..k_offset + head_dim], head_dim) * scale;
282        }
283
284        softmax(&mut scores);
285
286        // Weighted sum of values
287        for d in 0..head_dim {
288            let mut sum: f32 = 0.0;
289            for (t, &score) in scores.iter().enumerate() {
290                let v_offset = t * kv_stride + kv_h * head_dim;
291                sum += score * v_cache[v_offset + d];
292            }
293            output[q_offset + d] = sum;
294        }
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn dot_product_basic() {
304        let a = vec![1.0f32, 2.0, 3.0, 4.0];
305        let b = vec![1.0f32, 1.0, 1.0, 1.0];
306        let result = dot_f32_neon(&a, &b, 4);
307        assert!((result - 10.0).abs() < 1e-5);
308    }
309
310    #[test]
311    fn dot_product_large() {
312        let k = 576; // SmolLM hidden size
313        let a: Vec<f32> = (0..k).map(|i| (i as f32) * 0.001).collect();
314        let b: Vec<f32> = (0..k).map(|i| ((k - i) as f32) * 0.001).collect();
315
316        let neon_result = dot_f32_neon(&a, &b, k);
317
318        // Reference
319        let ref_result: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
320        assert!(
321            (neon_result - ref_result).abs() < 1e-1,
322            "NEON={neon_result}, ref={ref_result}"
323        );
324    }
325
326    #[test]
327    fn matmul_vec_basic() {
328        // [1, 2] * [[1, 2], [3, 4]]^T = [1*1+2*2, 1*3+2*4] = [5, 11]
329        let input = [1.0f32, 2.0];
330        let weight = [1.0, 2.0, 3.0, 4.0];
331        let mut output = [0.0f32; 2];
332        matmul_vec(&mut output, &input, &weight, 2, 2);
333        assert!((output[0] - 5.0).abs() < 1e-5);
334        assert!((output[1] - 11.0).abs() < 1e-5);
335    }
336
337    #[test]
338    fn matmul_vec_larger() {
339        let k = 64;
340        let n = 32;
341        let input: Vec<f32> = (0..k).map(|i| i as f32 * 0.1).collect();
342        let weight: Vec<f32> = (0..n * k).map(|i| (i % 7) as f32 * 0.01).collect();
343        let mut output = vec![0.0f32; n];
344        let mut output_ref = vec![0.0f32; n];
345
346        matmul_vec(&mut output, &input, &weight, k, n);
347
348        for j in 0..n {
349            let mut sum = 0.0f32;
350            for l in 0..k {
351                sum += input[l] * weight[j * k + l];
352            }
353            output_ref[j] = sum;
354        }
355
356        for j in 0..n {
357            assert!(
358                (output[j] - output_ref[j]).abs() < 1e-2,
359                "mismatch at j={j}: {} vs {}",
360                output[j],
361                output_ref[j]
362            );
363        }
364    }
365
366    #[test]
367    fn matmul_vec_odd_dimensions() {
368        let k = 13;
369        let n = 7;
370        let input: Vec<f32> = (0..k).map(|i| i as f32).collect();
371        let weight: Vec<f32> = (0..n * k).map(|i| (i as f32) * 0.01).collect();
372        let mut output = vec![0.0f32; n];
373        let mut output_ref = vec![0.0f32; n];
374
375        matmul_vec(&mut output, &input, &weight, k, n);
376
377        for j in 0..n {
378            let mut sum = 0.0f32;
379            for l in 0..k {
380                sum += input[l] * weight[j * k + l];
381            }
382            output_ref[j] = sum;
383        }
384
385        for j in 0..n {
386            assert!(
387                (output[j] - output_ref[j]).abs() < 1e-2,
388                "mismatch at j={j}: {} vs {}",
389                output[j],
390                output_ref[j]
391            );
392        }
393    }
394
395    #[test]
396    fn rms_norm_basic() {
397        let input = [1.0f32, 2.0, 3.0, 4.0];
398        let weight = [1.0f32; 4];
399        let mut output = [0.0f32; 4];
400        let mut output_ref = [0.0f32; 4];
401
402        rms_norm(&mut output, &input, &weight, 1e-5);
403
404        let sum_sq: f32 = input.iter().map(|x| x * x).sum();
405        let inv_rms = 1.0 / (sum_sq / 4.0 + 1e-5).sqrt();
406        for i in 0..4 {
407            output_ref[i] = input[i] * inv_rms * weight[i];
408        }
409
410        for i in 0..4 {
411            assert!((output[i] - output_ref[i]).abs() < 1e-5);
412        }
413    }
414
415    #[test]
416    fn matmul_general() {
417        let input = [1.0f32, 2.0, 3.0, 4.0];
418        let weight = [1.0, 0.0, 0.0, 1.0];
419        let mut output = [0.0f32; 4];
420        matmul(&mut output, &input, &weight, 2, 2, 2);
421        assert!((output[0] - 1.0).abs() < 1e-5);
422        assert!((output[1] - 2.0).abs() < 1e-5);
423        assert!((output[2] - 3.0).abs() < 1e-5);
424        assert!((output[3] - 4.0).abs() < 1e-5);
425    }
426
427    #[test]
428    fn dot_product_smollm_dimension() {
429        // Test with actual SmolLM hidden dimension (576)
430        let k = 576;
431        let a: Vec<f32> = (0..k).map(|i| ((i * 7 + 3) % 100) as f32 * 0.01).collect();
432        let b: Vec<f32> = (0..k).map(|i| ((i * 11 + 5) % 100) as f32 * 0.01).collect();
433
434        let neon = dot_f32_neon(&a, &b, k);
435        let reference: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
436
437        assert!(
438            (neon - reference).abs() / reference.abs() < 1e-4,
439            "relative error too large: NEON={neon}, ref={reference}"
440        );
441    }
442}