Skip to main content

sklears_simd/
target.rs

1//! Target-specific SIMD optimizations and compilation support
2//!
3//! This module provides target-specific optimizations for different CPU architectures
4//! and supports compile-time feature selection for optimal performance.
5
6use crate::traits::{SimdError, VectorArithmetic, VectorReduction};
7
8#[cfg(feature = "no-std")]
9use alloc::vec::Vec;
10
11/// Target-specific optimization levels
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum OptimizationTarget {
14    /// Generic optimization for any x86-64 CPU
15    Generic,
16    /// Optimize for Intel Haswell and newer (AVX2, FMA)
17    Haswell,
18    /// Optimize for Intel Skylake and newer (AVX2, FMA, enhanced instructions)
19    Skylake,
20    /// Optimize for AMD Zen and newer
21    Zen,
22    /// Optimize for ARM Cortex-A76 and newer
23    CortexA76,
24    /// Optimize for Apple Silicon M1/M2
25    AppleSilicon,
26    /// Optimize for Intel Granite Rapids with AVX10.1
27    GraniteRapids,
28    /// Optimize for Intel Diamond Rapids with enhanced AVX10
29    DiamondRapids,
30    /// Optimize for AMD Zen 5 with AVX-512
31    Zen5,
32    /// Optimize for server workloads (high throughput)
33    Server,
34    /// Optimize for mobile/embedded (low power)
35    Mobile,
36}
37
38/// Compile-time target configuration
39#[derive(Debug)]
40pub struct TargetConfig {
41    pub optimization_target: OptimizationTarget,
42    pub enable_fma: bool,
43    pub enable_avx512: bool,
44    pub enable_avx10: bool,
45    pub enable_amx: bool,
46    pub enable_sve2: bool,
47    pub enable_sme: bool,
48    pub enable_fp16: bool,
49    pub enable_bf16: bool,
50    pub prefer_throughput: bool,
51    pub prefer_latency: bool,
52}
53
54impl Default for TargetConfig {
55    fn default() -> Self {
56        Self {
57            optimization_target: OptimizationTarget::Generic,
58            enable_fma: cfg!(target_feature = "fma"),
59            enable_avx512: cfg!(target_feature = "avx512f"),
60            enable_avx10: false, // Not yet supported in stable Rust
61            enable_amx: false,   // Intel AMX not yet in stable Rust
62            enable_sve2: false,  // ARM SVE2 not yet in stable Rust
63            enable_sme: false,   // ARM SME not yet in stable Rust
64            enable_fp16: false,  // FP16 support is library-based
65            enable_bf16: false,  // BF16 support is library-based
66            prefer_throughput: true,
67            prefer_latency: false,
68        }
69    }
70}
71
72/// Target-specific vector operations dispatcher
73pub struct TargetOptimizedOps {
74    config: TargetConfig,
75}
76
77impl TargetOptimizedOps {
78    /// Create a new target-optimized operations instance
79    pub fn new(config: TargetConfig) -> Self {
80        Self { config }
81    }
82
83    /// Create with automatic target detection
84    pub fn auto_detect() -> Self {
85        let config = TargetConfig {
86            optimization_target: detect_optimization_target(),
87            enable_fma: {
88                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
89                {
90                    Self::detect_fma()
91                }
92                #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
93                {
94                    false
95                }
96            },
97            enable_avx512: {
98                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
99                {
100                    Self::detect_avx512()
101                }
102                #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
103                {
104                    false
105                }
106            },
107            enable_avx10: Self::detect_avx10(),
108            enable_amx: Self::detect_amx(),
109            enable_sve2: Self::detect_sve2(),
110            enable_sme: Self::detect_sme(),
111            enable_fp16: Self::detect_fp16(),
112            enable_bf16: Self::detect_bf16(),
113            prefer_throughput: true,
114            prefer_latency: false,
115        };
116
117        Self::new(config)
118    }
119
120    /// Detect FMA support
121    #[allow(dead_code)] // Used in offline detection queries; not yet wired to the config
122    fn detect_fma() -> bool {
123        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
124        {
125            crate::simd_feature_detected!("fma") || cfg!(target_feature = "fma")
126        }
127        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
128        {
129            false
130        }
131    }
132
133    /// Detect AVX-512 support
134    #[allow(dead_code)] // Used in offline detection queries; not yet wired to the config
135    fn detect_avx512() -> bool {
136        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
137        {
138            crate::simd_feature_detected!("avx512f") || cfg!(target_feature = "avx512f")
139        }
140        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
141        {
142            false
143        }
144    }
145
146    /// Detect AVX10 support (future Intel processors)
147    fn detect_avx10() -> bool {
148        // AVX10 is not yet available in stable Rust
149        // This is a placeholder for future implementation
150        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
151        {
152            // In the future, this would use crate::simd_feature_detected!("avx10.1")
153            false
154        }
155        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
156        {
157            false
158        }
159    }
160
161    /// Detect Intel AMX (Advanced Matrix Extensions) support
162    fn detect_amx() -> bool {
163        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
164        {
165            // AMX is not yet detectable via is_x86_feature_detected in stable Rust
166            // This would check for AMX-BF16, AMX-INT8, AMX-TILE in the future
167            false
168        }
169        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
170        {
171            false
172        }
173    }
174
175    /// Detect ARM SVE2 (Scalable Vector Extensions 2) support
176    fn detect_sve2() -> bool {
177        #[cfg(target_arch = "aarch64")]
178        {
179            // SVE2 detection is not yet available in stable Rust
180            // This would use is_aarch64_feature_detected!("sve2") in the future
181            false
182        }
183        #[cfg(not(target_arch = "aarch64"))]
184        {
185            false
186        }
187    }
188
189    /// Detect ARM SME (Scalable Matrix Extensions) support
190    fn detect_sme() -> bool {
191        #[cfg(target_arch = "aarch64")]
192        {
193            // SME detection is not yet available in stable Rust
194            // This would use is_aarch64_feature_detected!("sme") in the future
195            false
196        }
197        #[cfg(not(target_arch = "aarch64"))]
198        {
199            false
200        }
201    }
202
203    /// Detect hardware FP16 support
204    fn detect_fp16() -> bool {
205        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
206        {
207            // Intel F16C extension
208            crate::simd_feature_detected!("f16c")
209        }
210        #[cfg(target_arch = "aarch64")]
211        {
212            // ARM NEON has native FP16 support
213            true
214        }
215        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
216        {
217            false
218        }
219    }
220
221    /// Detect BF16 hardware support
222    fn detect_bf16() -> bool {
223        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
224        {
225            // Intel AVX512-BF16 or AMX-BF16
226            // Not yet detectable in stable Rust
227            false
228        }
229        #[cfg(target_arch = "aarch64")]
230        {
231            // ARM BF16 support in recent processors
232            // Not yet detectable in stable Rust
233            false
234        }
235        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
236        {
237            false
238        }
239    }
240
241    /// Get target-specific SIMD width for f32 operations
242    pub fn optimal_f32_width(&self) -> usize {
243        match self.config.optimization_target {
244            OptimizationTarget::Generic => 4,
245            OptimizationTarget::Haswell | OptimizationTarget::Skylake | OptimizationTarget::Zen => {
246                if self.config.enable_avx512 {
247                    16
248                } else {
249                    8
250                }
251            }
252            OptimizationTarget::GraniteRapids | OptimizationTarget::DiamondRapids => {
253                if self.config.enable_avx10 || self.config.enable_avx512 {
254                    16 // AVX10 unified or AVX-512 512-bit operations
255                } else {
256                    8
257                }
258            }
259            OptimizationTarget::Zen5 => {
260                if self.config.enable_avx512 {
261                    16 // AMD Zen 5 has full AVX-512 support
262                } else {
263                    8
264                }
265            }
266            OptimizationTarget::AppleSilicon | OptimizationTarget::CortexA76 => {
267                if self.config.enable_sve2 {
268                    8 // SVE2 with scalable vector width (conservative estimate)
269                } else {
270                    4 // NEON 128-bit
271                }
272            }
273            OptimizationTarget::Server => {
274                if self.config.enable_avx10 || self.config.enable_avx512 {
275                    16
276                } else {
277                    8
278                }
279            }
280            OptimizationTarget::Mobile => 4,
281        }
282    }
283
284    /// Get target-specific cache line size
285    pub fn cache_line_size(&self) -> usize {
286        match self.config.optimization_target {
287            OptimizationTarget::AppleSilicon => 128, // Apple Silicon has 128-byte cache lines
288            _ => 64,                                 // Most x86-64 CPUs have 64-byte cache lines
289        }
290    }
291
292    /// Get target-specific prefetch distance
293    pub fn prefetch_distance(&self) -> usize {
294        match self.config.optimization_target {
295            OptimizationTarget::Server => 512, // Aggressive prefetching for server workloads
296            OptimizationTarget::Mobile => 128, // Conservative prefetching for mobile
297            _ => 256,                          // Balanced prefetching
298        }
299    }
300}
301
302/// Detect the optimal target based on CPU features
303fn detect_optimization_target() -> OptimizationTarget {
304    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
305    {
306        if crate::simd_feature_detected!("avx2") && crate::simd_feature_detected!("fma") {
307            if crate::simd_feature_detected!("avx512f") {
308                OptimizationTarget::Skylake
309            } else {
310                OptimizationTarget::Haswell
311            }
312        } else {
313            OptimizationTarget::Generic
314        }
315    }
316
317    #[cfg(target_arch = "aarch64")]
318    {
319        // Detect Apple Silicon vs other ARM64
320        if cfg!(target_os = "macos") {
321            OptimizationTarget::AppleSilicon
322        } else {
323            OptimizationTarget::CortexA76
324        }
325    }
326
327    #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
328    {
329        OptimizationTarget::Generic
330    }
331}
332
333/// Target-specific vector arithmetic implementation
334impl VectorArithmetic<f32> for TargetOptimizedOps {
335    fn add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
336        if a.len() != b.len() {
337            return Err(SimdError::DimensionMismatch {
338                expected: a.len(),
339                actual: b.len(),
340            });
341        }
342
343        let width = self.optimal_f32_width();
344        match width {
345            16 | 8 => self.add_avx2(a, b),
346            4 => self.add_sse(a, b),
347            _ => Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect()),
348        }
349    }
350
351    fn sub(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
352        if a.len() != b.len() {
353            return Err(SimdError::DimensionMismatch {
354                expected: a.len(),
355                actual: b.len(),
356            });
357        }
358
359        let width = self.optimal_f32_width();
360        match width {
361            16 | 8 => self.sub_avx2(a, b),
362            4 => self.sub_sse(a, b),
363            _ => Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect()),
364        }
365    }
366
367    fn mul(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
368        if a.len() != b.len() {
369            return Err(SimdError::DimensionMismatch {
370                expected: a.len(),
371                actual: b.len(),
372            });
373        }
374
375        let width = self.optimal_f32_width();
376        match width {
377            16 | 8 => self.mul_avx2(a, b),
378            4 => self.mul_sse(a, b),
379            _ => Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect()),
380        }
381    }
382
383    fn div(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
384        if a.len() != b.len() {
385            return Err(SimdError::DimensionMismatch {
386                expected: a.len(),
387                actual: b.len(),
388            });
389        }
390
391        let width = self.optimal_f32_width();
392        match width {
393            16 | 8 => self.div_avx2(a, b),
394            4 => self.div_sse(a, b),
395            _ => Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x / y).collect()),
396        }
397    }
398
399    fn fma(&self, a: &[f32], b: &[f32], c: &[f32]) -> Result<Vec<f32>, SimdError> {
400        if a.len() != b.len() || a.len() != c.len() {
401            return Err(SimdError::DimensionMismatch {
402                expected: a.len(),
403                actual: b.len().min(c.len()),
404            });
405        }
406
407        if self.config.enable_fma {
408            let width = self.optimal_f32_width();
409            match width {
410                16 | 8 => self.fma_avx2(a, b, c),
411                4 => self.fma_sse(a, b, c),
412                _ => Ok(a
413                    .iter()
414                    .zip(b.iter())
415                    .zip(c.iter())
416                    .map(|((&x, &y), &z)| x * y + z)
417                    .collect()),
418            }
419        } else {
420            // Fallback to separate multiply and add
421            let mul_result = self.mul(a, b)?;
422            self.add(&mul_result, c)
423        }
424    }
425
426    fn scale(&self, vector: &[f32], scalar: f32) -> Result<Vec<f32>, SimdError> {
427        let width = self.optimal_f32_width();
428        match width {
429            16 | 8 => self.scale_avx2(vector, scalar),
430            4 => self.scale_sse(vector, scalar),
431            _ => Ok(vector.iter().map(|&x| x * scalar).collect()),
432        }
433    }
434}
435
436/// AVX-512 implementations
437#[allow(dead_code)]
438impl TargetOptimizedOps {
439    #[cfg(target_arch = "x86_64")]
440    #[target_feature(enable = "avx512f")]
441    unsafe fn add_avx512(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
442        #[cfg(target_arch = "x86_64")]
443        {
444            use core::arch::x86_64::*;
445
446            let mut result = Vec::with_capacity(a.len());
447            let chunks = a.len() / 16;
448
449            for i in 0..chunks {
450                let offset = i * 16;
451                let va = _mm512_loadu_ps(a.as_ptr().add(offset));
452                let vb = _mm512_loadu_ps(b.as_ptr().add(offset));
453                let vr = _mm512_add_ps(va, vb);
454
455                let mut temp = [0f32; 16];
456                _mm512_storeu_ps(temp.as_mut_ptr(), vr);
457                result.extend_from_slice(&temp);
458            }
459
460            // Handle remaining elements
461            for i in (chunks * 16)..a.len() {
462                result.push(a[i] + b[i]);
463            }
464
465            Ok(result)
466        }
467        #[cfg(not(target_arch = "x86_64"))]
468        {
469            Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect())
470        }
471    }
472
473    #[cfg(target_arch = "x86_64")]
474    #[target_feature(enable = "avx512f")]
475    unsafe fn sub_avx512(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
476        #[cfg(target_arch = "x86_64")]
477        {
478            use core::arch::x86_64::*;
479
480            let mut result = Vec::with_capacity(a.len());
481            let chunks = a.len() / 16;
482
483            for i in 0..chunks {
484                let offset = i * 16;
485                let va = _mm512_loadu_ps(a.as_ptr().add(offset));
486                let vb = _mm512_loadu_ps(b.as_ptr().add(offset));
487                let vr = _mm512_sub_ps(va, vb);
488
489                let mut temp = [0f32; 16];
490                _mm512_storeu_ps(temp.as_mut_ptr(), vr);
491                result.extend_from_slice(&temp);
492            }
493
494            // Handle remaining elements
495            for i in (chunks * 16)..a.len() {
496                result.push(a[i] - b[i]);
497            }
498
499            Ok(result)
500        }
501        #[cfg(not(target_arch = "x86_64"))]
502        {
503            Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect())
504        }
505    }
506
507    #[cfg(target_arch = "x86_64")]
508    #[target_feature(enable = "avx512f")]
509    unsafe fn mul_avx512(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
510        #[cfg(target_arch = "x86_64")]
511        {
512            use core::arch::x86_64::*;
513
514            let mut result = Vec::with_capacity(a.len());
515            let chunks = a.len() / 16;
516
517            for i in 0..chunks {
518                let offset = i * 16;
519                let va = _mm512_loadu_ps(a.as_ptr().add(offset));
520                let vb = _mm512_loadu_ps(b.as_ptr().add(offset));
521                let vr = _mm512_mul_ps(va, vb);
522
523                let mut temp = [0f32; 16];
524                _mm512_storeu_ps(temp.as_mut_ptr(), vr);
525                result.extend_from_slice(&temp);
526            }
527
528            // Handle remaining elements
529            for i in (chunks * 16)..a.len() {
530                result.push(a[i] * b[i]);
531            }
532
533            Ok(result)
534        }
535        #[cfg(not(target_arch = "x86_64"))]
536        {
537            Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect())
538        }
539    }
540
541    #[cfg(target_arch = "x86_64")]
542    #[target_feature(enable = "avx512f")]
543    unsafe fn div_avx512(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
544        #[cfg(target_arch = "x86_64")]
545        {
546            use core::arch::x86_64::*;
547
548            let mut result = Vec::with_capacity(a.len());
549            let chunks = a.len() / 16;
550
551            for i in 0..chunks {
552                let offset = i * 16;
553                let va = _mm512_loadu_ps(a.as_ptr().add(offset));
554                let vb = _mm512_loadu_ps(b.as_ptr().add(offset));
555                let vr = _mm512_div_ps(va, vb);
556
557                let mut temp = [0f32; 16];
558                _mm512_storeu_ps(temp.as_mut_ptr(), vr);
559                result.extend_from_slice(&temp);
560            }
561
562            // Handle remaining elements
563            for i in (chunks * 16)..a.len() {
564                result.push(a[i] / b[i]);
565            }
566
567            Ok(result)
568        }
569        #[cfg(not(target_arch = "x86_64"))]
570        {
571            Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x / y).collect())
572        }
573    }
574
575    #[cfg(target_arch = "x86_64")]
576    #[target_feature(enable = "avx512f")]
577    unsafe fn fma_avx512(&self, a: &[f32], b: &[f32], c: &[f32]) -> Result<Vec<f32>, SimdError> {
578        #[cfg(target_arch = "x86_64")]
579        {
580            use core::arch::x86_64::*;
581
582            let mut result = Vec::with_capacity(a.len());
583            let chunks = a.len() / 16;
584
585            for i in 0..chunks {
586                let offset = i * 16;
587                let va = _mm512_loadu_ps(a.as_ptr().add(offset));
588                let vb = _mm512_loadu_ps(b.as_ptr().add(offset));
589                let vc = _mm512_loadu_ps(c.as_ptr().add(offset));
590                let vr = _mm512_fmadd_ps(va, vb, vc);
591
592                let mut temp = [0f32; 16];
593                _mm512_storeu_ps(temp.as_mut_ptr(), vr);
594                result.extend_from_slice(&temp);
595            }
596
597            // Handle remaining elements
598            for i in (chunks * 16)..a.len() {
599                result.push(a[i] * b[i] + c[i]);
600            }
601
602            Ok(result)
603        }
604        #[cfg(not(target_arch = "x86_64"))]
605        {
606            Ok(a.iter()
607                .zip(b.iter())
608                .zip(c.iter())
609                .map(|((&x, &y), &z)| x * y + z)
610                .collect())
611        }
612    }
613
614    #[cfg(target_arch = "x86_64")]
615    #[target_feature(enable = "avx512f")]
616    unsafe fn scale_avx512(&self, vector: &[f32], scalar: f32) -> Result<Vec<f32>, SimdError> {
617        #[cfg(target_arch = "x86_64")]
618        {
619            use core::arch::x86_64::*;
620
621            let mut result = Vec::with_capacity(vector.len());
622            let chunks = vector.len() / 16;
623            let vs = _mm512_set1_ps(scalar);
624
625            for i in 0..chunks {
626                let offset = i * 16;
627                let vv = _mm512_loadu_ps(vector.as_ptr().add(offset));
628                let vr = _mm512_mul_ps(vv, vs);
629
630                let mut temp = [0f32; 16];
631                _mm512_storeu_ps(temp.as_mut_ptr(), vr);
632                result.extend_from_slice(&temp);
633            }
634
635            // Handle remaining elements
636            for val in vector.iter().skip(chunks * 16) {
637                result.push(*val * scalar);
638            }
639
640            Ok(result)
641        }
642        #[cfg(not(target_arch = "x86_64"))]
643        {
644            Ok(vector.iter().map(|&x| x * scalar).collect())
645        }
646    }
647}
648
649/// AVX2 implementations (similar structure, using _mm256_ intrinsics)
650impl TargetOptimizedOps {
651    fn add_avx2(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
652        Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect())
653    }
654
655    fn sub_avx2(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
656        Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect())
657    }
658
659    fn mul_avx2(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
660        Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect())
661    }
662
663    fn div_avx2(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
664        Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x / y).collect())
665    }
666
667    fn fma_avx2(&self, a: &[f32], b: &[f32], c: &[f32]) -> Result<Vec<f32>, SimdError> {
668        // Use existing fma function - it works in-place
669        let mut result = a.to_vec();
670        crate::vector::fma(&mut result, b, c);
671        Ok(result)
672    }
673
674    fn scale_avx2(&self, vector: &[f32], scalar: f32) -> Result<Vec<f32>, SimdError> {
675        let mut result = vector.to_vec();
676        crate::vector::scale(&mut result, scalar);
677        Ok(result)
678    }
679}
680
681/// SSE implementations (similar structure, using _mm_ intrinsics)
682impl TargetOptimizedOps {
683    fn add_sse(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
684        Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect())
685    }
686
687    fn sub_sse(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
688        Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect())
689    }
690
691    fn mul_sse(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
692        Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect())
693    }
694
695    fn div_sse(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, SimdError> {
696        Ok(a.iter().zip(b.iter()).map(|(&x, &y)| x / y).collect())
697    }
698
699    fn fma_sse(&self, a: &[f32], b: &[f32], c: &[f32]) -> Result<Vec<f32>, SimdError> {
700        Ok(a.iter()
701            .zip(b.iter())
702            .zip(c.iter())
703            .map(|((&x, &y), &z)| x * y + z)
704            .collect())
705    }
706
707    fn scale_sse(&self, vector: &[f32], scalar: f32) -> Result<Vec<f32>, SimdError> {
708        Ok(vector.iter().map(|&x| x * scalar).collect())
709    }
710}
711
712/// Target-specific reduction operations
713impl VectorReduction<f32> for TargetOptimizedOps {
714    fn sum(&self, vector: &[f32]) -> Result<f32, SimdError> {
715        if vector.is_empty() {
716            return Err(SimdError::EmptyInput);
717        }
718
719        let width = self.optimal_f32_width();
720        match width {
721            16 | 8 => self.sum_avx2(vector),
722            4 => self.sum_sse(vector),
723            _ => Ok(vector.iter().sum()),
724        }
725    }
726
727    fn min(&self, vector: &[f32]) -> Result<f32, SimdError> {
728        if vector.is_empty() {
729            return Err(SimdError::EmptyInput);
730        }
731
732        let (min_val, _) = crate::vector::min_max(vector);
733        Ok(min_val)
734    }
735
736    fn max(&self, vector: &[f32]) -> Result<f32, SimdError> {
737        if vector.is_empty() {
738            return Err(SimdError::EmptyInput);
739        }
740
741        let (_, max_val) = crate::vector::min_max(vector);
742        Ok(max_val)
743    }
744
745    fn dot_product(&self, a: &[f32], b: &[f32]) -> Result<f32, SimdError> {
746        if a.len() != b.len() {
747            return Err(SimdError::DimensionMismatch {
748                expected: a.len(),
749                actual: b.len(),
750            });
751        }
752
753        Ok(crate::vector::dot_product(a, b))
754    }
755
756    fn norm(&self, vector: &[f32]) -> Result<f32, SimdError> {
757        if vector.is_empty() {
758            return Err(SimdError::EmptyInput);
759        }
760
761        Ok(crate::vector::norm(vector))
762    }
763
764    fn mean(&self, vector: &[f32]) -> Result<f32, SimdError> {
765        if vector.is_empty() {
766            return Err(SimdError::EmptyInput);
767        }
768
769        let sum = self.sum(vector)?;
770        Ok(sum / vector.len() as f32)
771    }
772}
773
774/// Target-specific sum implementations
775impl TargetOptimizedOps {
776    #[allow(dead_code)] // Placeholder for dedicated AVX-512 sum path; not yet dispatched
777    fn sum_avx512(&self, vector: &[f32]) -> Result<f32, SimdError> {
778        Ok(crate::vector::sum(vector))
779    }
780
781    fn sum_avx2(&self, vector: &[f32]) -> Result<f32, SimdError> {
782        Ok(crate::vector::sum(vector))
783    }
784
785    fn sum_sse(&self, vector: &[f32]) -> Result<f32, SimdError> {
786        Ok(crate::vector::sum(vector))
787    }
788}
789
790/// Compilation target selection utilities
791pub mod compile_time {
792    use super::*;
793
794    /// Select implementation at compile time based on target features
795    #[macro_export]
796    macro_rules! select_target_impl {
797        ($generic:expr, $avx2:expr, $avx512:expr) => {{
798            #[cfg(target_feature = "avx512f")]
799            {
800                $avx512
801            }
802            #[cfg(all(target_feature = "avx2", not(target_feature = "avx512f")))]
803            {
804                $avx2
805            }
806            #[cfg(not(any(target_feature = "avx2", target_feature = "avx512f")))]
807            {
808                $generic
809            }
810        }};
811    }
812
813    /// Optimize for specific CPU microarchitecture
814    pub fn optimize_for_cpu() -> TargetConfig {
815        TargetConfig {
816            optimization_target: detect_optimization_target(),
817            enable_fma: cfg!(target_feature = "fma"),
818            enable_avx512: cfg!(target_feature = "avx512f"),
819            enable_avx10: false, // Not yet available in stable Rust
820            enable_amx: false,   // Intel AMX not yet in stable Rust
821            enable_sve2: false,  // ARM SVE2 not yet in stable Rust
822            enable_sme: false,   // ARM SME not yet in stable Rust
823            enable_fp16: false,  // FP16 support is library-based
824            enable_bf16: false,  // BF16 support is library-based
825            prefer_throughput: true,
826            prefer_latency: false,
827        }
828    }
829
830    /// Get CPU-specific optimization flags
831    #[allow(clippy::vec_init_then_push)]
832    pub fn get_optimization_flags() -> Vec<&'static str> {
833        #[allow(unused_mut)]
834        let mut flags = Vec::new();
835
836        #[cfg(target_feature = "sse2")]
837        flags.push("sse2");
838        #[cfg(target_feature = "avx")]
839        flags.push("avx");
840        #[cfg(target_feature = "avx2")]
841        flags.push("avx2");
842        #[cfg(target_feature = "fma")]
843        flags.push("fma");
844        #[cfg(target_feature = "avx512f")]
845        flags.push("avx512f");
846
847        flags
848    }
849}
850
851#[allow(non_snake_case)]
852#[cfg(all(test, not(feature = "no-std")))]
853mod tests {
854    use super::*;
855
856    #[cfg(feature = "no-std")]
857    use alloc::vec;
858
859    #[test]
860    fn test_target_detection() {
861        let target = detect_optimization_target();
862        println!("Detected optimization target: {:?}", target);
863
864        // Should detect a valid target
865        match target {
866            OptimizationTarget::Generic
867            | OptimizationTarget::Haswell
868            | OptimizationTarget::Skylake
869            | OptimizationTarget::Zen
870            | OptimizationTarget::CortexA76
871            | OptimizationTarget::AppleSilicon => {}
872            _ => panic!("Invalid optimization target detected"),
873        }
874    }
875
876    #[test]
877    fn test_target_config() {
878        let config = TargetConfig::default();
879        let ops = TargetOptimizedOps::new(config);
880
881        assert!(ops.optimal_f32_width() >= 1);
882        assert!(ops.cache_line_size() > 0);
883        assert!(ops.prefetch_distance() > 0);
884    }
885
886    #[test]
887    fn test_auto_detect() {
888        let ops = TargetOptimizedOps::auto_detect();
889
890        assert!(ops.optimal_f32_width() >= 1);
891        assert!(ops.optimal_f32_width() <= 16);
892    }
893
894    #[test]
895    fn test_vector_arithmetic() {
896        let ops = TargetOptimizedOps::auto_detect();
897
898        let a = vec![1.0, 2.0, 3.0, 4.0];
899        let b = vec![5.0, 6.0, 7.0, 8.0];
900
901        let result = ops.add(&a, &b).expect("operation should succeed");
902        assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
903
904        let result = ops.sub(&a, &b).expect("operation should succeed");
905        assert_eq!(result, vec![-4.0, -4.0, -4.0, -4.0]);
906
907        let result = ops.mul(&a, &b).expect("operation should succeed");
908        assert_eq!(result, vec![5.0, 12.0, 21.0, 32.0]);
909
910        let result = ops.scale(&a, 2.0).expect("operation should succeed");
911        assert_eq!(result, vec![2.0, 4.0, 6.0, 8.0]);
912    }
913
914    #[test]
915    fn test_vector_reductions() {
916        let ops = TargetOptimizedOps::auto_detect();
917
918        let vector = vec![1.0, 2.0, 3.0, 4.0];
919
920        let sum = ops.sum(&vector).expect("operation should succeed");
921        assert_eq!(sum, 10.0);
922
923        let mean = ops.mean(&vector).expect("operation should succeed");
924        assert_eq!(mean, 2.5);
925
926        let min = ops
927            .min(&vector)
928            .expect("collection should not be empty for min/max");
929        assert_eq!(min, 1.0);
930
931        let max = ops
932            .max(&vector)
933            .expect("collection should not be empty for min/max");
934        assert_eq!(max, 4.0);
935
936        let norm = ops.norm(&vector).expect("operation should succeed");
937        assert!((norm - (1.0 + 4.0 + 9.0 + 16.0_f32).sqrt()).abs() < 1e-6);
938    }
939
940    #[test]
941    fn test_fma_operation() {
942        let ops = TargetOptimizedOps::auto_detect();
943
944        let a = vec![1.0, 2.0, 3.0, 4.0];
945        let b = vec![2.0, 3.0, 4.0, 5.0];
946        let c = vec![1.0, 1.0, 1.0, 1.0];
947
948        let result = ops.fma(&a, &b, &c).expect("operation should succeed");
949        assert_eq!(result, vec![3.0, 7.0, 13.0, 21.0]); // a*b + c
950    }
951
952    #[test]
953    fn test_error_handling() {
954        let ops = TargetOptimizedOps::auto_detect();
955
956        let a = vec![1.0, 2.0, 3.0];
957        let b = vec![4.0, 5.0]; // Different length
958
959        let result = ops.add(&a, &b);
960        assert!(result.is_err());
961
962        match result {
963            Err(SimdError::DimensionMismatch { expected, actual }) => {
964                assert_eq!(expected, 3);
965                assert_eq!(actual, 2);
966            }
967            _ => panic!("Expected dimension mismatch error"),
968        }
969    }
970
971    #[test]
972    fn test_compile_time_features() {
973        use compile_time::*;
974
975        let flags = get_optimization_flags();
976        println!("Available optimization flags: {:?}", flags);
977
978        // Should have at least some basic features (or be empty on non-x86 platforms)
979        // This is acceptable since ARM or other platforms may not have these specific flags
980
981        let config = optimize_for_cpu();
982        println!("CPU-optimized config: {:?}", config);
983
984        // Basic sanity checks
985        assert!(matches!(
986            config.optimization_target,
987            OptimizationTarget::Generic
988                | OptimizationTarget::Haswell
989                | OptimizationTarget::Skylake
990                | OptimizationTarget::Zen
991                | OptimizationTarget::CortexA76
992                | OptimizationTarget::AppleSilicon
993                | OptimizationTarget::Server
994                | OptimizationTarget::Mobile
995        ));
996    }
997}