ghostflow_core/
neon.rs

1//! ARM NEON SIMD optimizations
2//!
3//! Provides SIMD acceleration for ARM processors (mobile, Apple Silicon, etc.)
4
5use crate::tensor::Tensor;
6use crate::error::Result;
7
8/// Check if NEON is available on this platform
9pub fn is_neon_available() -> bool {
10    #[cfg(target_arch = "aarch64")]
11    {
12        true // NEON is always available on AArch64
13    }
14    #[cfg(all(target_arch = "arm", target_feature = "neon"))]
15    {
16        true
17    }
18    #[cfg(not(any(target_arch = "aarch64", all(target_arch = "arm", target_feature = "neon"))))]
19    {
20        false
21    }
22}
23
24/// NEON-optimized vector addition
25pub fn add_neon(a: &[f32], b: &[f32], result: &mut [f32]) {
26    assert_eq!(a.len(), b.len());
27    assert_eq!(a.len(), result.len());
28    
29    #[cfg(target_arch = "aarch64")]
30    {
31        unsafe {
32            add_neon_impl(a, b, result);
33        }
34    }
35    #[cfg(not(target_arch = "aarch64"))]
36    {
37        // Fallback to scalar
38        for i in 0..a.len() {
39            result[i] = a[i] + b[i];
40        }
41    }
42}
43
44#[cfg(target_arch = "aarch64")]
45unsafe fn add_neon_impl(a: &[f32], b: &[f32], result: &mut [f32]) {
46    use std::arch::aarch64::*;
47    
48    let len = a.len();
49    let chunks = len / 4;
50    let remainder = len % 4;
51    
52    // Process 4 elements at a time using NEON
53    for i in 0..chunks {
54        let idx = i * 4;
55        
56        // Load 4 floats from a and b
57        let va = vld1q_f32(a.as_ptr().add(idx));
58        let vb = vld1q_f32(b.as_ptr().add(idx));
59        
60        // Add vectors
61        let vc = vaddq_f32(va, vb);
62        
63        // Store result
64        vst1q_f32(result.as_mut_ptr().add(idx), vc);
65    }
66    
67    // Handle remainder
68    for i in (chunks * 4)..len {
69        result[i] = a[i] + b[i];
70    }
71}
72
73/// NEON-optimized vector multiplication
74pub fn mul_neon(a: &[f32], b: &[f32], result: &mut [f32]) {
75    assert_eq!(a.len(), b.len());
76    assert_eq!(a.len(), result.len());
77    
78    #[cfg(target_arch = "aarch64")]
79    {
80        unsafe {
81            mul_neon_impl(a, b, result);
82        }
83    }
84    #[cfg(not(target_arch = "aarch64"))]
85    {
86        for i in 0..a.len() {
87            result[i] = a[i] * b[i];
88        }
89    }
90}
91
92#[cfg(target_arch = "aarch64")]
93unsafe fn mul_neon_impl(a: &[f32], b: &[f32], result: &mut [f32]) {
94    use std::arch::aarch64::*;
95    
96    let len = a.len();
97    let chunks = len / 4;
98    
99    for i in 0..chunks {
100        let idx = i * 4;
101        let va = vld1q_f32(a.as_ptr().add(idx));
102        let vb = vld1q_f32(b.as_ptr().add(idx));
103        let vc = vmulq_f32(va, vb);
104        vst1q_f32(result.as_mut_ptr().add(idx), vc);
105    }
106    
107    for i in (chunks * 4)..len {
108        result[i] = a[i] * b[i];
109    }
110}
111
112/// NEON-optimized dot product
113pub fn dot_neon(a: &[f32], b: &[f32]) -> f32 {
114    assert_eq!(a.len(), b.len());
115    
116    #[cfg(target_arch = "aarch64")]
117    {
118        unsafe { dot_neon_impl(a, b) }
119    }
120    #[cfg(not(target_arch = "aarch64"))]
121    {
122        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
123    }
124}
125
126#[cfg(target_arch = "aarch64")]
127unsafe fn dot_neon_impl(a: &[f32], b: &[f32]) -> f32 {
128    use std::arch::aarch64::*;
129    
130    let len = a.len();
131    let chunks = len / 4;
132    
133    // Accumulator vector
134    let mut acc = vdupq_n_f32(0.0);
135    
136    for i in 0..chunks {
137        let idx = i * 4;
138        let va = vld1q_f32(a.as_ptr().add(idx));
139        let vb = vld1q_f32(b.as_ptr().add(idx));
140        
141        // Multiply and accumulate
142        acc = vfmaq_f32(acc, va, vb);
143    }
144    
145    // Horizontal sum of accumulator
146    let mut sum = vaddvq_f32(acc);
147    
148    // Handle remainder
149    for i in (chunks * 4)..len {
150        sum += a[i] * b[i];
151    }
152    
153    sum
154}
155
156/// NEON-optimized ReLU
157pub fn relu_neon(data: &mut [f32]) {
158    #[cfg(target_arch = "aarch64")]
159    {
160        unsafe {
161            relu_neon_impl(data);
162        }
163    }
164    #[cfg(not(target_arch = "aarch64"))]
165    {
166        for x in data.iter_mut() {
167            *x = x.max(0.0);
168        }
169    }
170}
171
172#[cfg(target_arch = "aarch64")]
173unsafe fn relu_neon_impl(data: &mut [f32]) {
174    use std::arch::aarch64::*;
175    
176    let len = data.len();
177    let chunks = len / 4;
178    let zero = vdupq_n_f32(0.0);
179    
180    for i in 0..chunks {
181        let idx = i * 4;
182        let v = vld1q_f32(data.as_ptr().add(idx));
183        let result = vmaxq_f32(v, zero);
184        vst1q_f32(data.as_mut_ptr().add(idx), result);
185    }
186    
187    for i in (chunks * 4)..len {
188        data[i] = data[i].max(0.0);
189    }
190}
191
192/// NEON-optimized sigmoid
193pub fn sigmoid_neon(data: &mut [f32]) {
194    #[cfg(target_arch = "aarch64")]
195    {
196        unsafe {
197            sigmoid_neon_impl(data);
198        }
199    }
200    #[cfg(not(target_arch = "aarch64"))]
201    {
202        for x in data.iter_mut() {
203            *x = 1.0 / (1.0 + (-*x).exp());
204        }
205    }
206}
207
208#[cfg(target_arch = "aarch64")]
209unsafe fn sigmoid_neon_impl(data: &mut [f32]) {
210    // NEON doesn't have native exp, so we use scalar for now
211    // In production, would use a fast approximation
212    for x in data.iter_mut() {
213        *x = 1.0 / (1.0 + (-*x).exp());
214    }
215}
216
217/// NEON-optimized matrix multiplication (simplified)
218pub fn matmul_neon(
219    a: &[f32],
220    b: &[f32],
221    result: &mut [f32],
222    m: usize,
223    n: usize,
224    k: usize,
225) {
226    #[cfg(target_arch = "aarch64")]
227    {
228        unsafe {
229            matmul_neon_impl(a, b, result, m, n, k);
230        }
231    }
232    #[cfg(not(target_arch = "aarch64"))]
233    {
234        // Fallback to scalar
235        for i in 0..m {
236            for j in 0..n {
237                let mut sum = 0.0;
238                for p in 0..k {
239                    sum += a[i * k + p] * b[p * n + j];
240                }
241                result[i * n + j] = sum;
242            }
243        }
244    }
245}
246
247#[cfg(target_arch = "aarch64")]
248unsafe fn matmul_neon_impl(
249    a: &[f32],
250    b: &[f32],
251    result: &mut [f32],
252    m: usize,
253    n: usize,
254    k: usize,
255) {
256    use std::arch::aarch64::*;
257    
258    // Simplified NEON matmul - production would use blocking and better optimization
259    for i in 0..m {
260        for j in 0..n {
261            let mut acc = vdupq_n_f32(0.0);
262            let chunks = k / 4;
263            
264            for p in 0..chunks {
265                let idx = p * 4;
266                let va = vld1q_f32(a.as_ptr().add(i * k + idx));
267                let vb = vld1q_f32(b.as_ptr().add(idx * n + j));
268                acc = vfmaq_f32(acc, va, vb);
269            }
270            
271            let mut sum = vaddvq_f32(acc);
272            
273            // Handle remainder
274            for p in (chunks * 4)..k {
275                sum += a[i * k + p] * b[p * n + j];
276            }
277            
278            result[i * n + j] = sum;
279        }
280    }
281}
282
283/// NEON-optimized convolution (simplified 2D)
284pub fn conv2d_neon(
285    input: &[f32],
286    kernel: &[f32],
287    output: &mut [f32],
288    input_h: usize,
289    input_w: usize,
290    kernel_h: usize,
291    kernel_w: usize,
292) {
293    let output_h = input_h - kernel_h + 1;
294    let output_w = input_w - kernel_w + 1;
295    
296    #[cfg(target_arch = "aarch64")]
297    {
298        unsafe {
299            conv2d_neon_impl(input, kernel, output, input_h, input_w, kernel_h, kernel_w, output_h, output_w);
300        }
301    }
302    #[cfg(not(target_arch = "aarch64"))]
303    {
304        // Scalar fallback
305        for i in 0..output_h {
306            for j in 0..output_w {
307                let mut sum = 0.0;
308                for ki in 0..kernel_h {
309                    for kj in 0..kernel_w {
310                        sum += input[(i + ki) * input_w + (j + kj)] * kernel[ki * kernel_w + kj];
311                    }
312                }
313                output[i * output_w + j] = sum;
314            }
315        }
316    }
317}
318
319#[cfg(target_arch = "aarch64")]
320unsafe fn conv2d_neon_impl(
321    input: &[f32],
322    kernel: &[f32],
323    output: &mut [f32],
324    input_h: usize,
325    input_w: usize,
326    kernel_h: usize,
327    kernel_w: usize,
328    output_h: usize,
329    output_w: usize,
330) {
331    use std::arch::aarch64::*;
332    
333    // Simplified - production would use im2col or Winograd
334    for i in 0..output_h {
335        for j in 0..output_w {
336            let mut acc = vdupq_n_f32(0.0);
337            
338            for ki in 0..kernel_h {
339                for kj in 0..kernel_w {
340                    let input_val = input[(i + ki) * input_w + (j + kj)];
341                    let kernel_val = kernel[ki * kernel_w + kj];
342                    let v_input = vdupq_n_f32(input_val);
343                    let v_kernel = vdupq_n_f32(kernel_val);
344                    acc = vfmaq_f32(acc, v_input, v_kernel);
345                }
346            }
347            
348            output[i * output_w + j] = vaddvq_f32(acc);
349        }
350    }
351}
352
353/// Tensor operations with NEON acceleration
354impl Tensor {
355    /// Add two tensors using NEON
356    pub fn add_neon(&self, other: &Tensor) -> Result<Tensor> {
357        let a = self.data_f32();
358        let b = other.data_f32();
359        let mut result = vec![0.0; a.len()];
360        
361        add_neon(&a, &b, &mut result);
362        
363        Tensor::from_slice(&result, self.dims())
364    }
365    
366    /// Multiply two tensors using NEON
367    pub fn mul_neon(&self, other: &Tensor) -> Result<Tensor> {
368        let a = self.data_f32();
369        let b = other.data_f32();
370        let mut result = vec![0.0; a.len()];
371        
372        mul_neon(&a, &b, &mut result);
373        
374        Tensor::from_slice(&result, self.dims())
375    }
376    
377    /// ReLU activation using NEON
378    pub fn relu_neon(&self) -> Tensor {
379        let mut data = self.data_f32();
380        relu_neon(&mut data);
381        Tensor::from_slice(&data, self.dims()).unwrap()
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    
389    #[test]
390    fn test_neon_availability() {
391        let available = is_neon_available();
392        #[cfg(target_arch = "aarch64")]
393        assert!(available);
394    }
395    
396    #[test]
397    fn test_add_neon() {
398        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
399        let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
400        let mut result = vec![0.0; 8];
401        
402        add_neon(&a, &b, &mut result);
403        
404        assert_eq!(result, vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
405    }
406    
407    #[test]
408    fn test_dot_neon() {
409        let a = vec![1.0, 2.0, 3.0, 4.0];
410        let b = vec![1.0, 1.0, 1.0, 1.0];
411        
412        let result = dot_neon(&a, &b);
413        assert_eq!(result, 10.0);
414    }
415    
416    #[test]
417    fn test_relu_neon() {
418        let mut data = vec![-1.0, 2.0, -3.0, 4.0];
419        relu_neon(&mut data);
420        assert_eq!(data, vec![0.0, 2.0, 0.0, 4.0]);
421    }
422}