Skip to main content

trueno_explain/
simd.rs

1//! SIMD vectorization analyzer
2//!
3//! Analyzes x86 assembly for SIMD instruction usage and vectorization patterns.
4
5use crate::analyzer::{
6    AnalysisReport, Analyzer, MemoryPattern, MudaType, MudaWarning, RegisterUsage, RooflineMetric,
7};
8use crate::error::Result;
9use regex::Regex;
10
11/// Supported SIMD architectures
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum SimdArch {
14    /// SSE2 (128-bit)
15    Sse2,
16    /// AVX/AVX2 (256-bit)
17    Avx2,
18    /// AVX-512 (512-bit)
19    Avx512,
20    /// ARM NEON (128-bit)
21    Neon,
22}
23
24impl SimdArch {
25    /// Vector width in bits
26    #[must_use]
27    pub fn width_bits(&self) -> u32 {
28        match self {
29            Self::Sse2 | Self::Neon => 128,
30            Self::Avx2 => 256,
31            Self::Avx512 => 512,
32        }
33    }
34
35    /// Maximum f32 elements per vector
36    #[must_use]
37    pub fn f32_lanes(&self) -> u32 {
38        self.width_bits() / 32
39    }
40}
41
42/// SIMD instruction counts
43#[derive(Debug, Clone, Default)]
44pub struct SimdInstructionCounts {
45    /// Scalar instructions
46    pub scalar: u32,
47    /// SSE/SSE2 instructions (128-bit)
48    pub sse: u32,
49    /// AVX/AVX2 instructions (256-bit)
50    pub avx: u32,
51    /// AVX-512 instructions (512-bit)
52    pub avx512: u32,
53}
54
55impl SimdInstructionCounts {
56    /// Calculate vectorization ratio (0.0-1.0)
57    #[must_use]
58    pub fn vectorization_ratio(&self) -> f32 {
59        let total = self.scalar + self.sse + self.avx + self.avx512;
60        if total == 0 {
61            return 0.0;
62        }
63        let vectorized = self.sse + self.avx + self.avx512;
64        vectorized as f32 / total as f32
65    }
66}
67
68/// SIMD code analyzer
69pub struct SimdAnalyzer {
70    /// Target architecture for analysis
71    pub target_arch: SimdArch,
72    /// Warn if vectorization ratio below threshold
73    pub vectorization_threshold: f32,
74}
75
76impl Default for SimdAnalyzer {
77    fn default() -> Self {
78        Self {
79            target_arch: SimdArch::Avx2,
80            vectorization_threshold: 0.5,
81        }
82    }
83}
84
85impl SimdAnalyzer {
86    /// Create a new SIMD analyzer for the given architecture
87    #[must_use]
88    pub fn new(arch: SimdArch) -> Self {
89        Self {
90            target_arch: arch,
91            ..Default::default()
92        }
93    }
94
95    /// Count SIMD instructions in assembly
96    fn count_instructions(&self, asm: &str) -> SimdInstructionCounts {
97        let mut counts = SimdInstructionCounts::default();
98
99        // AVX-512 patterns (zmm registers, 512-bit ops)
100        let avx512_pattern =
101            Regex::new(r"(?i)(v\w+.*zmm|vp\w+.*zmm)").expect("valid regex pattern");
102        counts.avx512 = avx512_pattern.find_iter(asm).count() as u32;
103
104        // AVX/AVX2 patterns (ymm registers, 256-bit ops, v-prefix)
105        let avx_pattern = Regex::new(
106            r"(?i)(v\w+.*ymm|vp\w+.*ymm|vmovaps|vmovups|vmulps|vaddps|vsubps|vdivps|vfmadd|vfmsub)",
107        )
108        .expect("valid regex pattern");
109        counts.avx = avx_pattern.find_iter(asm).count() as u32;
110
111        // SSE patterns (xmm registers without v-prefix)
112        // Note: Rust regex doesn't support look-behind, so we match and filter
113        let sse_pattern = Regex::new(r"(?i)\b(movaps|movups|mulps|addps|subps|divps)\b.*xmm")
114            .expect("valid regex pattern");
115        counts.sse = sse_pattern.find_iter(asm).count() as u32;
116
117        // Scalar floating-point (ss = scalar single-precision)
118        let scalar_pattern =
119            Regex::new(r"(?i)\b(movss|mulss|addss|subss|divss|cvtsi2ss|cvtss2si)\b")
120                .expect("valid regex pattern");
121        counts.scalar = scalar_pattern.find_iter(asm).count() as u32;
122
123        counts
124    }
125
126    /// Detect scalar fallback code (Muda of Overprocessing)
127    fn detect_scalar_fallback(&self, counts: &SimdInstructionCounts) -> Option<MudaWarning> {
128        let ratio = counts.vectorization_ratio();
129        if ratio < self.vectorization_threshold && counts.scalar > 0 {
130            Some(MudaWarning {
131                muda_type: MudaType::Overprocessing,
132                description: format!(
133                    "Low vectorization: {:.1}% (threshold: {:.1}%)",
134                    ratio * 100.0,
135                    self.vectorization_threshold * 100.0
136                ),
137                impact: format!(
138                    "Potential {:.1}x speedup from better vectorization",
139                    self.target_arch.f32_lanes()
140                ),
141                line: None,
142                suggestion: Some("Check for alignment issues or loop trip count".to_string()),
143            })
144        } else {
145            None
146        }
147    }
148}
149
150impl Analyzer for SimdAnalyzer {
151    fn target_name(&self) -> &str {
152        match self.target_arch {
153            SimdArch::Sse2 => "x86 ASM (SSE2)",
154            SimdArch::Avx2 => "x86 ASM (AVX2)",
155            SimdArch::Avx512 => "x86 ASM (AVX-512)",
156            SimdArch::Neon => "ARM ASM (NEON)",
157        }
158    }
159
160    fn analyze(&self, asm: &str) -> Result<AnalysisReport> {
161        let counts = self.count_instructions(asm);
162        let warnings = self.detect_muda(asm);
163
164        let total_instructions = counts.scalar + counts.sse + counts.avx + counts.avx512;
165        let vectorization = counts.vectorization_ratio();
166
167        Ok(AnalysisReport {
168            name: "simd_analysis".to_string(),
169            target: self.target_name().to_string(),
170            registers: RegisterUsage::default(),
171            memory: MemoryPattern::default(),
172            roofline: self.estimate_roofline(&AnalysisReport::default()),
173            warnings,
174            instruction_count: total_instructions,
175            estimated_occupancy: vectorization, // Repurpose as vectorization ratio
176        })
177    }
178
179    fn detect_muda(&self, asm: &str) -> Vec<MudaWarning> {
180        let mut warnings = Vec::new();
181        let counts = self.count_instructions(asm);
182
183        if let Some(w) = self.detect_scalar_fallback(&counts) {
184            warnings.push(w);
185        }
186
187        warnings
188    }
189
190    fn estimate_roofline(&self, _analysis: &AnalysisReport) -> RooflineMetric {
191        // SIMD typically memory-bound for large data
192        RooflineMetric {
193            arithmetic_intensity: 1.0,
194            theoretical_peak_gflops: 1000.0, // Placeholder
195            memory_bound: true,
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_simd_arch_width() {
206        assert_eq!(SimdArch::Sse2.width_bits(), 128);
207        assert_eq!(SimdArch::Avx2.width_bits(), 256);
208        assert_eq!(SimdArch::Avx512.width_bits(), 512);
209    }
210
211    #[test]
212    fn test_simd_arch_lanes() {
213        assert_eq!(SimdArch::Sse2.f32_lanes(), 4);
214        assert_eq!(SimdArch::Avx2.f32_lanes(), 8);
215        assert_eq!(SimdArch::Avx512.f32_lanes(), 16);
216    }
217
218    #[test]
219    fn test_count_avx_instructions() {
220        let asm = r"
221            vmovaps ymm0, [rdi]
222            vmovaps ymm1, [rsi]
223            vaddps ymm2, ymm0, ymm1
224            vmovaps [rdx], ymm2
225        ";
226
227        let analyzer = SimdAnalyzer::new(SimdArch::Avx2);
228        let counts = analyzer.count_instructions(asm);
229
230        assert!(counts.avx > 0, "Should detect AVX instructions");
231    }
232
233    #[test]
234    fn test_count_sse_instructions() {
235        let asm = r"
236            movaps xmm0, [rdi]
237            movaps xmm1, [rsi]
238            addps xmm0, xmm1
239            movaps [rdx], xmm0
240        ";
241
242        let analyzer = SimdAnalyzer::new(SimdArch::Sse2);
243        let counts = analyzer.count_instructions(asm);
244
245        assert!(counts.sse > 0, "Should detect SSE instructions");
246    }
247
248    #[test]
249    fn test_vectorization_ratio() {
250        let counts = SimdInstructionCounts {
251            scalar: 2,
252            sse: 0,
253            avx: 8,
254            avx512: 0,
255        };
256
257        let ratio = counts.vectorization_ratio();
258        assert!((ratio - 0.8).abs() < 0.01, "Expected 80% vectorization");
259    }
260
261    #[test]
262    fn test_vectorization_ratio_zero() {
263        let counts = SimdInstructionCounts::default();
264        assert_eq!(counts.vectorization_ratio(), 0.0);
265    }
266
267    #[test]
268    fn test_detect_scalar_fallback() {
269        let asm = r"
270            movss xmm0, [rdi]
271            mulss xmm0, xmm1
272            addss xmm0, xmm2
273        ";
274
275        let analyzer = SimdAnalyzer::new(SimdArch::Avx2);
276        let warnings = analyzer.detect_muda(asm);
277
278        assert!(!warnings.is_empty(), "Should warn on scalar code");
279    }
280
281    /// F051: Detects AVX2 instructions
282    #[test]
283    fn f051_detect_avx2_instructions() {
284        let asm = "vmulps ymm0, ymm1, ymm2";
285        let analyzer = SimdAnalyzer::new(SimdArch::Avx2);
286        let counts = analyzer.count_instructions(asm);
287
288        assert!(counts.avx > 0, "Should detect vmulps");
289    }
290
291    /// F055: Calculates vectorization ratio
292    #[test]
293    fn f055_vectorization_ratio_positive() {
294        let asm = r"
295            vmovaps ymm0, [rdi]
296            vaddps ymm0, ymm0, ymm1
297        ";
298
299        let analyzer = SimdAnalyzer::new(SimdArch::Avx2);
300        let report = analyzer.analyze(asm).unwrap();
301
302        // estimated_occupancy repurposed as vectorization ratio
303        assert!(
304            report.estimated_occupancy > 0.0,
305            "Vectorization ratio should be > 0%"
306        );
307    }
308}