ghostflow_core/ops/
simd.rs1#[cfg(feature = "rayon")]
5use rayon::prelude::*;
6
7#[inline]
9pub fn relu_simd(data: &[f32]) -> Vec<f32> {
10 #[cfg(target_feature = "avx2")]
11 {
12 relu_avx2(data)
13 }
14
15 #[cfg(all(not(target_feature = "avx2"), target_feature = "sse2"))]
16 {
17 relu_sse2(data)
18 }
19
20 #[cfg(not(any(target_feature = "avx2", target_feature = "sse2")))]
21 {
22 relu_scalar(data)
23 }
24}
25
26#[cfg(target_feature = "avx2")]
28#[inline]
29fn relu_avx2(data: &[f32]) -> Vec<f32> {
30 use std::arch::x86_64::*;
31
32 let mut result = Vec::with_capacity(data.len());
33 unsafe {
34 let zero = _mm256_setzero_ps();
35 let chunks = data.chunks_exact(8);
36 let remainder = chunks.remainder();
37
38 for chunk in chunks {
39 let vec = _mm256_loadu_ps(chunk.as_ptr());
40 let max = _mm256_max_ps(vec, zero);
41 let mut out = [0.0f32; 8];
42 _mm256_storeu_ps(out.as_mut_ptr(), max);
43 result.extend_from_slice(&out);
44 }
45
46 result.extend(remainder.iter().map(|&x| x.max(0.0)));
48 }
49 result
50}
51
52#[cfg(target_feature = "sse2")]
54#[inline]
55fn relu_sse2(data: &[f32]) -> Vec<f32> {
56 use std::arch::x86_64::*;
57
58 let mut result = Vec::with_capacity(data.len());
59 unsafe {
60 let zero = _mm_setzero_ps();
61 let chunks = data.chunks_exact(4);
62 let remainder = chunks.remainder();
63
64 for chunk in chunks {
65 let vec = _mm_loadu_ps(chunk.as_ptr());
66 let max = _mm_max_ps(vec, zero);
67 let mut out = [0.0f32; 4];
68 _mm_storeu_ps(out.as_mut_ptr(), max);
69 result.extend_from_slice(&out);
70 }
71
72 result.extend(remainder.iter().map(|&x| x.max(0.0)));
74 }
75 result
76}
77
78#[allow(dead_code)]
80#[inline]
81fn relu_scalar(data: &[f32]) -> Vec<f32> {
82 data.iter().map(|&x| x.max(0.0)).collect()
83}
84
85#[inline]
87pub fn sigmoid_simd(data: &[f32]) -> Vec<f32> {
88 data.iter()
90 .map(|&x| {
91 1.0 / (1.0 + fast_exp(-x))
94 })
95 .collect()
96}
97
98#[inline]
100fn fast_exp(x: f32) -> f32 {
101 let x = x.clamp(-88.0, 88.0);
103
104 if x < 0.0 {
107 let x = -x;
108 let x2 = x * x;
109 let x3 = x2 * x;
110 let x4 = x2 * x2;
111 1.0 / (1.0 + x + x2 * 0.5 + x3 * 0.16666667 + x4 * 0.041666667)
112 } else {
113 let x2 = x * x;
114 let x3 = x2 * x;
115 let x4 = x2 * x2;
116 1.0 + x + x2 * 0.5 + x3 * 0.16666667 + x4 * 0.041666667
117 }
118}
119
120#[inline]
122pub fn gelu_simd(data: &[f32]) -> Vec<f32> {
123 const SQRT_2_OVER_PI: f32 = 0.797_884_6;
124 const COEFF: f32 = 0.044715;
125
126 data.iter()
127 .map(|&x| {
128 let inner = SQRT_2_OVER_PI * (x + COEFF * x.powi(3));
129 0.5 * x * (1.0 + fast_tanh(inner))
130 })
131 .collect()
132}
133
134#[inline]
136fn fast_tanh(x: f32) -> f32 {
137 let x = x.clamp(-3.0, 3.0);
139
140 let x2 = x * x;
142 x * (27.0 + x2) / (27.0 + 9.0 * x2)
143}
144
145#[inline]
147pub fn add_simd(a: &[f32], b: &[f32]) -> Vec<f32> {
148 #[cfg(target_feature = "avx2")]
149 {
150 add_avx2(a, b)
151 }
152
153 #[cfg(all(not(target_feature = "avx2"), target_feature = "sse2"))]
154 {
155 add_sse2(a, b)
156 }
157
158 #[cfg(not(any(target_feature = "avx2", target_feature = "sse2")))]
159 {
160 add_scalar(a, b)
161 }
162}
163
164#[cfg(target_feature = "avx2")]
165#[inline]
166fn add_avx2(a: &[f32], b: &[f32]) -> Vec<f32> {
167 use std::arch::x86_64::*;
168
169 let mut result = Vec::with_capacity(a.len());
170 unsafe {
171 let chunks_a = a.chunks_exact(8);
172 let chunks_b = b.chunks_exact(8);
173 let remainder_a = chunks_a.remainder();
174 let remainder_b = chunks_b.remainder();
175
176 for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
177 let vec_a = _mm256_loadu_ps(chunk_a.as_ptr());
178 let vec_b = _mm256_loadu_ps(chunk_b.as_ptr());
179 let sum = _mm256_add_ps(vec_a, vec_b);
180 let mut out = [0.0f32; 8];
181 _mm256_storeu_ps(out.as_mut_ptr(), sum);
182 result.extend_from_slice(&out);
183 }
184
185 result.extend(remainder_a.iter().zip(remainder_b.iter()).map(|(&x, &y)| x + y));
187 }
188 result
189}
190
191#[cfg(target_feature = "sse2")]
192#[inline]
193fn add_sse2(a: &[f32], b: &[f32]) -> Vec<f32> {
194 use std::arch::x86_64::*;
195
196 let mut result = Vec::with_capacity(a.len());
197 unsafe {
198 let chunks_a = a.chunks_exact(4);
199 let chunks_b = b.chunks_exact(4);
200 let remainder_a = chunks_a.remainder();
201 let remainder_b = chunks_b.remainder();
202
203 for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
204 let vec_a = _mm_loadu_ps(chunk_a.as_ptr());
205 let vec_b = _mm_loadu_ps(chunk_b.as_ptr());
206 let sum = _mm_add_ps(vec_a, vec_b);
207 let mut out = [0.0f32; 4];
208 _mm_storeu_ps(out.as_mut_ptr(), sum);
209 result.extend_from_slice(&out);
210 }
211
212 result.extend(remainder_a.iter().zip(remainder_b.iter()).map(|(&x, &y)| x + y));
214 }
215 result
216}
217
218#[allow(dead_code)]
219#[inline]
220fn add_scalar(a: &[f32], b: &[f32]) -> Vec<f32> {
221 a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect()
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_relu_simd() {
230 let data = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
231 let result = relu_simd(&data);
232 assert_eq!(result, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
233 }
234
235 #[test]
236 fn test_sigmoid_simd() {
237 let data = vec![0.0];
238 let result = sigmoid_simd(&data);
239 assert!((result[0] - 0.5).abs() < 0.01);
240 }
241
242 #[test]
243 fn test_add_simd() {
244 let a = vec![1.0, 2.0, 3.0, 4.0];
245 let b = vec![5.0, 6.0, 7.0, 8.0];
246 let result = add_simd(&a, &b);
247 assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
248 }
249}