1use super::unified_dispatcher::global_dispatcher;
7use crate::array::Array;
8use crate::error::{NumRs2Error, Result};
9
10pub trait SimdArrayOps {
15 fn simd_add(&self, other: &Self) -> Result<Array<f32>>;
17
18 fn simd_mul(&self, other: &Self) -> Result<Array<f32>>;
20
21 fn simd_sum(&self) -> f32;
23
24 fn simd_exp(&self) -> Array<f32>;
26
27 fn simd_log(&self) -> Array<f32>;
29
30 fn simd_sin_cos(&self) -> (Array<f32>, Array<f32>);
32
33 fn simd_matmul(&self, other: &Self) -> Result<Array<f32>>;
35
36 fn simd_dot(&self, other: &Self) -> Result<f32>;
38
39 fn simd_copy(&self) -> Result<Array<f32>>;
41}
42
43impl SimdArrayOps for Array<f32> {
44 fn simd_add(&self, other: &Self) -> Result<Array<f32>> {
45 if self.shape() != other.shape() {
46 return Err(NumRs2Error::ShapeMismatch {
47 expected: self.shape(),
48 actual: other.shape(),
49 });
50 }
51
52 let self_data = self.to_vec();
54 let other_data = other.to_vec();
55 let mut result_data = vec![0.0f32; self_data.len()];
56
57 let dispatcher = global_dispatcher();
59 match dispatcher.implementation_info().name {
60 "AVX2" | "AVX-512" => {
61 #[cfg(target_arch = "x86_64")]
62 unsafe {
63 super::avx2_ops::avx2_add_f32(&self_data, &other_data, &mut result_data);
64 }
65 #[cfg(not(target_arch = "x86_64"))]
66 {
67 for i in 0..self_data.len() {
68 result_data[i] = self_data[i] + other_data[i];
69 }
70 }
71 }
72 "NEON" => {
73 #[cfg(target_arch = "aarch64")]
74 {
75 for i in 0..self_data.len() {
77 result_data[i] = self_data[i] + other_data[i];
78 }
79 }
80 #[cfg(not(target_arch = "aarch64"))]
81 {
82 for i in 0..self_data.len() {
83 result_data[i] = self_data[i] + other_data[i];
84 }
85 }
86 }
87 _ => {
88 for i in 0..self_data.len() {
90 result_data[i] = self_data[i] + other_data[i];
91 }
92 }
93 }
94
95 Ok(Array::from_vec(result_data).reshape(&self.shape()))
96 }
97
98 fn simd_mul(&self, other: &Self) -> Result<Array<f32>> {
99 if self.shape() != other.shape() {
100 return Err(NumRs2Error::ShapeMismatch {
101 expected: self.shape(),
102 actual: other.shape(),
103 });
104 }
105
106 let self_data = self.to_vec();
107 let other_data = other.to_vec();
108 let mut result_data = vec![0.0f32; self_data.len()];
109
110 let dispatcher = global_dispatcher();
112 match dispatcher.implementation_info().name {
113 "AVX2" | "AVX-512" => {
114 #[cfg(target_arch = "x86_64")]
115 unsafe {
116 super::avx2_ops::avx2_mul_f32(&self_data, &other_data, &mut result_data);
117 }
118 #[cfg(not(target_arch = "x86_64"))]
119 {
120 for i in 0..self_data.len() {
121 result_data[i] = self_data[i] * other_data[i];
122 }
123 }
124 }
125 _ => {
126 for i in 0..self_data.len() {
127 result_data[i] = self_data[i] * other_data[i];
128 }
129 }
130 }
131
132 Ok(Array::from_vec(result_data).reshape(&self.shape()))
133 }
134
135 fn simd_sum(&self) -> f32 {
136 global_dispatcher().optimized_sum_f32(self)
137 }
138
139 fn simd_exp(&self) -> Array<f32> {
140 global_dispatcher().optimized_exp_f32(self)
141 }
142
143 fn simd_log(&self) -> Array<f32> {
144 global_dispatcher().optimized_log_f32(self)
145 }
146
147 fn simd_sin_cos(&self) -> (Array<f32>, Array<f32>) {
148 global_dispatcher().optimized_sin_cos_f32(self)
149 }
150
151 fn simd_matmul(&self, other: &Self) -> Result<Array<f32>> {
152 global_dispatcher().optimized_matmul_f32(self, other)
153 }
154
155 fn simd_dot(&self, other: &Self) -> Result<f32> {
156 global_dispatcher().optimized_dot_f32(self, other)
157 }
158
159 fn simd_copy(&self) -> Result<Array<f32>> {
160 global_dispatcher().optimized_copy_f32(self)
161 }
162}
163
164#[macro_export]
166macro_rules! simd_array {
167 ($($x:expr),* $(,)?) => {
168 Array::from_vec(vec![$($x),*])
169 };
170 ($x:expr; $n:expr) => {
171 Array::from_vec(vec![$x; $n])
172 };
173}
174
175pub struct SimdPerformanceHints;
177
178impl SimdPerformanceHints {
179 pub fn optimal_array_size() -> usize {
181 let dispatcher = global_dispatcher();
182 match dispatcher.implementation_info().vector_width {
183 512 => 16 * 4, 256 => 8 * 4, 128 => 4 * 4, _ => 16, }
188 }
189
190 pub fn is_simd_friendly(size: usize) -> bool {
192 let dispatcher = global_dispatcher();
193 let vector_elements = match dispatcher.implementation_info().vector_width {
194 512 => 16, 256 => 8, 128 => 4, _ => 4, };
199
200 size.is_multiple_of(vector_elements) && size >= vector_elements * 2
201 }
202
203 pub fn alignment_requirement() -> usize {
205 let dispatcher = global_dispatcher();
206 dispatcher.implementation_info().vector_width / 8 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::math::ElementWiseMath;
214 use crate::stats::Statistics;
215 use approx::assert_relative_eq;
216
217 #[test]
218 fn test_simd_array_ops() {
219 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
220 let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
221
222 let sum = a
224 .simd_add(&b)
225 .expect("simd_add should succeed with equal-sized arrays");
226 assert_eq!(sum.to_vec(), vec![6.0, 8.0, 10.0, 12.0]);
227
228 let product = a
229 .simd_mul(&b)
230 .expect("simd_mul should succeed with equal-sized arrays");
231 assert_eq!(product.to_vec(), vec![5.0, 12.0, 21.0, 32.0]);
232 }
233
234 #[test]
235 fn test_simd_reductions() {
236 let array = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
237
238 let sum = array.simd_sum();
239 assert_relative_eq!(sum, 10.0, epsilon = 1e-6);
240
241 let mean = array.mean();
242 assert_relative_eq!(mean, 2.5, epsilon = 1e-6);
243 }
244
245 #[test]
246 fn test_simd_math_functions() {
247 let array = Array::from_vec(vec![1.0, 4.0, 9.0, 16.0]);
248
249 let sqrt_result = array.sqrt();
250 assert_relative_eq!(sqrt_result.to_vec()[0], 1.0, epsilon = 1e-6);
251 assert_relative_eq!(sqrt_result.to_vec()[1], 2.0, epsilon = 1e-6);
252 assert_relative_eq!(sqrt_result.to_vec()[2], 3.0, epsilon = 1e-6);
253 assert_relative_eq!(sqrt_result.to_vec()[3], 4.0, epsilon = 1e-6);
254
255 let exp_input = Array::from_vec(vec![0.0, 1.0]);
256 let exp_result = exp_input.simd_exp();
257
258 let result_vec = exp_result.to_vec();
260 println!("exp_result values: {:?}", result_vec);
261 println!("Expected: [1.0, {}]", std::f32::consts::E);
262
263 #[cfg(target_arch = "x86_64")]
265 {
266 let direct_result =
267 crate::simd_optimize::avx2_enhanced::EnhancedSimdOps::vectorized_exp_f32(
268 &exp_input,
269 );
270 let direct_vec = direct_result.to_vec();
271 println!("Direct AVX2 result: {:?}", direct_vec);
272 assert_relative_eq!(direct_vec[0], 1.0, epsilon = 1e-6);
273 assert_relative_eq!(direct_vec[1], std::f32::consts::E, epsilon = 1e-5);
274 }
275
276 #[cfg(not(target_arch = "x86_64"))]
277 {
278 let fallback_result = exp_input.map(|x| x.exp());
280 let fallback_vec = fallback_result.to_vec();
281 assert_relative_eq!(fallback_vec[0], 1.0, epsilon = 1e-6);
282 assert_relative_eq!(fallback_vec[1], std::f32::consts::E, epsilon = 1e-5);
283 }
284 }
285
286 #[test]
287 fn test_performance_hints() {
288 let optimal_size = SimdPerformanceHints::optimal_array_size();
289 assert!(optimal_size >= 16);
290
291 let is_friendly = SimdPerformanceHints::is_simd_friendly(64);
292 println!("Size 64 is SIMD-friendly: {}", is_friendly);
293
294 let alignment = SimdPerformanceHints::alignment_requirement();
295 assert!(alignment >= 16);
296 }
297
298 #[test]
299 fn test_simd_array_macro() {
300 let array = simd_array![1.0, 2.0, 3.0, 4.0];
301 assert_eq!(array.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
302 }
303}