Skip to main content

trueno_explain/
simd.rs

1//! SIMD vectorization analyzer
2//!
3//! Analyzes x86 assembly for SIMD instruction usage and vectorization patterns.
4
5use crate::analyzer::{
6    AnalysisReport, Analyzer, MemoryPattern, MudaType, MudaWarning, RegisterUsage, RooflineMetric,
7};
8use crate::error::Result;
9use regex::Regex;
10
11/// Supported SIMD architectures
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum SimdArch {
14    /// SSE2 (128-bit)
15    Sse2,
16    /// AVX/AVX2 (256-bit)
17    Avx2,
18    /// AVX-512 (512-bit)
19    Avx512,
20    /// ARM NEON (128-bit)
21    Neon,
22}
23
24impl SimdArch {
25    /// Vector width in bits
26    #[must_use]
27    pub fn width_bits(&self) -> u32 {
28        match self {
29            Self::Sse2 | Self::Neon => 128,
30            Self::Avx2 => 256,
31            Self::Avx512 => 512,
32        }
33    }
34
35    /// Maximum f32 elements per vector
36    #[must_use]
37    pub fn f32_lanes(&self) -> u32 {
38        self.width_bits() / 32
39    }
40}
41
42/// SIMD instruction counts
43#[derive(Debug, Clone, Default)]
44pub struct SimdInstructionCounts {
45    /// Scalar instructions
46    pub scalar: u32,
47    /// SSE/SSE2 instructions (128-bit)
48    pub sse: u32,
49    /// AVX/AVX2 instructions (256-bit)
50    pub avx: u32,
51    /// AVX-512 instructions (512-bit)
52    pub avx512: u32,
53}
54
55impl SimdInstructionCounts {
56    /// Calculate vectorization ratio (0.0-1.0)
57    #[must_use]
58    pub fn vectorization_ratio(&self) -> f32 {
59        let total = self.scalar + self.sse + self.avx + self.avx512;
60        if total == 0 {
61            return 0.0;
62        }
63        let vectorized = self.sse + self.avx + self.avx512;
64        vectorized as f32 / total as f32
65    }
66}
67
68/// SIMD code analyzer
69pub struct SimdAnalyzer {
70    /// Target architecture for analysis
71    pub target_arch: SimdArch,
72    /// Warn if vectorization ratio below threshold
73    pub vectorization_threshold: f32,
74}
75
76impl Default for SimdAnalyzer {
77    fn default() -> Self {
78        Self {
79            target_arch: SimdArch::Avx2,
80            vectorization_threshold: 0.5,
81        }
82    }
83}
84
85impl SimdAnalyzer {
86    /// Create a new SIMD analyzer for the given architecture
87    #[must_use]
88    pub fn new(arch: SimdArch) -> Self {
89        Self {
90            target_arch: arch,
91            ..Default::default()
92        }
93    }
94
95    /// Count SIMD instructions in assembly
96    fn count_instructions(&self, asm: &str) -> SimdInstructionCounts {
97        let mut counts = SimdInstructionCounts::default();
98
99        // AVX-512 patterns (zmm registers, 512-bit ops)
100        let avx512_pattern = Regex::new(r"(?i)(v\w+.*zmm|vp\w+.*zmm)").unwrap();
101        counts.avx512 = avx512_pattern.find_iter(asm).count() as u32;
102
103        // AVX/AVX2 patterns (ymm registers, 256-bit ops, v-prefix)
104        let avx_pattern = Regex::new(
105            r"(?i)(v\w+.*ymm|vp\w+.*ymm|vmovaps|vmovups|vmulps|vaddps|vsubps|vdivps|vfmadd|vfmsub)",
106        )
107        .unwrap();
108        counts.avx = avx_pattern.find_iter(asm).count() as u32;
109
110        // SSE patterns (xmm registers without v-prefix)
111        // Note: Rust regex doesn't support look-behind, so we match and filter
112        let sse_pattern =
113            Regex::new(r"(?i)\b(movaps|movups|mulps|addps|subps|divps)\b.*xmm").unwrap();
114        counts.sse = sse_pattern.find_iter(asm).count() as u32;
115
116        // Scalar floating-point (ss = scalar single-precision)
117        let scalar_pattern =
118            Regex::new(r"(?i)\b(movss|mulss|addss|subss|divss|cvtsi2ss|cvtss2si)\b").unwrap();
119        counts.scalar = scalar_pattern.find_iter(asm).count() as u32;
120
121        counts
122    }
123
124    /// Detect scalar fallback code (Muda of Overprocessing)
125    fn detect_scalar_fallback(&self, counts: &SimdInstructionCounts) -> Option<MudaWarning> {
126        let ratio = counts.vectorization_ratio();
127        if ratio < self.vectorization_threshold && counts.scalar > 0 {
128            Some(MudaWarning {
129                muda_type: MudaType::Overprocessing,
130                description: format!(
131                    "Low vectorization: {:.1}% (threshold: {:.1}%)",
132                    ratio * 100.0,
133                    self.vectorization_threshold * 100.0
134                ),
135                impact: format!(
136                    "Potential {:.1}x speedup from better vectorization",
137                    self.target_arch.f32_lanes()
138                ),
139                line: None,
140                suggestion: Some("Check for alignment issues or loop trip count".to_string()),
141            })
142        } else {
143            None
144        }
145    }
146}
147
148impl Analyzer for SimdAnalyzer {
149    fn target_name(&self) -> &str {
150        match self.target_arch {
151            SimdArch::Sse2 => "x86 ASM (SSE2)",
152            SimdArch::Avx2 => "x86 ASM (AVX2)",
153            SimdArch::Avx512 => "x86 ASM (AVX-512)",
154            SimdArch::Neon => "ARM ASM (NEON)",
155        }
156    }
157
158    fn analyze(&self, asm: &str) -> Result<AnalysisReport> {
159        let counts = self.count_instructions(asm);
160        let warnings = self.detect_muda(asm);
161
162        let total_instructions = counts.scalar + counts.sse + counts.avx + counts.avx512;
163        let vectorization = counts.vectorization_ratio();
164
165        Ok(AnalysisReport {
166            name: "simd_analysis".to_string(),
167            target: self.target_name().to_string(),
168            registers: RegisterUsage::default(),
169            memory: MemoryPattern::default(),
170            roofline: self.estimate_roofline(&AnalysisReport::default()),
171            warnings,
172            instruction_count: total_instructions,
173            estimated_occupancy: vectorization, // Repurpose as vectorization ratio
174        })
175    }
176
177    fn detect_muda(&self, asm: &str) -> Vec<MudaWarning> {
178        let mut warnings = Vec::new();
179        let counts = self.count_instructions(asm);
180
181        if let Some(w) = self.detect_scalar_fallback(&counts) {
182            warnings.push(w);
183        }
184
185        warnings
186    }
187
188    fn estimate_roofline(&self, _analysis: &AnalysisReport) -> RooflineMetric {
189        // SIMD typically memory-bound for large data
190        RooflineMetric {
191            arithmetic_intensity: 1.0,
192            theoretical_peak_gflops: 1000.0, // Placeholder
193            memory_bound: true,
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_simd_arch_width() {
204        assert_eq!(SimdArch::Sse2.width_bits(), 128);
205        assert_eq!(SimdArch::Avx2.width_bits(), 256);
206        assert_eq!(SimdArch::Avx512.width_bits(), 512);
207    }
208
209    #[test]
210    fn test_simd_arch_lanes() {
211        assert_eq!(SimdArch::Sse2.f32_lanes(), 4);
212        assert_eq!(SimdArch::Avx2.f32_lanes(), 8);
213        assert_eq!(SimdArch::Avx512.f32_lanes(), 16);
214    }
215
216    #[test]
217    fn test_count_avx_instructions() {
218        let asm = r#"
219            vmovaps ymm0, [rdi]
220            vmovaps ymm1, [rsi]
221            vaddps ymm2, ymm0, ymm1
222            vmovaps [rdx], ymm2
223        "#;
224
225        let analyzer = SimdAnalyzer::new(SimdArch::Avx2);
226        let counts = analyzer.count_instructions(asm);
227
228        assert!(counts.avx > 0, "Should detect AVX instructions");
229    }
230
231    #[test]
232    fn test_count_sse_instructions() {
233        let asm = r#"
234            movaps xmm0, [rdi]
235            movaps xmm1, [rsi]
236            addps xmm0, xmm1
237            movaps [rdx], xmm0
238        "#;
239
240        let analyzer = SimdAnalyzer::new(SimdArch::Sse2);
241        let counts = analyzer.count_instructions(asm);
242
243        assert!(counts.sse > 0, "Should detect SSE instructions");
244    }
245
246    #[test]
247    fn test_vectorization_ratio() {
248        let counts = SimdInstructionCounts {
249            scalar: 2,
250            sse: 0,
251            avx: 8,
252            avx512: 0,
253        };
254
255        let ratio = counts.vectorization_ratio();
256        assert!((ratio - 0.8).abs() < 0.01, "Expected 80% vectorization");
257    }
258
259    #[test]
260    fn test_vectorization_ratio_zero() {
261        let counts = SimdInstructionCounts::default();
262        assert_eq!(counts.vectorization_ratio(), 0.0);
263    }
264
265    #[test]
266    fn test_detect_scalar_fallback() {
267        let asm = r#"
268            movss xmm0, [rdi]
269            mulss xmm0, xmm1
270            addss xmm0, xmm2
271        "#;
272
273        let analyzer = SimdAnalyzer::new(SimdArch::Avx2);
274        let warnings = analyzer.detect_muda(asm);
275
276        assert!(!warnings.is_empty(), "Should warn on scalar code");
277    }
278
279    /// F051: Detects AVX2 instructions
280    #[test]
281    fn f051_detect_avx2_instructions() {
282        let asm = "vmulps ymm0, ymm1, ymm2";
283        let analyzer = SimdAnalyzer::new(SimdArch::Avx2);
284        let counts = analyzer.count_instructions(asm);
285
286        assert!(counts.avx > 0, "Should detect vmulps");
287    }
288
289    /// F055: Calculates vectorization ratio
290    #[test]
291    fn f055_vectorization_ratio_positive() {
292        let asm = r#"
293            vmovaps ymm0, [rdi]
294            vaddps ymm0, ymm0, ymm1
295        "#;
296
297        let analyzer = SimdAnalyzer::new(SimdArch::Avx2);
298        let report = analyzer.analyze(asm).unwrap();
299
300        // estimated_occupancy repurposed as vectorization ratio
301        assert!(
302            report.estimated_occupancy > 0.0,
303            "Vectorization ratio should be > 0%"
304        );
305    }
306}