1use crate::traits::SimdError;
8#[cfg(target_arch = "x86")]
9use core::arch::x86::*;
10#[cfg(target_arch = "x86_64")]
11use core::arch::x86_64::*;
12
13#[cfg(feature = "no-std")]
14use alloc::{vec, vec::Vec};
15#[cfg(not(feature = "no-std"))]
16use std::{vec, vec::Vec};
17
18pub struct ConvolutionParams {
20 pub input_shape: (usize, usize, usize),
22 pub kernel_shape: (usize, usize, usize),
24 pub stride: usize,
26 pub padding: usize,
28}
29
30struct BlockRange {
32 i_start: usize,
33 j_start: usize,
34 k_start: usize,
35 i_end: usize,
36 j_end: usize,
37 k_end: usize,
38 n: usize,
39 k: usize,
40}
41
42pub struct AdvancedSimdOptimizer {
44 #[allow(dead_code)]
45 cache_line_size: usize,
47 #[allow(dead_code)] prefetch_distance: usize,
49 #[allow(dead_code)] vectorization_width: usize,
51}
52
53impl AdvancedSimdOptimizer {
54 pub fn new() -> Self {
56 Self {
57 cache_line_size: 64, prefetch_distance: 512, vectorization_width: 8, }
61 }
62
63 pub fn cache_aware_matrix_multiply(
65 &self,
66 a: &[f32],
67 b: &[f32],
68 c: &mut [f32],
69 m: usize,
70 n: usize,
71 k: usize,
72 ) -> Result<(), SimdError> {
73 if a.len() != m * k || b.len() != k * n || c.len() != m * n {
74 return Err(SimdError::DimensionMismatch {
75 expected: m * n,
76 actual: c.len(),
77 });
78 }
79
80 let block_size = self.calculate_optimal_block_size(m, n, k);
82
83 for i in (0..m).step_by(block_size) {
84 for j in (0..n).step_by(block_size) {
85 for kk in (0..k).step_by(block_size) {
86 let i_max = (i + block_size).min(m);
87 let j_max = (j + block_size).min(n);
88 let k_max = (kk + block_size).min(k);
89
90 self.matrix_multiply_block(
91 a,
92 b,
93 c,
94 &BlockRange {
95 i_start: i,
96 j_start: j,
97 k_start: kk,
98 i_end: i_max,
99 j_end: j_max,
100 k_end: k_max,
101 n,
102 k,
103 },
104 )?;
105 }
106 }
107 }
108
109 Ok(())
110 }
111
112 pub fn vectorized_dot_product(&self, a: &[f32], b: &[f32]) -> Result<f32, SimdError> {
114 if a.len() != b.len() {
115 return Err(SimdError::DimensionMismatch {
116 expected: a.len(),
117 actual: b.len(),
118 });
119 }
120
121 let len = a.len();
122 if len == 0 {
123 return Ok(0.0);
124 }
125
126 let mut result = 0.0f32;
127
128 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
129 {
130 if crate::simd_feature_detected!("avx2") {
131 return unsafe { self.dot_product_avx2(a, b) };
132 } else if crate::simd_feature_detected!("sse2") {
133 return unsafe { self.dot_product_sse2(a, b) };
134 }
135 }
136
137 let chunks = len / 4;
139 let remainder = len % 4;
140
141 for i in 0..chunks {
142 let base = i * 4;
143 result += a[base] * b[base]
144 + a[base + 1] * b[base + 1]
145 + a[base + 2] * b[base + 2]
146 + a[base + 3] * b[base + 3];
147 }
148
149 for i in (chunks * 4)..(chunks * 4 + remainder) {
150 result += a[i] * b[i];
151 }
152
153 Ok(result)
154 }
155
156 pub fn optimized_convolution(
158 &self,
159 input: &[f32],
160 kernel: &[f32],
161 output: &mut [f32],
162 params: &ConvolutionParams,
163 ) -> Result<(), SimdError> {
164 let (in_channels, in_height, in_width) = params.input_shape;
165 let (out_channels, k_height, k_width) = params.kernel_shape;
166 let stride = params.stride;
167 let padding = params.padding;
168
169 let out_height = (in_height + 2 * padding - k_height) / stride + 1;
170 let out_width = (in_width + 2 * padding - k_width) / stride + 1;
171
172 if output.len() != out_channels * out_height * out_width {
173 return Err(SimdError::DimensionMismatch {
174 expected: out_channels * out_height * out_width,
175 actual: output.len(),
176 });
177 }
178
179 let im2col_data = self.im2col_transform(
181 input,
182 params.input_shape,
183 params.kernel_shape,
184 stride,
185 padding,
186 )?;
187
188 self.cache_aware_matrix_multiply(
190 kernel,
191 &im2col_data,
192 output,
193 out_channels,
194 out_height * out_width,
195 in_channels * k_height * k_width,
196 )?;
197
198 Ok(())
199 }
200
201 pub fn vectorized_reduction(&self, data: &[f32], op: ReductionOp) -> Result<f32, SimdError> {
203 if data.is_empty() {
204 return Err(SimdError::EmptyInput);
205 }
206
207 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
208 {
209 if crate::simd_feature_detected!("avx2") {
210 return unsafe { self.reduction_avx2(data, op) };
211 }
212 }
213
214 match op {
216 ReductionOp::Sum => Ok(data.iter().sum()),
217 ReductionOp::Max => Ok(data.iter().copied().fold(f32::NEG_INFINITY, f32::max)),
218 ReductionOp::Min => Ok(data.iter().copied().fold(f32::INFINITY, f32::min)),
219 ReductionOp::Mean => Ok(data.iter().sum::<f32>() / data.len() as f32),
220 }
221 }
222
223 fn calculate_optimal_block_size(&self, _m: usize, _n: usize, _k: usize) -> usize {
226 let cache_size = 32768; let element_size = 4; let block_elements = cache_size / (3 * element_size); let block_size = (block_elements as f32).sqrt() as usize;
232 block_size.clamp(8, 64) }
234
235 fn matrix_multiply_block(
236 &self,
237 a: &[f32],
238 b: &[f32],
239 c: &mut [f32],
240 block: &BlockRange,
241 ) -> Result<(), SimdError> {
242 for i in block.i_start..block.i_end {
243 for j in block.j_start..block.j_end {
244 let mut sum = 0.0f32;
245 for kk in block.k_start..block.k_end {
246 sum += a[i * block.k + kk] * b[kk * block.n + j];
247 }
248 c[i * block.n + j] += sum;
249 }
250 }
251 Ok(())
252 }
253
254 fn im2col_transform(
255 &self,
256 input: &[f32],
257 input_shape: (usize, usize, usize),
258 kernel_shape: (usize, usize, usize),
259 stride: usize,
260 padding: usize,
261 ) -> Result<Vec<f32>, SimdError> {
262 let (in_channels, in_height, in_width) = input_shape;
263 let (_, k_height, k_width) = kernel_shape;
264
265 let out_height = (in_height + 2 * padding - k_height) / stride + 1;
266 let out_width = (in_width + 2 * padding - k_width) / stride + 1;
267
268 let mut result = vec![0.0f32; in_channels * k_height * k_width * out_height * out_width];
269
270 for c in 0..in_channels {
271 for kh in 0..k_height {
272 for kw in 0..k_width {
273 for oh in 0..out_height {
274 for ow in 0..out_width {
275 let ih = oh * stride + kh;
276 let iw = ow * stride + kw;
277
278 let value = if ih >= padding
279 && ih < in_height + padding
280 && iw >= padding
281 && iw < in_width + padding
282 {
283 let adjusted_ih = ih - padding;
284 let adjusted_iw = iw - padding;
285 input[c * in_height * in_width
286 + adjusted_ih * in_width
287 + adjusted_iw]
288 } else {
289 0.0f32
290 };
291
292 let col_idx = (c * k_height * k_width + kh * k_width + kw)
293 * out_height
294 * out_width
295 + oh * out_width
296 + ow;
297 result[col_idx] = value;
298 }
299 }
300 }
301 }
302 }
303
304 Ok(result)
305 }
306
307 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
308 #[target_feature(enable = "avx2")]
309 unsafe fn dot_product_avx2(&self, a: &[f32], b: &[f32]) -> Result<f32, SimdError> {
310 let len = a.len();
311 let mut sum = _mm256_setzero_ps();
312
313 let chunks = len / 8;
314 for i in 0..chunks {
315 let a_vec = _mm256_loadu_ps(a.as_ptr().add(i * 8));
316 let b_vec = _mm256_loadu_ps(b.as_ptr().add(i * 8));
317 let product = _mm256_mul_ps(a_vec, b_vec);
318 sum = _mm256_add_ps(sum, product);
319 }
320
321 let sum_high = _mm256_extractf128_ps(sum, 1);
323 let sum_low = _mm256_castps256_ps128(sum);
324 let sum128 = _mm_add_ps(sum_high, sum_low);
325
326 let mut result = [0.0f32; 4];
327 _mm_storeu_ps(result.as_mut_ptr(), sum128);
328 let mut final_sum = result[0] + result[1] + result[2] + result[3];
329
330 for i in (chunks * 8)..len {
332 final_sum += a[i] * b[i];
333 }
334
335 Ok(final_sum)
336 }
337
338 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
339 #[target_feature(enable = "sse2")]
340 unsafe fn dot_product_sse2(&self, a: &[f32], b: &[f32]) -> Result<f32, SimdError> {
341 let len = a.len();
342 let mut sum = _mm_setzero_ps();
343
344 let chunks = len / 4;
345 for i in 0..chunks {
346 let a_vec = _mm_loadu_ps(a.as_ptr().add(i * 4));
347 let b_vec = _mm_loadu_ps(b.as_ptr().add(i * 4));
348 let product = _mm_mul_ps(a_vec, b_vec);
349 sum = _mm_add_ps(sum, product);
350 }
351
352 let mut result = [0.0f32; 4];
353 _mm_storeu_ps(result.as_mut_ptr(), sum);
354 let mut final_sum = result[0] + result[1] + result[2] + result[3];
355
356 for i in (chunks * 4)..len {
358 final_sum += a[i] * b[i];
359 }
360
361 Ok(final_sum)
362 }
363
364 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
365 #[target_feature(enable = "avx2")]
366 unsafe fn reduction_avx2(&self, data: &[f32], op: ReductionOp) -> Result<f32, SimdError> {
367 let len = data.len();
368 let chunks = len / 8;
369
370 let mut accumulator = match op {
371 ReductionOp::Sum | ReductionOp::Mean => _mm256_setzero_ps(),
372 ReductionOp::Max => _mm256_set1_ps(f32::NEG_INFINITY),
373 ReductionOp::Min => _mm256_set1_ps(f32::INFINITY),
374 };
375
376 for i in 0..chunks {
377 let data_vec = _mm256_loadu_ps(data.as_ptr().add(i * 8));
378 accumulator = match op {
379 ReductionOp::Sum | ReductionOp::Mean => _mm256_add_ps(accumulator, data_vec),
380 ReductionOp::Max => _mm256_max_ps(accumulator, data_vec),
381 ReductionOp::Min => _mm256_min_ps(accumulator, data_vec),
382 };
383 }
384
385 let mut result = [0.0f32; 8];
387 _mm256_storeu_ps(result.as_mut_ptr(), accumulator);
388
389 let mut final_result = match op {
390 ReductionOp::Sum | ReductionOp::Mean => result.iter().sum::<f32>(),
391 ReductionOp::Max => result.iter().copied().fold(f32::NEG_INFINITY, f32::max),
392 ReductionOp::Min => result.iter().copied().fold(f32::INFINITY, f32::min),
393 };
394
395 for val in data.iter().take(len).skip(chunks * 8) {
397 final_result = match op {
398 ReductionOp::Sum | ReductionOp::Mean => final_result + *val,
399 ReductionOp::Max => final_result.max(*val),
400 ReductionOp::Min => final_result.min(*val),
401 };
402 }
403
404 if matches!(op, ReductionOp::Mean) {
405 final_result /= len as f32;
406 }
407
408 Ok(final_result)
409 }
410}
411
412impl Default for AdvancedSimdOptimizer {
413 fn default() -> Self {
414 Self::new()
415 }
416}
417
418#[derive(Debug, Clone, Copy)]
420pub enum ReductionOp {
421 Sum,
422 Max,
423 Min,
424 Mean,
425}
426
427pub struct CacheAwareSort;
429
430impl CacheAwareSort {
431 pub fn vectorized_merge_sort(data: &mut [f32]) {
433 if data.len() <= 1 {
434 return;
435 }
436
437 let mid = data.len() / 2;
438 Self::vectorized_merge_sort(&mut data[..mid]);
439 Self::vectorized_merge_sort(&mut data[mid..]);
440
441 let mut temp = vec![0.0f32; data.len()];
443 Self::cache_friendly_merge(data, &mut temp, mid);
444 data.copy_from_slice(&temp);
445 }
446
447 fn cache_friendly_merge(data: &[f32], temp: &mut [f32], mid: usize) {
448 let (left, right) = data.split_at(mid);
449 let mut i = 0;
450 let mut j = 0;
451 let mut k = 0;
452
453 while i < left.len() && j < right.len() {
454 if left[i] <= right[j] {
455 temp[k] = left[i];
456 i += 1;
457 } else {
458 temp[k] = right[j];
459 j += 1;
460 }
461 k += 1;
462 }
463
464 while i < left.len() {
465 temp[k] = left[i];
466 i += 1;
467 k += 1;
468 }
469
470 while j < right.len() {
471 temp[k] = right[j];
472 j += 1;
473 k += 1;
474 }
475 }
476}
477
478#[allow(non_snake_case)]
479#[cfg(all(test, not(feature = "no-std")))]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn test_vectorized_dot_product() {
485 let optimizer = AdvancedSimdOptimizer::new();
486 let a = vec![1.0, 2.0, 3.0, 4.0];
487 let b = vec![5.0, 6.0, 7.0, 8.0];
488
489 let result = optimizer
490 .vectorized_dot_product(&a, &b)
491 .expect("operation should succeed");
492 let expected = 1.0 * 5.0 + 2.0 * 6.0 + 3.0 * 7.0 + 4.0 * 8.0;
493
494 assert!((result - expected).abs() < 1e-6);
495 }
496
497 #[test]
498 fn test_vectorized_reduction() {
499 let optimizer = AdvancedSimdOptimizer::new();
500 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
501
502 let sum = optimizer
503 .vectorized_reduction(&data, ReductionOp::Sum)
504 .expect("operation should succeed");
505 assert_eq!(sum, 15.0);
506
507 let max = optimizer
508 .vectorized_reduction(&data, ReductionOp::Max)
509 .expect("operation should succeed");
510 assert_eq!(max, 5.0);
511
512 let min = optimizer
513 .vectorized_reduction(&data, ReductionOp::Min)
514 .expect("operation should succeed");
515 assert_eq!(min, 1.0);
516
517 let mean = optimizer
518 .vectorized_reduction(&data, ReductionOp::Mean)
519 .expect("operation should succeed");
520 assert_eq!(mean, 3.0);
521 }
522
523 #[test]
524 fn test_cache_aware_sort() {
525 let mut data = vec![5.0, 2.0, 8.0, 1.0, 9.0, 3.0];
526 CacheAwareSort::vectorized_merge_sort(&mut data);
527
528 let expected = vec![1.0, 2.0, 3.0, 5.0, 8.0, 9.0];
529 assert_eq!(data, expected);
530 }
531}