1#[cfg(target_arch = "aarch64")]
11#[inline]
12fn dot_f32_neon(a: &[f32], b: &[f32], len: usize) -> f32 {
13 use std::arch::aarch64::*;
14 unsafe {
15 let mut sum0 = vdupq_n_f32(0.0);
16 let mut sum1 = vdupq_n_f32(0.0);
17 let mut sum2 = vdupq_n_f32(0.0);
18 let mut sum3 = vdupq_n_f32(0.0);
19
20 let chunks = len / 16;
21 for i in 0..chunks {
22 let base = i * 16;
23 let a0 = vld1q_f32(a.as_ptr().add(base));
24 let b0 = vld1q_f32(b.as_ptr().add(base));
25 sum0 = vfmaq_f32(sum0, a0, b0);
26
27 let a1 = vld1q_f32(a.as_ptr().add(base + 4));
28 let b1 = vld1q_f32(b.as_ptr().add(base + 4));
29 sum1 = vfmaq_f32(sum1, a1, b1);
30
31 let a2 = vld1q_f32(a.as_ptr().add(base + 8));
32 let b2 = vld1q_f32(b.as_ptr().add(base + 8));
33 sum2 = vfmaq_f32(sum2, a2, b2);
34
35 let a3 = vld1q_f32(a.as_ptr().add(base + 12));
36 let b3 = vld1q_f32(b.as_ptr().add(base + 12));
37 sum3 = vfmaq_f32(sum3, a3, b3);
38 }
39
40 sum0 = vaddq_f32(sum0, sum1);
42 sum2 = vaddq_f32(sum2, sum3);
43 sum0 = vaddq_f32(sum0, sum2);
44
45 let mut result = vaddvq_f32(sum0);
46
47 for i in (chunks * 16)..len {
49 result += *a.get_unchecked(i) * *b.get_unchecked(i);
50 }
51
52 result
53 }
54}
55
56#[cfg(not(target_arch = "aarch64"))]
58#[inline]
59fn dot_f32_neon(a: &[f32], b: &[f32], len: usize) -> f32 {
60 let mut sum: f32 = 0.0;
61 for i in 0..len {
62 sum += a[i] * b[i];
63 }
64 sum
65}
66
67pub fn matmul_vec(output: &mut [f32], input: &[f32], weight: &[f32], k: usize, n: usize) {
74 let n_chunks = n / 4;
76 let n_remainder = n % 4;
77
78 for chunk in 0..n_chunks {
79 let j0 = chunk * 4;
80 output[j0] = dot_f32_neon(input, &weight[j0 * k..(j0 + 1) * k], k);
81 output[j0 + 1] = dot_f32_neon(input, &weight[(j0 + 1) * k..(j0 + 2) * k], k);
82 output[j0 + 2] = dot_f32_neon(input, &weight[(j0 + 2) * k..(j0 + 3) * k], k);
83 output[j0 + 3] = dot_f32_neon(input, &weight[(j0 + 3) * k..(j0 + 4) * k], k);
84 }
85
86 let j_base = n_chunks * 4;
88 for r in 0..n_remainder {
89 let j = j_base + r;
90 output[j] = dot_f32_neon(input, &weight[j * k..(j + 1) * k], k);
91 }
92}
93
94pub fn matmul(output: &mut [f32], input: &[f32], weight: &[f32], m: usize, k: usize, n: usize) {
98 if m == 1 {
99 matmul_vec(output, input, weight, k, n);
100 } else {
101 for i in 0..m {
102 let in_row = &input[i * k..(i + 1) * k];
103 let out_row = &mut output[i * n..(i + 1) * n];
104 matmul_vec(out_row, in_row, weight, k, n);
105 }
106 }
107}
108
109pub fn rms_norm(output: &mut [f32], input: &[f32], weight: &[f32], eps: f32) {
111 let n = input.len();
112
113 let sum_sq = dot_f32_neon(input, input, n);
115 let inv_rms = 1.0 / (sum_sq / n as f32 + eps).sqrt();
116
117 rms_norm_apply(output, input, weight, inv_rms);
119}
120
121#[cfg(target_arch = "aarch64")]
122fn rms_norm_apply(output: &mut [f32], input: &[f32], weight: &[f32], inv_rms: f32) {
123 use std::arch::aarch64::*;
124 let n = input.len();
125 let chunks = n / 4;
126
127 unsafe {
128 let scale = vdupq_n_f32(inv_rms);
129 for i in 0..chunks {
130 let base = i * 4;
131 let x = vld1q_f32(input.as_ptr().add(base));
132 let w = vld1q_f32(weight.as_ptr().add(base));
133 let r = vmulq_f32(vmulq_f32(x, scale), w);
134 vst1q_f32(output.as_mut_ptr().add(base), r);
135 }
136 }
137 for i in (chunks * 4)..n {
138 output[i] = input[i] * inv_rms * weight[i];
139 }
140}
141
142#[cfg(not(target_arch = "aarch64"))]
143fn rms_norm_apply(output: &mut [f32], input: &[f32], weight: &[f32], inv_rms: f32) {
144 for i in 0..input.len() {
145 output[i] = input[i] * inv_rms * weight[i];
146 }
147}
148
149pub fn silu(output: &mut [f32], input: &[f32]) {
151 for (o, &x) in output.iter_mut().zip(input.iter()) {
152 *o = x / (1.0 + (-x).exp());
153 }
154}
155
156pub fn gelu(output: &mut [f32], input: &[f32]) {
158 const SQRT_2_OVER_PI: f32 = 0.797_884_6; for (o, &x) in output.iter_mut().zip(input.iter()) {
160 let inner = SQRT_2_OVER_PI * (x + 0.044715 * x * x * x);
161 *o = 0.5 * x * (1.0 + inner.tanh());
162 }
163}
164
165pub fn elementwise_mul(output: &mut [f32], a: &[f32], b: &[f32]) {
167 elementwise_binary_op(output, a, b, BinaryOp::Mul);
168}
169
170pub fn elementwise_add(output: &mut [f32], a: &[f32], b: &[f32]) {
172 elementwise_binary_op(output, a, b, BinaryOp::Add);
173}
174
175enum BinaryOp {
176 Mul,
177 Add,
178}
179
180#[cfg(target_arch = "aarch64")]
181fn elementwise_binary_op(output: &mut [f32], a: &[f32], b: &[f32], op: BinaryOp) {
182 use std::arch::aarch64::*;
183 let n = a.len();
184 let chunks = n / 16;
185
186 unsafe {
187 for i in 0..chunks {
188 let base = i * 16;
189 let a0 = vld1q_f32(a.as_ptr().add(base));
190 let b0 = vld1q_f32(b.as_ptr().add(base));
191 let a1 = vld1q_f32(a.as_ptr().add(base + 4));
192 let b1 = vld1q_f32(b.as_ptr().add(base + 4));
193 let a2 = vld1q_f32(a.as_ptr().add(base + 8));
194 let b2 = vld1q_f32(b.as_ptr().add(base + 8));
195 let a3 = vld1q_f32(a.as_ptr().add(base + 12));
196 let b3 = vld1q_f32(b.as_ptr().add(base + 12));
197
198 let (r0, r1, r2, r3) = match op {
199 BinaryOp::Mul => (
200 vmulq_f32(a0, b0),
201 vmulq_f32(a1, b1),
202 vmulq_f32(a2, b2),
203 vmulq_f32(a3, b3),
204 ),
205 BinaryOp::Add => (
206 vaddq_f32(a0, b0),
207 vaddq_f32(a1, b1),
208 vaddq_f32(a2, b2),
209 vaddq_f32(a3, b3),
210 ),
211 };
212
213 vst1q_f32(output.as_mut_ptr().add(base), r0);
214 vst1q_f32(output.as_mut_ptr().add(base + 4), r1);
215 vst1q_f32(output.as_mut_ptr().add(base + 8), r2);
216 vst1q_f32(output.as_mut_ptr().add(base + 12), r3);
217 }
218 }
219
220 for i in (chunks * 16)..n {
222 output[i] = match op {
223 BinaryOp::Mul => a[i] * b[i],
224 BinaryOp::Add => a[i] + b[i],
225 };
226 }
227}
228
229#[cfg(not(target_arch = "aarch64"))]
230fn elementwise_binary_op(output: &mut [f32], a: &[f32], b: &[f32], op: BinaryOp) {
231 for i in 0..a.len() {
232 output[i] = match op {
233 BinaryOp::Mul => a[i] * b[i],
234 BinaryOp::Add => a[i] + b[i],
235 };
236 }
237}
238
239pub fn softmax(values: &mut [f32]) {
241 let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
242 let mut sum: f32 = 0.0;
243 for v in values.iter_mut() {
244 *v = (*v - max_val).exp();
245 sum += *v;
246 }
247 let inv_sum = 1.0 / sum;
248 for v in values.iter_mut() {
249 *v *= inv_sum;
250 }
251}
252
253#[allow(clippy::too_many_arguments)]
257pub fn attention(
258 output: &mut [f32],
259 q: &[f32],
260 k_cache: &[f32],
261 v_cache: &[f32],
262 seq_len: usize,
263 num_heads: usize,
264 num_kv_heads: usize,
265 head_dim: usize,
266) {
267 let kv_group_size = num_heads / num_kv_heads;
268 let scale = 1.0 / (head_dim as f32).sqrt();
269 let kv_stride = num_kv_heads * head_dim;
270
271 for h in 0..num_heads {
272 let kv_h = h / kv_group_size;
273 let q_offset = h * head_dim;
274 let q_head = &q[q_offset..q_offset + head_dim];
275
276 let mut scores = vec![0.0f32; seq_len];
278 for (t, score) in scores.iter_mut().enumerate() {
279 let k_offset = t * kv_stride + kv_h * head_dim;
280 *score =
281 dot_f32_neon(q_head, &k_cache[k_offset..k_offset + head_dim], head_dim) * scale;
282 }
283
284 softmax(&mut scores);
285
286 for d in 0..head_dim {
288 let mut sum: f32 = 0.0;
289 for (t, &score) in scores.iter().enumerate() {
290 let v_offset = t * kv_stride + kv_h * head_dim;
291 sum += score * v_cache[v_offset + d];
292 }
293 output[q_offset + d] = sum;
294 }
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn dot_product_basic() {
304 let a = vec![1.0f32, 2.0, 3.0, 4.0];
305 let b = vec![1.0f32, 1.0, 1.0, 1.0];
306 let result = dot_f32_neon(&a, &b, 4);
307 assert!((result - 10.0).abs() < 1e-5);
308 }
309
310 #[test]
311 fn dot_product_large() {
312 let k = 576; let a: Vec<f32> = (0..k).map(|i| (i as f32) * 0.001).collect();
314 let b: Vec<f32> = (0..k).map(|i| ((k - i) as f32) * 0.001).collect();
315
316 let neon_result = dot_f32_neon(&a, &b, k);
317
318 let ref_result: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
320 assert!(
321 (neon_result - ref_result).abs() < 1e-1,
322 "NEON={neon_result}, ref={ref_result}"
323 );
324 }
325
326 #[test]
327 fn matmul_vec_basic() {
328 let input = [1.0f32, 2.0];
330 let weight = [1.0, 2.0, 3.0, 4.0];
331 let mut output = [0.0f32; 2];
332 matmul_vec(&mut output, &input, &weight, 2, 2);
333 assert!((output[0] - 5.0).abs() < 1e-5);
334 assert!((output[1] - 11.0).abs() < 1e-5);
335 }
336
337 #[test]
338 fn matmul_vec_larger() {
339 let k = 64;
340 let n = 32;
341 let input: Vec<f32> = (0..k).map(|i| i as f32 * 0.1).collect();
342 let weight: Vec<f32> = (0..n * k).map(|i| (i % 7) as f32 * 0.01).collect();
343 let mut output = vec![0.0f32; n];
344 let mut output_ref = vec![0.0f32; n];
345
346 matmul_vec(&mut output, &input, &weight, k, n);
347
348 for j in 0..n {
349 let mut sum = 0.0f32;
350 for l in 0..k {
351 sum += input[l] * weight[j * k + l];
352 }
353 output_ref[j] = sum;
354 }
355
356 for j in 0..n {
357 assert!(
358 (output[j] - output_ref[j]).abs() < 1e-2,
359 "mismatch at j={j}: {} vs {}",
360 output[j],
361 output_ref[j]
362 );
363 }
364 }
365
366 #[test]
367 fn matmul_vec_odd_dimensions() {
368 let k = 13;
369 let n = 7;
370 let input: Vec<f32> = (0..k).map(|i| i as f32).collect();
371 let weight: Vec<f32> = (0..n * k).map(|i| (i as f32) * 0.01).collect();
372 let mut output = vec![0.0f32; n];
373 let mut output_ref = vec![0.0f32; n];
374
375 matmul_vec(&mut output, &input, &weight, k, n);
376
377 for j in 0..n {
378 let mut sum = 0.0f32;
379 for l in 0..k {
380 sum += input[l] * weight[j * k + l];
381 }
382 output_ref[j] = sum;
383 }
384
385 for j in 0..n {
386 assert!(
387 (output[j] - output_ref[j]).abs() < 1e-2,
388 "mismatch at j={j}: {} vs {}",
389 output[j],
390 output_ref[j]
391 );
392 }
393 }
394
395 #[test]
396 fn rms_norm_basic() {
397 let input = [1.0f32, 2.0, 3.0, 4.0];
398 let weight = [1.0f32; 4];
399 let mut output = [0.0f32; 4];
400 let mut output_ref = [0.0f32; 4];
401
402 rms_norm(&mut output, &input, &weight, 1e-5);
403
404 let sum_sq: f32 = input.iter().map(|x| x * x).sum();
405 let inv_rms = 1.0 / (sum_sq / 4.0 + 1e-5).sqrt();
406 for i in 0..4 {
407 output_ref[i] = input[i] * inv_rms * weight[i];
408 }
409
410 for i in 0..4 {
411 assert!((output[i] - output_ref[i]).abs() < 1e-5);
412 }
413 }
414
415 #[test]
416 fn matmul_general() {
417 let input = [1.0f32, 2.0, 3.0, 4.0];
418 let weight = [1.0, 0.0, 0.0, 1.0];
419 let mut output = [0.0f32; 4];
420 matmul(&mut output, &input, &weight, 2, 2, 2);
421 assert!((output[0] - 1.0).abs() < 1e-5);
422 assert!((output[1] - 2.0).abs() < 1e-5);
423 assert!((output[2] - 3.0).abs() < 1e-5);
424 assert!((output[3] - 4.0).abs() < 1e-5);
425 }
426
427 #[test]
428 fn dot_product_smollm_dimension() {
429 let k = 576;
431 let a: Vec<f32> = (0..k).map(|i| ((i * 7 + 3) % 100) as f32 * 0.01).collect();
432 let b: Vec<f32> = (0..k).map(|i| ((i * 11 + 5) % 100) as f32 * 0.01).collect();
433
434 let neon = dot_f32_neon(&a, &b, k);
435 let reference: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
436
437 assert!(
438 (neon - reference).abs() / reference.abs() < 1e-4,
439 "relative error too large: NEON={neon}, ref={reference}"
440 );
441 }
442}