1use crate::tensor::Tensor;
6use crate::error::Result;
7
8pub fn is_neon_available() -> bool {
10 #[cfg(target_arch = "aarch64")]
11 {
12 true }
14 #[cfg(all(target_arch = "arm", target_feature = "neon"))]
15 {
16 true
17 }
18 #[cfg(not(any(target_arch = "aarch64", all(target_arch = "arm", target_feature = "neon"))))]
19 {
20 false
21 }
22}
23
24pub fn add_neon(a: &[f32], b: &[f32], result: &mut [f32]) {
26 assert_eq!(a.len(), b.len());
27 assert_eq!(a.len(), result.len());
28
29 #[cfg(target_arch = "aarch64")]
30 {
31 unsafe {
32 add_neon_impl(a, b, result);
33 }
34 }
35 #[cfg(not(target_arch = "aarch64"))]
36 {
37 for i in 0..a.len() {
39 result[i] = a[i] + b[i];
40 }
41 }
42}
43
44#[cfg(target_arch = "aarch64")]
45unsafe fn add_neon_impl(a: &[f32], b: &[f32], result: &mut [f32]) {
46 use std::arch::aarch64::*;
47
48 let len = a.len();
49 let chunks = len / 4;
50 let remainder = len % 4;
51
52 for i in 0..chunks {
54 let idx = i * 4;
55
56 let va = vld1q_f32(a.as_ptr().add(idx));
58 let vb = vld1q_f32(b.as_ptr().add(idx));
59
60 let vc = vaddq_f32(va, vb);
62
63 vst1q_f32(result.as_mut_ptr().add(idx), vc);
65 }
66
67 for i in (chunks * 4)..len {
69 result[i] = a[i] + b[i];
70 }
71}
72
73pub fn mul_neon(a: &[f32], b: &[f32], result: &mut [f32]) {
75 assert_eq!(a.len(), b.len());
76 assert_eq!(a.len(), result.len());
77
78 #[cfg(target_arch = "aarch64")]
79 {
80 unsafe {
81 mul_neon_impl(a, b, result);
82 }
83 }
84 #[cfg(not(target_arch = "aarch64"))]
85 {
86 for i in 0..a.len() {
87 result[i] = a[i] * b[i];
88 }
89 }
90}
91
92#[cfg(target_arch = "aarch64")]
93unsafe fn mul_neon_impl(a: &[f32], b: &[f32], result: &mut [f32]) {
94 use std::arch::aarch64::*;
95
96 let len = a.len();
97 let chunks = len / 4;
98
99 for i in 0..chunks {
100 let idx = i * 4;
101 let va = vld1q_f32(a.as_ptr().add(idx));
102 let vb = vld1q_f32(b.as_ptr().add(idx));
103 let vc = vmulq_f32(va, vb);
104 vst1q_f32(result.as_mut_ptr().add(idx), vc);
105 }
106
107 for i in (chunks * 4)..len {
108 result[i] = a[i] * b[i];
109 }
110}
111
112pub fn dot_neon(a: &[f32], b: &[f32]) -> f32 {
114 assert_eq!(a.len(), b.len());
115
116 #[cfg(target_arch = "aarch64")]
117 {
118 unsafe { dot_neon_impl(a, b) }
119 }
120 #[cfg(not(target_arch = "aarch64"))]
121 {
122 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
123 }
124}
125
126#[cfg(target_arch = "aarch64")]
127unsafe fn dot_neon_impl(a: &[f32], b: &[f32]) -> f32 {
128 use std::arch::aarch64::*;
129
130 let len = a.len();
131 let chunks = len / 4;
132
133 let mut acc = vdupq_n_f32(0.0);
135
136 for i in 0..chunks {
137 let idx = i * 4;
138 let va = vld1q_f32(a.as_ptr().add(idx));
139 let vb = vld1q_f32(b.as_ptr().add(idx));
140
141 acc = vfmaq_f32(acc, va, vb);
143 }
144
145 let mut sum = vaddvq_f32(acc);
147
148 for i in (chunks * 4)..len {
150 sum += a[i] * b[i];
151 }
152
153 sum
154}
155
156pub fn relu_neon(data: &mut [f32]) {
158 #[cfg(target_arch = "aarch64")]
159 {
160 unsafe {
161 relu_neon_impl(data);
162 }
163 }
164 #[cfg(not(target_arch = "aarch64"))]
165 {
166 for x in data.iter_mut() {
167 *x = x.max(0.0);
168 }
169 }
170}
171
172#[cfg(target_arch = "aarch64")]
173unsafe fn relu_neon_impl(data: &mut [f32]) {
174 use std::arch::aarch64::*;
175
176 let len = data.len();
177 let chunks = len / 4;
178 let zero = vdupq_n_f32(0.0);
179
180 for i in 0..chunks {
181 let idx = i * 4;
182 let v = vld1q_f32(data.as_ptr().add(idx));
183 let result = vmaxq_f32(v, zero);
184 vst1q_f32(data.as_mut_ptr().add(idx), result);
185 }
186
187 for i in (chunks * 4)..len {
188 data[i] = data[i].max(0.0);
189 }
190}
191
192pub fn sigmoid_neon(data: &mut [f32]) {
194 #[cfg(target_arch = "aarch64")]
195 {
196 unsafe {
197 sigmoid_neon_impl(data);
198 }
199 }
200 #[cfg(not(target_arch = "aarch64"))]
201 {
202 for x in data.iter_mut() {
203 *x = 1.0 / (1.0 + (-*x).exp());
204 }
205 }
206}
207
208#[cfg(target_arch = "aarch64")]
209unsafe fn sigmoid_neon_impl(data: &mut [f32]) {
210 for x in data.iter_mut() {
213 *x = 1.0 / (1.0 + (-*x).exp());
214 }
215}
216
217pub fn matmul_neon(
219 a: &[f32],
220 b: &[f32],
221 result: &mut [f32],
222 m: usize,
223 n: usize,
224 k: usize,
225) {
226 #[cfg(target_arch = "aarch64")]
227 {
228 unsafe {
229 matmul_neon_impl(a, b, result, m, n, k);
230 }
231 }
232 #[cfg(not(target_arch = "aarch64"))]
233 {
234 for i in 0..m {
236 for j in 0..n {
237 let mut sum = 0.0;
238 for p in 0..k {
239 sum += a[i * k + p] * b[p * n + j];
240 }
241 result[i * n + j] = sum;
242 }
243 }
244 }
245}
246
247#[cfg(target_arch = "aarch64")]
248unsafe fn matmul_neon_impl(
249 a: &[f32],
250 b: &[f32],
251 result: &mut [f32],
252 m: usize,
253 n: usize,
254 k: usize,
255) {
256 use std::arch::aarch64::*;
257
258 for i in 0..m {
260 for j in 0..n {
261 let mut acc = vdupq_n_f32(0.0);
262 let chunks = k / 4;
263
264 for p in 0..chunks {
265 let idx = p * 4;
266 let va = vld1q_f32(a.as_ptr().add(i * k + idx));
267 let vb = vld1q_f32(b.as_ptr().add(idx * n + j));
268 acc = vfmaq_f32(acc, va, vb);
269 }
270
271 let mut sum = vaddvq_f32(acc);
272
273 for p in (chunks * 4)..k {
275 sum += a[i * k + p] * b[p * n + j];
276 }
277
278 result[i * n + j] = sum;
279 }
280 }
281}
282
283pub fn conv2d_neon(
285 input: &[f32],
286 kernel: &[f32],
287 output: &mut [f32],
288 input_h: usize,
289 input_w: usize,
290 kernel_h: usize,
291 kernel_w: usize,
292) {
293 let output_h = input_h - kernel_h + 1;
294 let output_w = input_w - kernel_w + 1;
295
296 #[cfg(target_arch = "aarch64")]
297 {
298 unsafe {
299 conv2d_neon_impl(input, kernel, output, input_h, input_w, kernel_h, kernel_w, output_h, output_w);
300 }
301 }
302 #[cfg(not(target_arch = "aarch64"))]
303 {
304 for i in 0..output_h {
306 for j in 0..output_w {
307 let mut sum = 0.0;
308 for ki in 0..kernel_h {
309 for kj in 0..kernel_w {
310 sum += input[(i + ki) * input_w + (j + kj)] * kernel[ki * kernel_w + kj];
311 }
312 }
313 output[i * output_w + j] = sum;
314 }
315 }
316 }
317}
318
319#[cfg(target_arch = "aarch64")]
320unsafe fn conv2d_neon_impl(
321 input: &[f32],
322 kernel: &[f32],
323 output: &mut [f32],
324 input_h: usize,
325 input_w: usize,
326 kernel_h: usize,
327 kernel_w: usize,
328 output_h: usize,
329 output_w: usize,
330) {
331 use std::arch::aarch64::*;
332
333 for i in 0..output_h {
335 for j in 0..output_w {
336 let mut acc = vdupq_n_f32(0.0);
337
338 for ki in 0..kernel_h {
339 for kj in 0..kernel_w {
340 let input_val = input[(i + ki) * input_w + (j + kj)];
341 let kernel_val = kernel[ki * kernel_w + kj];
342 let v_input = vdupq_n_f32(input_val);
343 let v_kernel = vdupq_n_f32(kernel_val);
344 acc = vfmaq_f32(acc, v_input, v_kernel);
345 }
346 }
347
348 output[i * output_w + j] = vaddvq_f32(acc);
349 }
350 }
351}
352
353impl Tensor {
355 pub fn add_neon(&self, other: &Tensor) -> Result<Tensor> {
357 let a = self.data_f32();
358 let b = other.data_f32();
359 let mut result = vec![0.0; a.len()];
360
361 add_neon(&a, &b, &mut result);
362
363 Tensor::from_slice(&result, self.dims())
364 }
365
366 pub fn mul_neon(&self, other: &Tensor) -> Result<Tensor> {
368 let a = self.data_f32();
369 let b = other.data_f32();
370 let mut result = vec![0.0; a.len()];
371
372 mul_neon(&a, &b, &mut result);
373
374 Tensor::from_slice(&result, self.dims())
375 }
376
377 pub fn relu_neon(&self) -> Tensor {
379 let mut data = self.data_f32();
380 relu_neon(&mut data);
381 Tensor::from_slice(&data, self.dims()).unwrap()
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[test]
390 fn test_neon_availability() {
391 let available = is_neon_available();
392 #[cfg(target_arch = "aarch64")]
393 assert!(available);
394 }
395
396 #[test]
397 fn test_add_neon() {
398 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
399 let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
400 let mut result = vec![0.0; 8];
401
402 add_neon(&a, &b, &mut result);
403
404 assert_eq!(result, vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
405 }
406
407 #[test]
408 fn test_dot_neon() {
409 let a = vec![1.0, 2.0, 3.0, 4.0];
410 let b = vec![1.0, 1.0, 1.0, 1.0];
411
412 let result = dot_neon(&a, &b);
413 assert_eq!(result, 10.0);
414 }
415
416 #[test]
417 fn test_relu_neon() {
418 let mut data = vec![-1.0, 2.0, -3.0, 4.0];
419 relu_neon(&mut data);
420 assert_eq!(data, vec![0.0, 2.0, 0.0, 4.0]);
421 }
422}