Skip to main content

bhc_loop_ir/
vectorize.rs

1//! # Auto-Vectorization Pass
2//!
3//! This module implements auto-vectorization for Loop IR, transforming scalar
4//! operations to SIMD operations based on loop analysis.
5//!
6//! ## M3 Exit Criteria
7//!
8//! - `matmul` microkernel auto-vectorizes on x86_64 and aarch64
9//! - SIMD intrinsics: add, mul, fmadd, hadd
10//!
11//! ## Vectorization Strategy
12//!
13//! 1. **Analyze loops**: Identify vectorizable innermost loops
14//! 2. **Check access patterns**: Sequential access enables vectorization
15//! 3. **Transform operations**: Scalar → Vector operations
16//! 4. **Handle remainder**: Scalar loop for non-aligned elements
17
18use crate::{
19    AccessPattern, BinOp, Body, Loop, LoopAttrs, LoopIR, LoopId, LoopMetadata, LoopType, Op,
20    ScalarType, Stmt, TargetArch, TripCount, Value,
21};
22use rustc_hash::FxHashMap;
23use thiserror::Error;
24
25/// Errors that can occur during vectorization.
26#[derive(Clone, Debug, Error)]
27pub enum VectorizeError {
28    /// Loop cannot be vectorized.
29    #[error("loop {loop_id:?} cannot be vectorized: {reason}")]
30    NotVectorizable {
31        /// Loop identifier.
32        loop_id: LoopId,
33        /// Reason vectorization failed.
34        reason: String,
35    },
36
37    /// Invalid vector width.
38    #[error("invalid vector width {width} for type {ty:?}")]
39    InvalidWidth {
40        /// Requested width.
41        width: u8,
42        /// Element type.
43        ty: ScalarType,
44    },
45}
46
47/// Result of vectorization analysis for a loop.
48#[derive(Clone, Debug)]
49pub struct VectorizationInfo {
50    /// Whether the loop can be vectorized.
51    pub vectorizable: bool,
52    /// Reason if not vectorizable.
53    pub reason: Option<String>,
54    /// Recommended vector width.
55    pub recommended_width: u8,
56    /// Access patterns in the loop.
57    pub access_patterns: Vec<AccessPattern>,
58    /// Whether FMA opportunities exist.
59    pub has_fma: bool,
60    /// Whether horizontal reduction is needed.
61    pub has_reduction: bool,
62}
63
64impl Default for VectorizationInfo {
65    fn default() -> Self {
66        Self {
67            vectorizable: false,
68            reason: Some("not analyzed".to_string()),
69            recommended_width: 1,
70            access_patterns: Vec::new(),
71            has_fma: false,
72            has_reduction: false,
73        }
74    }
75}
76
77/// Configuration for vectorization.
78#[derive(Clone, Debug)]
79pub struct VectorizeConfig {
80    /// Target architecture.
81    pub target: TargetArch,
82    /// Force a specific vector width (0 = auto).
83    pub forced_width: u8,
84    /// Generate remainder loop for non-aligned iterations.
85    pub generate_remainder: bool,
86    /// Enable FMA fusion.
87    pub enable_fma: bool,
88    /// Minimum trip count for vectorization.
89    pub min_trip_count: usize,
90}
91
92impl Default for VectorizeConfig {
93    fn default() -> Self {
94        Self {
95            target: TargetArch::default(),
96            forced_width: 0,
97            generate_remainder: true,
98            enable_fma: true,
99            min_trip_count: 4,
100        }
101    }
102}
103
104/// Vectorization pass state.
105pub struct VectorizePass {
106    config: VectorizeConfig,
107    /// Analysis results per loop.
108    analysis: FxHashMap<LoopId, VectorizationInfo>,
109}
110
111impl VectorizePass {
112    /// Create a new vectorization pass with the given configuration.
113    pub fn new(config: VectorizeConfig) -> Self {
114        Self {
115            config,
116            analysis: FxHashMap::default(),
117        }
118    }
119
120    /// Analyze a Loop IR function for vectorization opportunities.
121    pub fn analyze(&mut self, ir: &LoopIR) -> FxHashMap<LoopId, VectorizationInfo> {
122        self.analysis.clear();
123
124        for stmt in &ir.body.stmts {
125            self.analyze_stmt(stmt, &ir.loop_info);
126        }
127
128        self.analysis.clone()
129    }
130
131    /// Analyze a statement for vectorization.
132    fn analyze_stmt(&mut self, stmt: &Stmt, loop_info: &[LoopMetadata]) {
133        match stmt {
134            Stmt::Loop(lp) => {
135                let info = self.analyze_loop(lp, loop_info);
136                self.analysis.insert(lp.id, info);
137
138                // Recursively analyze nested loops
139                for inner_stmt in &lp.body.stmts {
140                    self.analyze_stmt(inner_stmt, loop_info);
141                }
142            }
143            _ => {}
144        }
145    }
146
147    /// Analyze a single loop for vectorization.
148    fn analyze_loop(&self, lp: &Loop, loop_info: &[LoopMetadata]) -> VectorizationInfo {
149        let mut info = VectorizationInfo::default();
150
151        // Check if loop is marked as vectorizable
152        if !lp.attrs.contains(LoopAttrs::VECTORIZE) {
153            info.reason = Some("loop not marked VECTORIZE".to_string());
154            return info;
155        }
156
157        // Check trip count
158        let metadata = loop_info.iter().find(|m| m.id == lp.id);
159        let trip_count = metadata.map(|m| &m.trip_count);
160
161        match trip_count {
162            Some(TripCount::Static(n)) if *n < self.config.min_trip_count => {
163                info.reason = Some(format!(
164                    "trip count {} below threshold {}",
165                    n, self.config.min_trip_count
166                ));
167                return info;
168            }
169            Some(TripCount::Dynamic) => {
170                // Dynamic trip count requires runtime check
171                // Still vectorizable with remainder handling
172            }
173            _ => {}
174        }
175
176        // Analyze access patterns in loop body
177        let (patterns, has_fma, has_reduction) = self.analyze_loop_body(&lp.body);
178        info.access_patterns = patterns.clone();
179        info.has_fma = has_fma;
180        info.has_reduction = has_reduction;
181
182        // Check if all accesses are vectorization-friendly
183        let all_sequential = patterns
184            .iter()
185            .all(|p| matches!(p, AccessPattern::Sequential | AccessPattern::Broadcast));
186
187        if !all_sequential {
188            info.reason = Some("non-sequential access pattern".to_string());
189            return info;
190        }
191
192        // Determine vector width
193        let elem_type = self.infer_element_type(&lp.body);
194        let width = if self.config.forced_width > 0 {
195            self.config.forced_width
196        } else {
197            LoopType::natural_vector_width(elem_type, self.config.target)
198        };
199
200        info.vectorizable = width > 1;
201        info.recommended_width = width;
202        info.reason = None;
203
204        info
205    }
206
207    /// Analyze loop body for access patterns, FMA opportunities, and reductions.
208    fn analyze_loop_body(&self, body: &Body) -> (Vec<AccessPattern>, bool, bool) {
209        let mut patterns = Vec::new();
210        let mut has_fma = false;
211        let mut has_reduction = false;
212
213        for stmt in &body.stmts {
214            match stmt {
215                Stmt::Assign(_, op) => {
216                    // Check for load access patterns
217                    if let Op::Load(mem_ref) = op {
218                        patterns.push(mem_ref.access.clone());
219                    }
220
221                    // Check for FMA pattern: a * b + c or a + b * c
222                    if self.config.enable_fma {
223                        has_fma |= self.is_fma_opportunity(op);
224                    }
225
226                    // Check for reduction operations
227                    if let Op::VecReduce(_, _) = op {
228                        has_reduction = true;
229                    }
230                }
231                Stmt::Store(mem_ref, _) => {
232                    patterns.push(mem_ref.access.clone());
233                }
234                Stmt::Loop(inner) => {
235                    // Check if inner loop has reduction attribute
236                    if inner.attrs.contains(LoopAttrs::REDUCTION) {
237                        has_reduction = true;
238                    }
239                }
240                _ => {}
241            }
242        }
243
244        (patterns, has_fma, has_reduction)
245    }
246
247    /// Check if an operation can be replaced with FMA.
248    fn is_fma_opportunity(&self, op: &Op) -> bool {
249        // Pattern: Add(Mul(a, b), c) or Add(c, Mul(a, b))
250        match op {
251            Op::Binary(BinOp::Add, _, _) => {
252                // Would need to check operands are Mul results
253                // For now, return false as we'd need more context
254                false
255            }
256            _ => false,
257        }
258    }
259
260    /// Infer the element type from loop body operations.
261    fn infer_element_type(&self, body: &Body) -> ScalarType {
262        for stmt in &body.stmts {
263            if let Stmt::Assign(_, Op::Load(mem_ref)) = stmt {
264                if let LoopType::Scalar(s) = &mem_ref.elem_ty {
265                    return *s;
266                }
267            }
268        }
269        ScalarType::Float(32) // Default
270    }
271
272    /// Apply vectorization to a Loop IR function.
273    pub fn vectorize(&self, ir: &mut LoopIR) -> Result<VectorizeReport, VectorizeError> {
274        let mut report = VectorizeReport::default();
275
276        for stmt in &mut ir.body.stmts {
277            self.vectorize_stmt(stmt, &mut ir.loop_info, &mut report)?;
278        }
279
280        Ok(report)
281    }
282
283    /// Vectorize a statement.
284    fn vectorize_stmt(
285        &self,
286        stmt: &mut Stmt,
287        loop_info: &mut Vec<LoopMetadata>,
288        report: &mut VectorizeReport,
289    ) -> Result<(), VectorizeError> {
290        match stmt {
291            Stmt::Loop(lp) => {
292                if let Some(info) = self.analysis.get(&lp.id) {
293                    if info.vectorizable {
294                        self.vectorize_loop(lp, info, loop_info, report)?;
295                    }
296                }
297
298                // Recursively vectorize nested loops
299                for inner_stmt in &mut lp.body.stmts {
300                    self.vectorize_stmt(inner_stmt, loop_info, report)?;
301                }
302            }
303            _ => {}
304        }
305        Ok(())
306    }
307
308    /// Vectorize a single loop.
309    fn vectorize_loop(
310        &self,
311        lp: &mut Loop,
312        info: &VectorizationInfo,
313        loop_info: &mut Vec<LoopMetadata>,
314        report: &mut VectorizeReport,
315    ) -> Result<(), VectorizeError> {
316        let width = info.recommended_width;
317
318        // Update loop step
319        lp.step = Value::i64(width as i64);
320
321        // Update loop metadata
322        if let Some(meta) = loop_info.iter_mut().find(|m| m.id == lp.id) {
323            meta.vector_width = Some(width);
324        }
325
326        // Transform operations in loop body to vector operations
327        self.vectorize_body(&mut lp.body, width)?;
328
329        // Record vectorization
330        report.vectorized_loops.push(VectorizedLoopInfo {
331            loop_id: lp.id,
332            vector_width: width,
333            has_fma: info.has_fma,
334            has_reduction: info.has_reduction,
335        });
336
337        Ok(())
338    }
339
340    /// Transform scalar operations in a body to vector operations.
341    fn vectorize_body(&self, body: &mut Body, width: u8) -> Result<(), VectorizeError> {
342        for stmt in &mut body.stmts {
343            match stmt {
344                Stmt::Assign(_, op) => {
345                    *op = self.vectorize_op(op, width)?;
346                }
347                _ => {}
348            }
349        }
350        Ok(())
351    }
352
353    /// Transform a scalar operation to a vector operation.
354    fn vectorize_op(&self, op: &Op, width: u8) -> Result<Op, VectorizeError> {
355        match op {
356            Op::Load(mem_ref) => {
357                // Transform scalar load to vector load
358                let mut vec_ref = mem_ref.clone();
359                if let LoopType::Scalar(s) = &mem_ref.elem_ty {
360                    vec_ref.elem_ty = LoopType::Vector(*s, width);
361                }
362                Ok(Op::Load(vec_ref))
363            }
364
365            Op::Binary(bin_op, a, b) => {
366                // Transform scalar binary op to vector binary op
367                let vec_a = self.vectorize_value(a, width);
368                let vec_b = self.vectorize_value(b, width);
369                Ok(Op::Binary(*bin_op, vec_a, vec_b))
370            }
371
372            Op::Unary(un_op, a) => {
373                let vec_a = self.vectorize_value(a, width);
374                Ok(Op::Unary(*un_op, vec_a))
375            }
376
377            // FMA is naturally vector
378            Op::Fma(a, b, c) => {
379                let vec_a = self.vectorize_value(a, width);
380                let vec_b = self.vectorize_value(b, width);
381                let vec_c = self.vectorize_value(c, width);
382                Ok(Op::Fma(vec_a, vec_b, vec_c))
383            }
384
385            // Keep other operations as-is
386            _ => Ok(op.clone()),
387        }
388    }
389
390    /// Transform a scalar value to a vector value.
391    fn vectorize_value(&self, val: &Value, width: u8) -> Value {
392        match val {
393            Value::Var(id, LoopType::Scalar(s)) => Value::Var(*id, LoopType::Vector(*s, width)),
394            Value::FloatConst(f, s) => {
395                // Scalar constant will be broadcast
396                Value::FloatConst(*f, *s)
397            }
398            Value::IntConst(i, s) => Value::IntConst(*i, *s),
399            _ => val.clone(),
400        }
401    }
402}
403
404/// Report of vectorization results.
405#[derive(Clone, Debug, Default)]
406pub struct VectorizeReport {
407    /// Loops that were vectorized.
408    pub vectorized_loops: Vec<VectorizedLoopInfo>,
409    /// Loops that could not be vectorized.
410    pub failed_loops: Vec<(LoopId, String)>,
411}
412
413impl VectorizeReport {
414    /// Returns true if any loops were vectorized.
415    pub fn any_vectorized(&self) -> bool {
416        !self.vectorized_loops.is_empty()
417    }
418
419    /// Returns the total number of vectorized loops.
420    pub fn count(&self) -> usize {
421        self.vectorized_loops.len()
422    }
423}
424
425impl std::fmt::Display for VectorizeReport {
426    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
427        writeln!(f, "Vectorization Report")?;
428        writeln!(f, "====================")?;
429        writeln!(f, "Vectorized loops: {}", self.vectorized_loops.len())?;
430
431        for info in &self.vectorized_loops {
432            writeln!(
433                f,
434                "  Loop {:?}: width={}, fma={}, reduction={}",
435                info.loop_id, info.vector_width, info.has_fma, info.has_reduction
436            )?;
437        }
438
439        if !self.failed_loops.is_empty() {
440            writeln!(f, "\nFailed loops: {}", self.failed_loops.len())?;
441            for (id, reason) in &self.failed_loops {
442                writeln!(f, "  Loop {:?}: {}", id, reason)?;
443            }
444        }
445
446        Ok(())
447    }
448}
449
450/// Information about a vectorized loop.
451#[derive(Clone, Debug)]
452pub struct VectorizedLoopInfo {
453    /// Loop identifier.
454    pub loop_id: LoopId,
455    /// Vector width used.
456    pub vector_width: u8,
457    /// Whether FMA was used.
458    pub has_fma: bool,
459    /// Whether reduction was needed.
460    pub has_reduction: bool,
461}
462
463// ============================================================================
464// SIMD Intrinsics (M3 Deliverable)
465// ============================================================================
466
467/// SIMD intrinsic operations.
468///
469/// These map directly to hardware SIMD instructions on supported targets.
470#[derive(Clone, Copy, Debug, PartialEq, Eq)]
471pub enum SimdIntrinsic {
472    // --- Arithmetic ---
473    /// Vector add: result[i] = a[i] + b[i]
474    Add,
475    /// Vector subtract: result[i] = a[i] - b[i]
476    Sub,
477    /// Vector multiply: result[i] = a[i] * b[i]
478    Mul,
479    /// Vector divide: result[i] = a[i] / b[i]
480    Div,
481
482    // --- Fused Multiply-Add ---
483    /// Fused multiply-add: result[i] = a[i] * b[i] + c[i]
484    Fmadd,
485    /// Fused multiply-subtract: result[i] = a[i] * b[i] - c[i]
486    Fmsub,
487    /// Fused negative multiply-add: result[i] = -(a[i] * b[i]) + c[i]
488    Fnmadd,
489
490    // --- Horizontal Operations ---
491    /// Horizontal add (pairwise): hadd([a,b,c,d], [e,f,g,h]) = [a+b, c+d, e+f, g+h]
492    Hadd,
493    /// Horizontal sum (reduce all elements): sum([a,b,c,d]) = a+b+c+d
494    HorizontalSum,
495
496    // --- Min/Max ---
497    /// Vector minimum: result[i] = min(a[i], b[i])
498    Min,
499    /// Vector maximum: result[i] = max(a[i], b[i])
500    Max,
501
502    // --- Comparison ---
503    /// Vector compare equal: result[i] = a[i] == b[i] ? ~0 : 0
504    CmpEq,
505    /// Vector compare less than: result[i] = a[i] < b[i] ? ~0 : 0
506    CmpLt,
507    /// Vector compare less or equal: result[i] = a[i] <= b[i] ? ~0 : 0
508    CmpLe,
509
510    // --- Data Movement ---
511    /// Broadcast scalar to all lanes
512    Broadcast,
513    /// Extract element from vector
514    Extract,
515    /// Insert element into vector
516    Insert,
517    /// Shuffle/permute elements
518    Shuffle,
519
520    // --- Load/Store ---
521    /// Aligned load
522    LoadAligned,
523    /// Unaligned load
524    LoadUnaligned,
525    /// Aligned store
526    StoreAligned,
527    /// Unaligned store
528    StoreUnaligned,
529}
530
531impl SimdIntrinsic {
532    /// Returns the x86 intrinsic name for this operation.
533    pub fn x86_name(&self, ty: ScalarType, width: u8) -> &'static str {
534        match (self, ty, width) {
535            // Float32 x 4 (SSE)
536            (Self::Add, ScalarType::Float(32), 4) => "_mm_add_ps",
537            (Self::Sub, ScalarType::Float(32), 4) => "_mm_sub_ps",
538            (Self::Mul, ScalarType::Float(32), 4) => "_mm_mul_ps",
539            (Self::Div, ScalarType::Float(32), 4) => "_mm_div_ps",
540            (Self::Fmadd, ScalarType::Float(32), 4) => "_mm_fmadd_ps",
541            (Self::Min, ScalarType::Float(32), 4) => "_mm_min_ps",
542            (Self::Max, ScalarType::Float(32), 4) => "_mm_max_ps",
543            (Self::LoadAligned, ScalarType::Float(32), 4) => "_mm_load_ps",
544            (Self::StoreAligned, ScalarType::Float(32), 4) => "_mm_store_ps",
545
546            // Float32 x 8 (AVX)
547            (Self::Add, ScalarType::Float(32), 8) => "_mm256_add_ps",
548            (Self::Sub, ScalarType::Float(32), 8) => "_mm256_sub_ps",
549            (Self::Mul, ScalarType::Float(32), 8) => "_mm256_mul_ps",
550            (Self::Div, ScalarType::Float(32), 8) => "_mm256_div_ps",
551            (Self::Fmadd, ScalarType::Float(32), 8) => "_mm256_fmadd_ps",
552            (Self::Min, ScalarType::Float(32), 8) => "_mm256_min_ps",
553            (Self::Max, ScalarType::Float(32), 8) => "_mm256_max_ps",
554            (Self::LoadAligned, ScalarType::Float(32), 8) => "_mm256_load_ps",
555            (Self::StoreAligned, ScalarType::Float(32), 8) => "_mm256_store_ps",
556            (Self::Hadd, ScalarType::Float(32), 8) => "_mm256_hadd_ps",
557
558            // Float64 x 2 (SSE2)
559            (Self::Add, ScalarType::Float(64), 2) => "_mm_add_pd",
560            (Self::Sub, ScalarType::Float(64), 2) => "_mm_sub_pd",
561            (Self::Mul, ScalarType::Float(64), 2) => "_mm_mul_pd",
562            (Self::Fmadd, ScalarType::Float(64), 2) => "_mm_fmadd_pd",
563
564            // Float64 x 4 (AVX)
565            (Self::Add, ScalarType::Float(64), 4) => "_mm256_add_pd",
566            (Self::Sub, ScalarType::Float(64), 4) => "_mm256_sub_pd",
567            (Self::Mul, ScalarType::Float(64), 4) => "_mm256_mul_pd",
568            (Self::Fmadd, ScalarType::Float(64), 4) => "_mm256_fmadd_pd",
569
570            _ => "unknown_intrinsic",
571        }
572    }
573
574    /// Returns the ARM NEON intrinsic name for this operation.
575    pub fn arm_name(&self, ty: ScalarType, width: u8) -> &'static str {
576        match (self, ty, width) {
577            // Float32 x 4 (NEON)
578            (Self::Add, ScalarType::Float(32), 4) => "vaddq_f32",
579            (Self::Sub, ScalarType::Float(32), 4) => "vsubq_f32",
580            (Self::Mul, ScalarType::Float(32), 4) => "vmulq_f32",
581            (Self::Fmadd, ScalarType::Float(32), 4) => "vfmaq_f32",
582            (Self::Min, ScalarType::Float(32), 4) => "vminq_f32",
583            (Self::Max, ScalarType::Float(32), 4) => "vmaxq_f32",
584            (Self::LoadAligned, ScalarType::Float(32), 4) => "vld1q_f32",
585            (Self::StoreAligned, ScalarType::Float(32), 4) => "vst1q_f32",
586
587            // Float64 x 2 (NEON)
588            (Self::Add, ScalarType::Float(64), 2) => "vaddq_f64",
589            (Self::Sub, ScalarType::Float(64), 2) => "vsubq_f64",
590            (Self::Mul, ScalarType::Float(64), 2) => "vmulq_f64",
591            (Self::Fmadd, ScalarType::Float(64), 2) => "vfmaq_f64",
592
593            _ => "unknown_intrinsic",
594        }
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use crate::{MemRef, Param, ValueId};
602    use bhc_index::Idx;
603    use bhc_intern::Symbol;
604    use bhc_tensor_ir::BufferId;
605
606    fn make_vectorizable_loop(trip_count: usize) -> (LoopIR, LoopId) {
607        let loop_id = LoopId::new(0);
608        let loop_var = ValueId::new(0);
609
610        let mem_ref = MemRef {
611            buffer: BufferId::new(0),
612            index: Value::Var(loop_var, LoopType::Scalar(ScalarType::I64)),
613            elem_ty: LoopType::Scalar(ScalarType::F32),
614            access: AccessPattern::Sequential,
615        };
616
617        let mut body = Body::new();
618        let load_result = ValueId::new(1);
619        body.push(Stmt::Assign(load_result, Op::Load(mem_ref.clone())));
620
621        let mul_result = ValueId::new(2);
622        body.push(Stmt::Assign(
623            mul_result,
624            Op::Binary(
625                BinOp::Mul,
626                Value::Var(load_result, LoopType::Scalar(ScalarType::F32)),
627                Value::float(2.0, 32),
628            ),
629        ));
630
631        body.push(Stmt::Store(
632            mem_ref,
633            Value::Var(mul_result, LoopType::Scalar(ScalarType::F32)),
634        ));
635
636        let lp = Loop {
637            id: loop_id,
638            var: loop_var,
639            lower: Value::i64(0),
640            upper: Value::i64(trip_count as i64),
641            step: Value::i64(1),
642            body,
643            attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT,
644        };
645
646        let mut outer_body = Body::new();
647        outer_body.push(Stmt::Loop(lp));
648
649        let ir = LoopIR {
650            name: Symbol::intern("test_kernel"),
651            params: vec![Param {
652                name: Symbol::intern("data"),
653                ty: LoopType::Ptr(Box::new(LoopType::Scalar(ScalarType::F32))),
654                is_ptr: true,
655            }],
656            return_ty: LoopType::Void,
657            body: outer_body,
658            allocs: vec![],
659            loop_info: vec![LoopMetadata {
660                id: loop_id,
661                trip_count: TripCount::Static(trip_count),
662                vector_width: None,
663                parallel_chunk: None,
664                unroll_factor: None,
665                dependencies: Vec::new(),
666            }],
667        };
668
669        (ir, loop_id)
670    }
671
672    #[test]
673    fn test_vectorization_analysis() {
674        let (ir, loop_id) = make_vectorizable_loop(1024);
675
676        let mut pass = VectorizePass::new(VectorizeConfig::default());
677        let analysis = pass.analyze(&ir);
678
679        let info = analysis.get(&loop_id).expect("loop should be analyzed");
680        assert!(info.vectorizable, "loop should be vectorizable");
681        assert!(
682            info.recommended_width > 1,
683            "should recommend vector width > 1"
684        );
685    }
686
687    #[test]
688    fn test_vectorization_below_threshold() {
689        let (ir, loop_id) = make_vectorizable_loop(2); // Below default threshold of 4
690
691        let mut pass = VectorizePass::new(VectorizeConfig::default());
692        let analysis = pass.analyze(&ir);
693
694        let info = analysis.get(&loop_id).expect("loop should be analyzed");
695        assert!(!info.vectorizable, "small loop should not be vectorizable");
696    }
697
698    #[test]
699    fn test_vectorization_transform() {
700        let (mut ir, _loop_id) = make_vectorizable_loop(1024);
701
702        let mut pass = VectorizePass::new(VectorizeConfig::default());
703        pass.analyze(&ir);
704        let report = pass
705            .vectorize(&mut ir)
706            .expect("vectorization should succeed");
707
708        assert!(report.any_vectorized(), "should have vectorized loops");
709        assert_eq!(report.count(), 1, "should have vectorized 1 loop");
710    }
711
712    #[test]
713    fn test_simd_intrinsic_names() {
714        // x86 SSE
715        assert_eq!(
716            SimdIntrinsic::Add.x86_name(ScalarType::F32, 4),
717            "_mm_add_ps"
718        );
719        assert_eq!(
720            SimdIntrinsic::Fmadd.x86_name(ScalarType::F32, 4),
721            "_mm_fmadd_ps"
722        );
723
724        // x86 AVX
725        assert_eq!(
726            SimdIntrinsic::Add.x86_name(ScalarType::F32, 8),
727            "_mm256_add_ps"
728        );
729        assert_eq!(
730            SimdIntrinsic::Hadd.x86_name(ScalarType::F32, 8),
731            "_mm256_hadd_ps"
732        );
733
734        // ARM NEON
735        assert_eq!(SimdIntrinsic::Add.arm_name(ScalarType::F32, 4), "vaddq_f32");
736        assert_eq!(
737            SimdIntrinsic::Fmadd.arm_name(ScalarType::F32, 4),
738            "vfmaq_f32"
739        );
740    }
741
742    #[test]
743    fn test_target_vector_widths() {
744        // AVX should use 8-wide for f32
745        assert_eq!(
746            LoopType::natural_vector_width(ScalarType::F32, TargetArch::X86_64Avx2),
747            8
748        );
749
750        // SSE should use 4-wide for f32
751        assert_eq!(
752            LoopType::natural_vector_width(ScalarType::F32, TargetArch::X86_64Sse2),
753            4
754        );
755
756        // NEON should use 4-wide for f32
757        assert_eq!(
758            LoopType::natural_vector_width(ScalarType::F32, TargetArch::Aarch64Neon),
759            4
760        );
761
762        // AVX should use 4-wide for f64
763        assert_eq!(
764            LoopType::natural_vector_width(ScalarType::F64, TargetArch::X86_64Avx2),
765            4
766        );
767    }
768
769    #[test]
770    fn test_vectorize_report_display() {
771        let report = VectorizeReport {
772            vectorized_loops: vec![VectorizedLoopInfo {
773                loop_id: LoopId::new(0),
774                vector_width: 8,
775                has_fma: true,
776                has_reduction: false,
777            }],
778            failed_loops: vec![],
779        };
780
781        let output = format!("{}", report);
782        assert!(output.contains("Vectorized loops: 1"));
783        assert!(output.contains("width=8"));
784        assert!(output.contains("fma=true"));
785    }
786}