Skip to main content

bhc_loop_ir/
vectorize.rs

1//! # Auto-Vectorization Pass
2//!
3//! This module implements auto-vectorization for Loop IR, transforming scalar
4//! operations to SIMD operations based on loop analysis.
5//!
6//! ## M3 Exit Criteria
7//!
8//! - `matmul` microkernel auto-vectorizes on x86_64 and aarch64
9//! - SIMD intrinsics: add, mul, fmadd, hadd
10//!
11//! ## Vectorization Strategy
12//!
13//! 1. **Analyze loops**: Identify vectorizable innermost loops
14//! 2. **Check access patterns**: Sequential access enables vectorization
15//! 3. **Transform operations**: Scalar → Vector operations
16//! 4. **Handle remainder**: Scalar loop for non-aligned elements
17
18use crate::{
19    AccessPattern, BinOp, Body, Loop, LoopAttrs, LoopIR, LoopId, LoopMetadata, LoopType, Op,
20    ScalarType, Stmt, TargetArch, TripCount, Value,
21};
22use rustc_hash::FxHashMap;
23use thiserror::Error;
24
25/// Errors that can occur during vectorization.
26#[derive(Clone, Debug, Error)]
27pub enum VectorizeError {
28    /// Loop cannot be vectorized.
29    #[error("loop {loop_id:?} cannot be vectorized: {reason}")]
30    NotVectorizable {
31        /// Loop identifier.
32        loop_id: LoopId,
33        /// Reason vectorization failed.
34        reason: String,
35    },
36
37    /// Invalid vector width.
38    #[error("invalid vector width {width} for type {ty:?}")]
39    InvalidWidth {
40        /// Requested width.
41        width: u8,
42        /// Element type.
43        ty: ScalarType,
44    },
45}
46
47/// Result of vectorization analysis for a loop.
48#[derive(Clone, Debug)]
49pub struct VectorizationInfo {
50    /// Whether the loop can be vectorized.
51    pub vectorizable: bool,
52    /// Reason if not vectorizable.
53    pub reason: Option<String>,
54    /// Recommended vector width.
55    pub recommended_width: u8,
56    /// Access patterns in the loop.
57    pub access_patterns: Vec<AccessPattern>,
58    /// Whether FMA opportunities exist.
59    pub has_fma: bool,
60    /// Whether horizontal reduction is needed.
61    pub has_reduction: bool,
62}
63
64impl Default for VectorizationInfo {
65    fn default() -> Self {
66        Self {
67            vectorizable: false,
68            reason: Some("not analyzed".to_string()),
69            recommended_width: 1,
70            access_patterns: Vec::new(),
71            has_fma: false,
72            has_reduction: false,
73        }
74    }
75}
76
77/// Configuration for vectorization.
78#[derive(Clone, Debug)]
79pub struct VectorizeConfig {
80    /// Target architecture.
81    pub target: TargetArch,
82    /// Force a specific vector width (0 = auto).
83    pub forced_width: u8,
84    /// Generate remainder loop for non-aligned iterations.
85    pub generate_remainder: bool,
86    /// Enable FMA fusion.
87    pub enable_fma: bool,
88    /// Minimum trip count for vectorization.
89    pub min_trip_count: usize,
90}
91
92impl Default for VectorizeConfig {
93    fn default() -> Self {
94        Self {
95            target: TargetArch::default(),
96            forced_width: 0,
97            generate_remainder: true,
98            enable_fma: true,
99            min_trip_count: 4,
100        }
101    }
102}
103
104/// Vectorization pass state.
105pub struct VectorizePass {
106    config: VectorizeConfig,
107    /// Analysis results per loop.
108    analysis: FxHashMap<LoopId, VectorizationInfo>,
109}
110
111impl VectorizePass {
112    /// Create a new vectorization pass with the given configuration.
113    pub fn new(config: VectorizeConfig) -> Self {
114        Self {
115            config,
116            analysis: FxHashMap::default(),
117        }
118    }
119
120    /// Analyze a Loop IR function for vectorization opportunities.
121    pub fn analyze(&mut self, ir: &LoopIR) -> FxHashMap<LoopId, VectorizationInfo> {
122        self.analysis.clear();
123
124        for stmt in &ir.body.stmts {
125            self.analyze_stmt(stmt, &ir.loop_info);
126        }
127
128        self.analysis.clone()
129    }
130
131    /// Analyze a statement for vectorization.
132    fn analyze_stmt(&mut self, stmt: &Stmt, loop_info: &[LoopMetadata]) {
133        if let Stmt::Loop(lp) = stmt {
134            let info = self.analyze_loop(lp, loop_info);
135            self.analysis.insert(lp.id, info);
136
137            // Recursively analyze nested loops
138            for inner_stmt in &lp.body.stmts {
139                self.analyze_stmt(inner_stmt, loop_info);
140            }
141        }
142    }
143
144    /// Analyze a single loop for vectorization.
145    fn analyze_loop(&self, lp: &Loop, loop_info: &[LoopMetadata]) -> VectorizationInfo {
146        let mut info = VectorizationInfo::default();
147
148        // Check if loop is marked as vectorizable
149        if !lp.attrs.contains(LoopAttrs::VECTORIZE) {
150            info.reason = Some("loop not marked VECTORIZE".to_string());
151            return info;
152        }
153
154        // Check trip count
155        let metadata = loop_info.iter().find(|m| m.id == lp.id);
156        let trip_count = metadata.map(|m| &m.trip_count);
157
158        match trip_count {
159            Some(TripCount::Static(n)) if *n < self.config.min_trip_count => {
160                info.reason = Some(format!(
161                    "trip count {} below threshold {}",
162                    n, self.config.min_trip_count
163                ));
164                return info;
165            }
166            Some(TripCount::Dynamic) => {
167                // Dynamic trip count requires runtime check
168                // Still vectorizable with remainder handling
169            }
170            _ => {}
171        }
172
173        // Analyze access patterns in loop body
174        let (patterns, has_fma, has_reduction) = self.analyze_loop_body(&lp.body);
175        info.access_patterns = patterns.clone();
176        info.has_fma = has_fma;
177        info.has_reduction = has_reduction;
178
179        // Check if all accesses are vectorization-friendly
180        let all_sequential = patterns
181            .iter()
182            .all(|p| matches!(p, AccessPattern::Sequential | AccessPattern::Broadcast));
183
184        if !all_sequential {
185            info.reason = Some("non-sequential access pattern".to_string());
186            return info;
187        }
188
189        // Determine vector width
190        let elem_type = self.infer_element_type(&lp.body);
191        let width = if self.config.forced_width > 0 {
192            self.config.forced_width
193        } else {
194            LoopType::natural_vector_width(elem_type, self.config.target)
195        };
196
197        info.vectorizable = width > 1;
198        info.recommended_width = width;
199        info.reason = None;
200
201        info
202    }
203
204    /// Analyze loop body for access patterns, FMA opportunities, and reductions.
205    fn analyze_loop_body(&self, body: &Body) -> (Vec<AccessPattern>, bool, bool) {
206        let mut patterns = Vec::new();
207        let mut has_fma = false;
208        let mut has_reduction = false;
209
210        for stmt in &body.stmts {
211            match stmt {
212                Stmt::Assign(_, op) => {
213                    // Check for load access patterns
214                    if let Op::Load(mem_ref) = op {
215                        patterns.push(mem_ref.access.clone());
216                    }
217
218                    // Check for FMA pattern: a * b + c or a + b * c
219                    if self.config.enable_fma {
220                        has_fma |= self.is_fma_opportunity(op);
221                    }
222
223                    // Check for reduction operations
224                    if let Op::VecReduce(_, _) = op {
225                        has_reduction = true;
226                    }
227                }
228                Stmt::Store(mem_ref, _) => {
229                    patterns.push(mem_ref.access.clone());
230                }
231                Stmt::Loop(inner)
232                    // Check if inner loop has reduction attribute
233                    if inner.attrs.contains(LoopAttrs::REDUCTION) => {
234                        has_reduction = true;
235                    }
236                _ => {}
237            }
238        }
239
240        (patterns, has_fma, has_reduction)
241    }
242
243    /// Check if an operation can be replaced with FMA.
244    fn is_fma_opportunity(&self, op: &Op) -> bool {
245        // Pattern: Add(Mul(a, b), c) or Add(c, Mul(a, b))
246        match op {
247            Op::Binary(BinOp::Add, _, _) => {
248                // Would need to check operands are Mul results
249                // For now, return false as we'd need more context
250                false
251            }
252            _ => false,
253        }
254    }
255
256    /// Infer the element type from loop body operations.
257    fn infer_element_type(&self, body: &Body) -> ScalarType {
258        for stmt in &body.stmts {
259            if let Stmt::Assign(_, Op::Load(mem_ref)) = stmt {
260                if let LoopType::Scalar(s) = &mem_ref.elem_ty {
261                    return *s;
262                }
263            }
264        }
265        ScalarType::Float(32) // Default
266    }
267
268    /// Apply vectorization to a Loop IR function.
269    pub fn vectorize(&self, ir: &mut LoopIR) -> Result<VectorizeReport, VectorizeError> {
270        let mut report = VectorizeReport::default();
271
272        for stmt in &mut ir.body.stmts {
273            self.vectorize_stmt(stmt, &mut ir.loop_info, &mut report)?;
274        }
275
276        Ok(report)
277    }
278
279    /// Vectorize a statement.
280    fn vectorize_stmt(
281        &self,
282        stmt: &mut Stmt,
283        loop_info: &mut [LoopMetadata],
284        report: &mut VectorizeReport,
285    ) -> Result<(), VectorizeError> {
286        if let Stmt::Loop(lp) = stmt {
287            if let Some(info) = self.analysis.get(&lp.id) {
288                if info.vectorizable {
289                    self.vectorize_loop(lp, info, loop_info, report)?;
290                }
291            }
292
293            // Recursively vectorize nested loops
294            for inner_stmt in &mut lp.body.stmts {
295                self.vectorize_stmt(inner_stmt, loop_info, report)?;
296            }
297        }
298        Ok(())
299    }
300
301    /// Vectorize a single loop.
302    fn vectorize_loop(
303        &self,
304        lp: &mut Loop,
305        info: &VectorizationInfo,
306        loop_info: &mut [LoopMetadata],
307        report: &mut VectorizeReport,
308    ) -> Result<(), VectorizeError> {
309        let width = info.recommended_width;
310
311        // Update loop step
312        lp.step = Value::i64(width as i64);
313
314        // Update loop metadata
315        if let Some(meta) = loop_info.iter_mut().find(|m| m.id == lp.id) {
316            meta.vector_width = Some(width);
317        }
318
319        // Transform operations in loop body to vector operations
320        self.vectorize_body(&mut lp.body, width)?;
321
322        // Record vectorization
323        report.vectorized_loops.push(VectorizedLoopInfo {
324            loop_id: lp.id,
325            vector_width: width,
326            has_fma: info.has_fma,
327            has_reduction: info.has_reduction,
328        });
329
330        Ok(())
331    }
332
333    /// Transform scalar operations in a body to vector operations.
334    fn vectorize_body(&self, body: &mut Body, width: u8) -> Result<(), VectorizeError> {
335        for stmt in &mut body.stmts {
336            if let Stmt::Assign(_, op) = stmt {
337                *op = self.vectorize_op(op, width)?;
338            }
339        }
340        Ok(())
341    }
342
343    /// Transform a scalar operation to a vector operation.
344    fn vectorize_op(&self, op: &Op, width: u8) -> Result<Op, VectorizeError> {
345        match op {
346            Op::Load(mem_ref) => {
347                // Transform scalar load to vector load
348                let mut vec_ref = mem_ref.clone();
349                if let LoopType::Scalar(s) = &mem_ref.elem_ty {
350                    vec_ref.elem_ty = LoopType::Vector(*s, width);
351                }
352                Ok(Op::Load(vec_ref))
353            }
354
355            Op::Binary(bin_op, a, b) => {
356                // Transform scalar binary op to vector binary op
357                let vec_a = self.vectorize_value(a, width);
358                let vec_b = self.vectorize_value(b, width);
359                Ok(Op::Binary(*bin_op, vec_a, vec_b))
360            }
361
362            Op::Unary(un_op, a) => {
363                let vec_a = self.vectorize_value(a, width);
364                Ok(Op::Unary(*un_op, vec_a))
365            }
366
367            // FMA is naturally vector
368            Op::Fma(a, b, c) => {
369                let vec_a = self.vectorize_value(a, width);
370                let vec_b = self.vectorize_value(b, width);
371                let vec_c = self.vectorize_value(c, width);
372                Ok(Op::Fma(vec_a, vec_b, vec_c))
373            }
374
375            // Keep other operations as-is
376            _ => Ok(op.clone()),
377        }
378    }
379
380    /// Transform a scalar value to a vector value.
381    fn vectorize_value(&self, val: &Value, width: u8) -> Value {
382        match val {
383            Value::Var(id, LoopType::Scalar(s)) => Value::Var(*id, LoopType::Vector(*s, width)),
384            Value::FloatConst(f, s) => {
385                // Scalar constant will be broadcast
386                Value::FloatConst(*f, *s)
387            }
388            Value::IntConst(i, s) => Value::IntConst(*i, *s),
389            _ => val.clone(),
390        }
391    }
392}
393
394/// Report of vectorization results.
395#[derive(Clone, Debug, Default)]
396pub struct VectorizeReport {
397    /// Loops that were vectorized.
398    pub vectorized_loops: Vec<VectorizedLoopInfo>,
399    /// Loops that could not be vectorized.
400    pub failed_loops: Vec<(LoopId, String)>,
401}
402
403impl VectorizeReport {
404    /// Returns true if any loops were vectorized.
405    pub fn any_vectorized(&self) -> bool {
406        !self.vectorized_loops.is_empty()
407    }
408
409    /// Returns the total number of vectorized loops.
410    pub fn count(&self) -> usize {
411        self.vectorized_loops.len()
412    }
413}
414
415impl std::fmt::Display for VectorizeReport {
416    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417        writeln!(f, "Vectorization Report")?;
418        writeln!(f, "====================")?;
419        writeln!(f, "Vectorized loops: {}", self.vectorized_loops.len())?;
420
421        for info in &self.vectorized_loops {
422            writeln!(
423                f,
424                "  Loop {:?}: width={}, fma={}, reduction={}",
425                info.loop_id, info.vector_width, info.has_fma, info.has_reduction
426            )?;
427        }
428
429        if !self.failed_loops.is_empty() {
430            writeln!(f, "\nFailed loops: {}", self.failed_loops.len())?;
431            for (id, reason) in &self.failed_loops {
432                writeln!(f, "  Loop {:?}: {}", id, reason)?;
433            }
434        }
435
436        Ok(())
437    }
438}
439
440/// Information about a vectorized loop.
441#[derive(Clone, Debug)]
442pub struct VectorizedLoopInfo {
443    /// Loop identifier.
444    pub loop_id: LoopId,
445    /// Vector width used.
446    pub vector_width: u8,
447    /// Whether FMA was used.
448    pub has_fma: bool,
449    /// Whether reduction was needed.
450    pub has_reduction: bool,
451}
452
453// ============================================================================
454// SIMD Intrinsics (M3 Deliverable)
455// ============================================================================
456
457/// SIMD intrinsic operations.
458///
459/// These map directly to hardware SIMD instructions on supported targets.
460#[derive(Clone, Copy, Debug, PartialEq, Eq)]
461pub enum SimdIntrinsic {
462    // --- Arithmetic ---
463    /// Vector add: `result[i] = a[i] + b[i]`
464    Add,
465    /// Vector subtract: `result[i] = a[i] - b[i]`
466    Sub,
467    /// Vector multiply: `result[i] = a[i] * b[i]`
468    Mul,
469    /// Vector divide: `result[i] = a[i] / b[i]`
470    Div,
471
472    // --- Fused Multiply-Add ---
473    /// Fused multiply-add: `result[i] = a[i] * b[i] + c[i]`
474    Fmadd,
475    /// Fused multiply-subtract: `result[i] = a[i] * b[i] - c[i]`
476    Fmsub,
477    /// Fused negative multiply-add: `result[i] = -(a[i] * b[i]) + c[i]`
478    Fnmadd,
479
480    // --- Horizontal Operations ---
481    /// Horizontal add (pairwise): `hadd([a,b,c,d], [e,f,g,h]) = [a+b, c+d, e+f, g+h]`
482    Hadd,
483    /// Horizontal sum (reduce all elements): `sum([a,b,c,d]) = a+b+c+d`
484    HorizontalSum,
485
486    // --- Min/Max ---
487    /// Vector minimum: `result[i] = min(a[i], b[i])`
488    Min,
489    /// Vector maximum: `result[i] = max(a[i], b[i])`
490    Max,
491
492    // --- Comparison ---
493    /// Vector compare equal: `result[i] = a[i] == b[i] ? ~0 : 0`
494    CmpEq,
495    /// Vector compare less than: `result[i] = a[i] < b[i] ? ~0 : 0`
496    CmpLt,
497    /// Vector compare less or equal: `result[i] = a[i] <= b[i] ? ~0 : 0`
498    CmpLe,
499
500    // --- Data Movement ---
501    /// Broadcast scalar to all lanes
502    Broadcast,
503    /// Extract element from vector
504    Extract,
505    /// Insert element into vector
506    Insert,
507    /// Shuffle/permute elements
508    Shuffle,
509
510    // --- Load/Store ---
511    /// Aligned load
512    LoadAligned,
513    /// Unaligned load
514    LoadUnaligned,
515    /// Aligned store
516    StoreAligned,
517    /// Unaligned store
518    StoreUnaligned,
519}
520
521impl SimdIntrinsic {
522    /// Returns the x86 intrinsic name for this operation.
523    pub fn x86_name(&self, ty: ScalarType, width: u8) -> &'static str {
524        match (self, ty, width) {
525            // Float32 x 4 (SSE)
526            (Self::Add, ScalarType::Float(32), 4) => "_mm_add_ps",
527            (Self::Sub, ScalarType::Float(32), 4) => "_mm_sub_ps",
528            (Self::Mul, ScalarType::Float(32), 4) => "_mm_mul_ps",
529            (Self::Div, ScalarType::Float(32), 4) => "_mm_div_ps",
530            (Self::Fmadd, ScalarType::Float(32), 4) => "_mm_fmadd_ps",
531            (Self::Min, ScalarType::Float(32), 4) => "_mm_min_ps",
532            (Self::Max, ScalarType::Float(32), 4) => "_mm_max_ps",
533            (Self::LoadAligned, ScalarType::Float(32), 4) => "_mm_load_ps",
534            (Self::StoreAligned, ScalarType::Float(32), 4) => "_mm_store_ps",
535
536            // Float32 x 8 (AVX)
537            (Self::Add, ScalarType::Float(32), 8) => "_mm256_add_ps",
538            (Self::Sub, ScalarType::Float(32), 8) => "_mm256_sub_ps",
539            (Self::Mul, ScalarType::Float(32), 8) => "_mm256_mul_ps",
540            (Self::Div, ScalarType::Float(32), 8) => "_mm256_div_ps",
541            (Self::Fmadd, ScalarType::Float(32), 8) => "_mm256_fmadd_ps",
542            (Self::Min, ScalarType::Float(32), 8) => "_mm256_min_ps",
543            (Self::Max, ScalarType::Float(32), 8) => "_mm256_max_ps",
544            (Self::LoadAligned, ScalarType::Float(32), 8) => "_mm256_load_ps",
545            (Self::StoreAligned, ScalarType::Float(32), 8) => "_mm256_store_ps",
546            (Self::Hadd, ScalarType::Float(32), 8) => "_mm256_hadd_ps",
547
548            // Float64 x 2 (SSE2)
549            (Self::Add, ScalarType::Float(64), 2) => "_mm_add_pd",
550            (Self::Sub, ScalarType::Float(64), 2) => "_mm_sub_pd",
551            (Self::Mul, ScalarType::Float(64), 2) => "_mm_mul_pd",
552            (Self::Fmadd, ScalarType::Float(64), 2) => "_mm_fmadd_pd",
553
554            // Float64 x 4 (AVX)
555            (Self::Add, ScalarType::Float(64), 4) => "_mm256_add_pd",
556            (Self::Sub, ScalarType::Float(64), 4) => "_mm256_sub_pd",
557            (Self::Mul, ScalarType::Float(64), 4) => "_mm256_mul_pd",
558            (Self::Fmadd, ScalarType::Float(64), 4) => "_mm256_fmadd_pd",
559
560            _ => "unknown_intrinsic",
561        }
562    }
563
564    /// Returns the ARM NEON intrinsic name for this operation.
565    pub fn arm_name(&self, ty: ScalarType, width: u8) -> &'static str {
566        match (self, ty, width) {
567            // Float32 x 4 (NEON)
568            (Self::Add, ScalarType::Float(32), 4) => "vaddq_f32",
569            (Self::Sub, ScalarType::Float(32), 4) => "vsubq_f32",
570            (Self::Mul, ScalarType::Float(32), 4) => "vmulq_f32",
571            (Self::Fmadd, ScalarType::Float(32), 4) => "vfmaq_f32",
572            (Self::Min, ScalarType::Float(32), 4) => "vminq_f32",
573            (Self::Max, ScalarType::Float(32), 4) => "vmaxq_f32",
574            (Self::LoadAligned, ScalarType::Float(32), 4) => "vld1q_f32",
575            (Self::StoreAligned, ScalarType::Float(32), 4) => "vst1q_f32",
576
577            // Float64 x 2 (NEON)
578            (Self::Add, ScalarType::Float(64), 2) => "vaddq_f64",
579            (Self::Sub, ScalarType::Float(64), 2) => "vsubq_f64",
580            (Self::Mul, ScalarType::Float(64), 2) => "vmulq_f64",
581            (Self::Fmadd, ScalarType::Float(64), 2) => "vfmaq_f64",
582
583            _ => "unknown_intrinsic",
584        }
585    }
586}
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591    use crate::{MemRef, Param, ValueId};
592    use bhc_index::Idx;
593    use bhc_intern::Symbol;
594    use bhc_tensor_ir::BufferId;
595
596    fn make_vectorizable_loop(trip_count: usize) -> (LoopIR, LoopId) {
597        let loop_id = LoopId::new(0);
598        let loop_var = ValueId::new(0);
599
600        let mem_ref = MemRef {
601            buffer: BufferId::new(0),
602            index: Value::Var(loop_var, LoopType::Scalar(ScalarType::I64)),
603            elem_ty: LoopType::Scalar(ScalarType::F32),
604            access: AccessPattern::Sequential,
605        };
606
607        let mut body = Body::new();
608        let load_result = ValueId::new(1);
609        body.push(Stmt::Assign(load_result, Op::Load(mem_ref.clone())));
610
611        let mul_result = ValueId::new(2);
612        body.push(Stmt::Assign(
613            mul_result,
614            Op::Binary(
615                BinOp::Mul,
616                Value::Var(load_result, LoopType::Scalar(ScalarType::F32)),
617                Value::float(2.0, 32),
618            ),
619        ));
620
621        body.push(Stmt::Store(
622            mem_ref,
623            Value::Var(mul_result, LoopType::Scalar(ScalarType::F32)),
624        ));
625
626        let lp = Loop {
627            id: loop_id,
628            var: loop_var,
629            lower: Value::i64(0),
630            upper: Value::i64(trip_count as i64),
631            step: Value::i64(1),
632            body,
633            attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT,
634        };
635
636        let mut outer_body = Body::new();
637        outer_body.push(Stmt::Loop(lp));
638
639        let ir = LoopIR {
640            name: Symbol::intern("test_kernel"),
641            params: vec![Param {
642                name: Symbol::intern("data"),
643                ty: LoopType::Ptr(Box::new(LoopType::Scalar(ScalarType::F32))),
644                is_ptr: true,
645            }],
646            return_ty: LoopType::Void,
647            body: outer_body,
648            allocs: vec![],
649            loop_info: vec![LoopMetadata {
650                id: loop_id,
651                trip_count: TripCount::Static(trip_count),
652                vector_width: None,
653                parallel_chunk: None,
654                unroll_factor: None,
655                dependencies: Vec::new(),
656            }],
657        };
658
659        (ir, loop_id)
660    }
661
662    #[test]
663    fn test_vectorization_analysis() {
664        let (ir, loop_id) = make_vectorizable_loop(1024);
665
666        let mut pass = VectorizePass::new(VectorizeConfig::default());
667        let analysis = pass.analyze(&ir);
668
669        let info = analysis.get(&loop_id).expect("loop should be analyzed");
670        assert!(info.vectorizable, "loop should be vectorizable");
671        assert!(
672            info.recommended_width > 1,
673            "should recommend vector width > 1"
674        );
675    }
676
677    #[test]
678    fn test_vectorization_below_threshold() {
679        let (ir, loop_id) = make_vectorizable_loop(2); // Below default threshold of 4
680
681        let mut pass = VectorizePass::new(VectorizeConfig::default());
682        let analysis = pass.analyze(&ir);
683
684        let info = analysis.get(&loop_id).expect("loop should be analyzed");
685        assert!(!info.vectorizable, "small loop should not be vectorizable");
686    }
687
688    #[test]
689    fn test_vectorization_transform() {
690        let (mut ir, _loop_id) = make_vectorizable_loop(1024);
691
692        let mut pass = VectorizePass::new(VectorizeConfig::default());
693        pass.analyze(&ir);
694        let report = pass
695            .vectorize(&mut ir)
696            .expect("vectorization should succeed");
697
698        assert!(report.any_vectorized(), "should have vectorized loops");
699        assert_eq!(report.count(), 1, "should have vectorized 1 loop");
700    }
701
702    #[test]
703    fn test_simd_intrinsic_names() {
704        // x86 SSE
705        assert_eq!(
706            SimdIntrinsic::Add.x86_name(ScalarType::F32, 4),
707            "_mm_add_ps"
708        );
709        assert_eq!(
710            SimdIntrinsic::Fmadd.x86_name(ScalarType::F32, 4),
711            "_mm_fmadd_ps"
712        );
713
714        // x86 AVX
715        assert_eq!(
716            SimdIntrinsic::Add.x86_name(ScalarType::F32, 8),
717            "_mm256_add_ps"
718        );
719        assert_eq!(
720            SimdIntrinsic::Hadd.x86_name(ScalarType::F32, 8),
721            "_mm256_hadd_ps"
722        );
723
724        // ARM NEON
725        assert_eq!(SimdIntrinsic::Add.arm_name(ScalarType::F32, 4), "vaddq_f32");
726        assert_eq!(
727            SimdIntrinsic::Fmadd.arm_name(ScalarType::F32, 4),
728            "vfmaq_f32"
729        );
730    }
731
732    #[test]
733    fn test_target_vector_widths() {
734        // AVX should use 8-wide for f32
735        assert_eq!(
736            LoopType::natural_vector_width(ScalarType::F32, TargetArch::X86_64Avx2),
737            8
738        );
739
740        // SSE should use 4-wide for f32
741        assert_eq!(
742            LoopType::natural_vector_width(ScalarType::F32, TargetArch::X86_64Sse2),
743            4
744        );
745
746        // NEON should use 4-wide for f32
747        assert_eq!(
748            LoopType::natural_vector_width(ScalarType::F32, TargetArch::Aarch64Neon),
749            4
750        );
751
752        // AVX should use 4-wide for f64
753        assert_eq!(
754            LoopType::natural_vector_width(ScalarType::F64, TargetArch::X86_64Avx2),
755            4
756        );
757    }
758
759    #[test]
760    fn test_vectorize_report_display() {
761        let report = VectorizeReport {
762            vectorized_loops: vec![VectorizedLoopInfo {
763                loop_id: LoopId::new(0),
764                vector_width: 8,
765                has_fma: true,
766                has_reduction: false,
767            }],
768            failed_loops: vec![],
769        };
770
771        let output = format!("{}", report);
772        assert!(output.contains("Vectorized loops: 1"));
773        assert!(output.contains("width=8"));
774        assert!(output.contains("fma=true"));
775    }
776}