1use crate::{
19 AccessPattern, BinOp, Body, Loop, LoopAttrs, LoopIR, LoopId, LoopMetadata, LoopType, Op,
20 ScalarType, Stmt, TargetArch, TripCount, Value,
21};
22use rustc_hash::FxHashMap;
23use thiserror::Error;
24
25#[derive(Clone, Debug, Error)]
27pub enum VectorizeError {
28 #[error("loop {loop_id:?} cannot be vectorized: {reason}")]
30 NotVectorizable {
31 loop_id: LoopId,
33 reason: String,
35 },
36
37 #[error("invalid vector width {width} for type {ty:?}")]
39 InvalidWidth {
40 width: u8,
42 ty: ScalarType,
44 },
45}
46
47#[derive(Clone, Debug)]
49pub struct VectorizationInfo {
50 pub vectorizable: bool,
52 pub reason: Option<String>,
54 pub recommended_width: u8,
56 pub access_patterns: Vec<AccessPattern>,
58 pub has_fma: bool,
60 pub has_reduction: bool,
62}
63
64impl Default for VectorizationInfo {
65 fn default() -> Self {
66 Self {
67 vectorizable: false,
68 reason: Some("not analyzed".to_string()),
69 recommended_width: 1,
70 access_patterns: Vec::new(),
71 has_fma: false,
72 has_reduction: false,
73 }
74 }
75}
76
77#[derive(Clone, Debug)]
79pub struct VectorizeConfig {
80 pub target: TargetArch,
82 pub forced_width: u8,
84 pub generate_remainder: bool,
86 pub enable_fma: bool,
88 pub min_trip_count: usize,
90}
91
92impl Default for VectorizeConfig {
93 fn default() -> Self {
94 Self {
95 target: TargetArch::default(),
96 forced_width: 0,
97 generate_remainder: true,
98 enable_fma: true,
99 min_trip_count: 4,
100 }
101 }
102}
103
104pub struct VectorizePass {
106 config: VectorizeConfig,
107 analysis: FxHashMap<LoopId, VectorizationInfo>,
109}
110
111impl VectorizePass {
112 pub fn new(config: VectorizeConfig) -> Self {
114 Self {
115 config,
116 analysis: FxHashMap::default(),
117 }
118 }
119
120 pub fn analyze(&mut self, ir: &LoopIR) -> FxHashMap<LoopId, VectorizationInfo> {
122 self.analysis.clear();
123
124 for stmt in &ir.body.stmts {
125 self.analyze_stmt(stmt, &ir.loop_info);
126 }
127
128 self.analysis.clone()
129 }
130
131 fn analyze_stmt(&mut self, stmt: &Stmt, loop_info: &[LoopMetadata]) {
133 match stmt {
134 Stmt::Loop(lp) => {
135 let info = self.analyze_loop(lp, loop_info);
136 self.analysis.insert(lp.id, info);
137
138 for inner_stmt in &lp.body.stmts {
140 self.analyze_stmt(inner_stmt, loop_info);
141 }
142 }
143 _ => {}
144 }
145 }
146
147 fn analyze_loop(&self, lp: &Loop, loop_info: &[LoopMetadata]) -> VectorizationInfo {
149 let mut info = VectorizationInfo::default();
150
151 if !lp.attrs.contains(LoopAttrs::VECTORIZE) {
153 info.reason = Some("loop not marked VECTORIZE".to_string());
154 return info;
155 }
156
157 let metadata = loop_info.iter().find(|m| m.id == lp.id);
159 let trip_count = metadata.map(|m| &m.trip_count);
160
161 match trip_count {
162 Some(TripCount::Static(n)) if *n < self.config.min_trip_count => {
163 info.reason = Some(format!(
164 "trip count {} below threshold {}",
165 n, self.config.min_trip_count
166 ));
167 return info;
168 }
169 Some(TripCount::Dynamic) => {
170 }
173 _ => {}
174 }
175
176 let (patterns, has_fma, has_reduction) = self.analyze_loop_body(&lp.body);
178 info.access_patterns = patterns.clone();
179 info.has_fma = has_fma;
180 info.has_reduction = has_reduction;
181
182 let all_sequential = patterns
184 .iter()
185 .all(|p| matches!(p, AccessPattern::Sequential | AccessPattern::Broadcast));
186
187 if !all_sequential {
188 info.reason = Some("non-sequential access pattern".to_string());
189 return info;
190 }
191
192 let elem_type = self.infer_element_type(&lp.body);
194 let width = if self.config.forced_width > 0 {
195 self.config.forced_width
196 } else {
197 LoopType::natural_vector_width(elem_type, self.config.target)
198 };
199
200 info.vectorizable = width > 1;
201 info.recommended_width = width;
202 info.reason = None;
203
204 info
205 }
206
207 fn analyze_loop_body(&self, body: &Body) -> (Vec<AccessPattern>, bool, bool) {
209 let mut patterns = Vec::new();
210 let mut has_fma = false;
211 let mut has_reduction = false;
212
213 for stmt in &body.stmts {
214 match stmt {
215 Stmt::Assign(_, op) => {
216 if let Op::Load(mem_ref) = op {
218 patterns.push(mem_ref.access.clone());
219 }
220
221 if self.config.enable_fma {
223 has_fma |= self.is_fma_opportunity(op);
224 }
225
226 if let Op::VecReduce(_, _) = op {
228 has_reduction = true;
229 }
230 }
231 Stmt::Store(mem_ref, _) => {
232 patterns.push(mem_ref.access.clone());
233 }
234 Stmt::Loop(inner) => {
235 if inner.attrs.contains(LoopAttrs::REDUCTION) {
237 has_reduction = true;
238 }
239 }
240 _ => {}
241 }
242 }
243
244 (patterns, has_fma, has_reduction)
245 }
246
247 fn is_fma_opportunity(&self, op: &Op) -> bool {
249 match op {
251 Op::Binary(BinOp::Add, _, _) => {
252 false
255 }
256 _ => false,
257 }
258 }
259
260 fn infer_element_type(&self, body: &Body) -> ScalarType {
262 for stmt in &body.stmts {
263 if let Stmt::Assign(_, Op::Load(mem_ref)) = stmt {
264 if let LoopType::Scalar(s) = &mem_ref.elem_ty {
265 return *s;
266 }
267 }
268 }
269 ScalarType::Float(32) }
271
272 pub fn vectorize(&self, ir: &mut LoopIR) -> Result<VectorizeReport, VectorizeError> {
274 let mut report = VectorizeReport::default();
275
276 for stmt in &mut ir.body.stmts {
277 self.vectorize_stmt(stmt, &mut ir.loop_info, &mut report)?;
278 }
279
280 Ok(report)
281 }
282
283 fn vectorize_stmt(
285 &self,
286 stmt: &mut Stmt,
287 loop_info: &mut Vec<LoopMetadata>,
288 report: &mut VectorizeReport,
289 ) -> Result<(), VectorizeError> {
290 match stmt {
291 Stmt::Loop(lp) => {
292 if let Some(info) = self.analysis.get(&lp.id) {
293 if info.vectorizable {
294 self.vectorize_loop(lp, info, loop_info, report)?;
295 }
296 }
297
298 for inner_stmt in &mut lp.body.stmts {
300 self.vectorize_stmt(inner_stmt, loop_info, report)?;
301 }
302 }
303 _ => {}
304 }
305 Ok(())
306 }
307
308 fn vectorize_loop(
310 &self,
311 lp: &mut Loop,
312 info: &VectorizationInfo,
313 loop_info: &mut Vec<LoopMetadata>,
314 report: &mut VectorizeReport,
315 ) -> Result<(), VectorizeError> {
316 let width = info.recommended_width;
317
318 lp.step = Value::i64(width as i64);
320
321 if let Some(meta) = loop_info.iter_mut().find(|m| m.id == lp.id) {
323 meta.vector_width = Some(width);
324 }
325
326 self.vectorize_body(&mut lp.body, width)?;
328
329 report.vectorized_loops.push(VectorizedLoopInfo {
331 loop_id: lp.id,
332 vector_width: width,
333 has_fma: info.has_fma,
334 has_reduction: info.has_reduction,
335 });
336
337 Ok(())
338 }
339
340 fn vectorize_body(&self, body: &mut Body, width: u8) -> Result<(), VectorizeError> {
342 for stmt in &mut body.stmts {
343 match stmt {
344 Stmt::Assign(_, op) => {
345 *op = self.vectorize_op(op, width)?;
346 }
347 _ => {}
348 }
349 }
350 Ok(())
351 }
352
353 fn vectorize_op(&self, op: &Op, width: u8) -> Result<Op, VectorizeError> {
355 match op {
356 Op::Load(mem_ref) => {
357 let mut vec_ref = mem_ref.clone();
359 if let LoopType::Scalar(s) = &mem_ref.elem_ty {
360 vec_ref.elem_ty = LoopType::Vector(*s, width);
361 }
362 Ok(Op::Load(vec_ref))
363 }
364
365 Op::Binary(bin_op, a, b) => {
366 let vec_a = self.vectorize_value(a, width);
368 let vec_b = self.vectorize_value(b, width);
369 Ok(Op::Binary(*bin_op, vec_a, vec_b))
370 }
371
372 Op::Unary(un_op, a) => {
373 let vec_a = self.vectorize_value(a, width);
374 Ok(Op::Unary(*un_op, vec_a))
375 }
376
377 Op::Fma(a, b, c) => {
379 let vec_a = self.vectorize_value(a, width);
380 let vec_b = self.vectorize_value(b, width);
381 let vec_c = self.vectorize_value(c, width);
382 Ok(Op::Fma(vec_a, vec_b, vec_c))
383 }
384
385 _ => Ok(op.clone()),
387 }
388 }
389
390 fn vectorize_value(&self, val: &Value, width: u8) -> Value {
392 match val {
393 Value::Var(id, LoopType::Scalar(s)) => Value::Var(*id, LoopType::Vector(*s, width)),
394 Value::FloatConst(f, s) => {
395 Value::FloatConst(*f, *s)
397 }
398 Value::IntConst(i, s) => Value::IntConst(*i, *s),
399 _ => val.clone(),
400 }
401 }
402}
403
404#[derive(Clone, Debug, Default)]
406pub struct VectorizeReport {
407 pub vectorized_loops: Vec<VectorizedLoopInfo>,
409 pub failed_loops: Vec<(LoopId, String)>,
411}
412
413impl VectorizeReport {
414 pub fn any_vectorized(&self) -> bool {
416 !self.vectorized_loops.is_empty()
417 }
418
419 pub fn count(&self) -> usize {
421 self.vectorized_loops.len()
422 }
423}
424
425impl std::fmt::Display for VectorizeReport {
426 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
427 writeln!(f, "Vectorization Report")?;
428 writeln!(f, "====================")?;
429 writeln!(f, "Vectorized loops: {}", self.vectorized_loops.len())?;
430
431 for info in &self.vectorized_loops {
432 writeln!(
433 f,
434 " Loop {:?}: width={}, fma={}, reduction={}",
435 info.loop_id, info.vector_width, info.has_fma, info.has_reduction
436 )?;
437 }
438
439 if !self.failed_loops.is_empty() {
440 writeln!(f, "\nFailed loops: {}", self.failed_loops.len())?;
441 for (id, reason) in &self.failed_loops {
442 writeln!(f, " Loop {:?}: {}", id, reason)?;
443 }
444 }
445
446 Ok(())
447 }
448}
449
450#[derive(Clone, Debug)]
452pub struct VectorizedLoopInfo {
453 pub loop_id: LoopId,
455 pub vector_width: u8,
457 pub has_fma: bool,
459 pub has_reduction: bool,
461}
462
463#[derive(Clone, Copy, Debug, PartialEq, Eq)]
471pub enum SimdIntrinsic {
472 Add,
475 Sub,
477 Mul,
479 Div,
481
482 Fmadd,
485 Fmsub,
487 Fnmadd,
489
490 Hadd,
493 HorizontalSum,
495
496 Min,
499 Max,
501
502 CmpEq,
505 CmpLt,
507 CmpLe,
509
510 Broadcast,
513 Extract,
515 Insert,
517 Shuffle,
519
520 LoadAligned,
523 LoadUnaligned,
525 StoreAligned,
527 StoreUnaligned,
529}
530
531impl SimdIntrinsic {
532 pub fn x86_name(&self, ty: ScalarType, width: u8) -> &'static str {
534 match (self, ty, width) {
535 (Self::Add, ScalarType::Float(32), 4) => "_mm_add_ps",
537 (Self::Sub, ScalarType::Float(32), 4) => "_mm_sub_ps",
538 (Self::Mul, ScalarType::Float(32), 4) => "_mm_mul_ps",
539 (Self::Div, ScalarType::Float(32), 4) => "_mm_div_ps",
540 (Self::Fmadd, ScalarType::Float(32), 4) => "_mm_fmadd_ps",
541 (Self::Min, ScalarType::Float(32), 4) => "_mm_min_ps",
542 (Self::Max, ScalarType::Float(32), 4) => "_mm_max_ps",
543 (Self::LoadAligned, ScalarType::Float(32), 4) => "_mm_load_ps",
544 (Self::StoreAligned, ScalarType::Float(32), 4) => "_mm_store_ps",
545
546 (Self::Add, ScalarType::Float(32), 8) => "_mm256_add_ps",
548 (Self::Sub, ScalarType::Float(32), 8) => "_mm256_sub_ps",
549 (Self::Mul, ScalarType::Float(32), 8) => "_mm256_mul_ps",
550 (Self::Div, ScalarType::Float(32), 8) => "_mm256_div_ps",
551 (Self::Fmadd, ScalarType::Float(32), 8) => "_mm256_fmadd_ps",
552 (Self::Min, ScalarType::Float(32), 8) => "_mm256_min_ps",
553 (Self::Max, ScalarType::Float(32), 8) => "_mm256_max_ps",
554 (Self::LoadAligned, ScalarType::Float(32), 8) => "_mm256_load_ps",
555 (Self::StoreAligned, ScalarType::Float(32), 8) => "_mm256_store_ps",
556 (Self::Hadd, ScalarType::Float(32), 8) => "_mm256_hadd_ps",
557
558 (Self::Add, ScalarType::Float(64), 2) => "_mm_add_pd",
560 (Self::Sub, ScalarType::Float(64), 2) => "_mm_sub_pd",
561 (Self::Mul, ScalarType::Float(64), 2) => "_mm_mul_pd",
562 (Self::Fmadd, ScalarType::Float(64), 2) => "_mm_fmadd_pd",
563
564 (Self::Add, ScalarType::Float(64), 4) => "_mm256_add_pd",
566 (Self::Sub, ScalarType::Float(64), 4) => "_mm256_sub_pd",
567 (Self::Mul, ScalarType::Float(64), 4) => "_mm256_mul_pd",
568 (Self::Fmadd, ScalarType::Float(64), 4) => "_mm256_fmadd_pd",
569
570 _ => "unknown_intrinsic",
571 }
572 }
573
574 pub fn arm_name(&self, ty: ScalarType, width: u8) -> &'static str {
576 match (self, ty, width) {
577 (Self::Add, ScalarType::Float(32), 4) => "vaddq_f32",
579 (Self::Sub, ScalarType::Float(32), 4) => "vsubq_f32",
580 (Self::Mul, ScalarType::Float(32), 4) => "vmulq_f32",
581 (Self::Fmadd, ScalarType::Float(32), 4) => "vfmaq_f32",
582 (Self::Min, ScalarType::Float(32), 4) => "vminq_f32",
583 (Self::Max, ScalarType::Float(32), 4) => "vmaxq_f32",
584 (Self::LoadAligned, ScalarType::Float(32), 4) => "vld1q_f32",
585 (Self::StoreAligned, ScalarType::Float(32), 4) => "vst1q_f32",
586
587 (Self::Add, ScalarType::Float(64), 2) => "vaddq_f64",
589 (Self::Sub, ScalarType::Float(64), 2) => "vsubq_f64",
590 (Self::Mul, ScalarType::Float(64), 2) => "vmulq_f64",
591 (Self::Fmadd, ScalarType::Float(64), 2) => "vfmaq_f64",
592
593 _ => "unknown_intrinsic",
594 }
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use crate::{MemRef, Param, ValueId};
602 use bhc_index::Idx;
603 use bhc_intern::Symbol;
604 use bhc_tensor_ir::BufferId;
605
606 fn make_vectorizable_loop(trip_count: usize) -> (LoopIR, LoopId) {
607 let loop_id = LoopId::new(0);
608 let loop_var = ValueId::new(0);
609
610 let mem_ref = MemRef {
611 buffer: BufferId::new(0),
612 index: Value::Var(loop_var, LoopType::Scalar(ScalarType::I64)),
613 elem_ty: LoopType::Scalar(ScalarType::F32),
614 access: AccessPattern::Sequential,
615 };
616
617 let mut body = Body::new();
618 let load_result = ValueId::new(1);
619 body.push(Stmt::Assign(load_result, Op::Load(mem_ref.clone())));
620
621 let mul_result = ValueId::new(2);
622 body.push(Stmt::Assign(
623 mul_result,
624 Op::Binary(
625 BinOp::Mul,
626 Value::Var(load_result, LoopType::Scalar(ScalarType::F32)),
627 Value::float(2.0, 32),
628 ),
629 ));
630
631 body.push(Stmt::Store(
632 mem_ref,
633 Value::Var(mul_result, LoopType::Scalar(ScalarType::F32)),
634 ));
635
636 let lp = Loop {
637 id: loop_id,
638 var: loop_var,
639 lower: Value::i64(0),
640 upper: Value::i64(trip_count as i64),
641 step: Value::i64(1),
642 body,
643 attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT,
644 };
645
646 let mut outer_body = Body::new();
647 outer_body.push(Stmt::Loop(lp));
648
649 let ir = LoopIR {
650 name: Symbol::intern("test_kernel"),
651 params: vec![Param {
652 name: Symbol::intern("data"),
653 ty: LoopType::Ptr(Box::new(LoopType::Scalar(ScalarType::F32))),
654 is_ptr: true,
655 }],
656 return_ty: LoopType::Void,
657 body: outer_body,
658 allocs: vec![],
659 loop_info: vec![LoopMetadata {
660 id: loop_id,
661 trip_count: TripCount::Static(trip_count),
662 vector_width: None,
663 parallel_chunk: None,
664 unroll_factor: None,
665 dependencies: Vec::new(),
666 }],
667 };
668
669 (ir, loop_id)
670 }
671
672 #[test]
673 fn test_vectorization_analysis() {
674 let (ir, loop_id) = make_vectorizable_loop(1024);
675
676 let mut pass = VectorizePass::new(VectorizeConfig::default());
677 let analysis = pass.analyze(&ir);
678
679 let info = analysis.get(&loop_id).expect("loop should be analyzed");
680 assert!(info.vectorizable, "loop should be vectorizable");
681 assert!(
682 info.recommended_width > 1,
683 "should recommend vector width > 1"
684 );
685 }
686
687 #[test]
688 fn test_vectorization_below_threshold() {
689 let (ir, loop_id) = make_vectorizable_loop(2); let mut pass = VectorizePass::new(VectorizeConfig::default());
692 let analysis = pass.analyze(&ir);
693
694 let info = analysis.get(&loop_id).expect("loop should be analyzed");
695 assert!(!info.vectorizable, "small loop should not be vectorizable");
696 }
697
698 #[test]
699 fn test_vectorization_transform() {
700 let (mut ir, _loop_id) = make_vectorizable_loop(1024);
701
702 let mut pass = VectorizePass::new(VectorizeConfig::default());
703 pass.analyze(&ir);
704 let report = pass
705 .vectorize(&mut ir)
706 .expect("vectorization should succeed");
707
708 assert!(report.any_vectorized(), "should have vectorized loops");
709 assert_eq!(report.count(), 1, "should have vectorized 1 loop");
710 }
711
712 #[test]
713 fn test_simd_intrinsic_names() {
714 assert_eq!(
716 SimdIntrinsic::Add.x86_name(ScalarType::F32, 4),
717 "_mm_add_ps"
718 );
719 assert_eq!(
720 SimdIntrinsic::Fmadd.x86_name(ScalarType::F32, 4),
721 "_mm_fmadd_ps"
722 );
723
724 assert_eq!(
726 SimdIntrinsic::Add.x86_name(ScalarType::F32, 8),
727 "_mm256_add_ps"
728 );
729 assert_eq!(
730 SimdIntrinsic::Hadd.x86_name(ScalarType::F32, 8),
731 "_mm256_hadd_ps"
732 );
733
734 assert_eq!(SimdIntrinsic::Add.arm_name(ScalarType::F32, 4), "vaddq_f32");
736 assert_eq!(
737 SimdIntrinsic::Fmadd.arm_name(ScalarType::F32, 4),
738 "vfmaq_f32"
739 );
740 }
741
742 #[test]
743 fn test_target_vector_widths() {
744 assert_eq!(
746 LoopType::natural_vector_width(ScalarType::F32, TargetArch::X86_64Avx2),
747 8
748 );
749
750 assert_eq!(
752 LoopType::natural_vector_width(ScalarType::F32, TargetArch::X86_64Sse2),
753 4
754 );
755
756 assert_eq!(
758 LoopType::natural_vector_width(ScalarType::F32, TargetArch::Aarch64Neon),
759 4
760 );
761
762 assert_eq!(
764 LoopType::natural_vector_width(ScalarType::F64, TargetArch::X86_64Avx2),
765 4
766 );
767 }
768
769 #[test]
770 fn test_vectorize_report_display() {
771 let report = VectorizeReport {
772 vectorized_loops: vec![VectorizedLoopInfo {
773 loop_id: LoopId::new(0),
774 vector_width: 8,
775 has_fma: true,
776 has_reduction: false,
777 }],
778 failed_loops: vec![],
779 };
780
781 let output = format!("{}", report);
782 assert!(output.contains("Vectorized loops: 1"));
783 assert!(output.contains("width=8"));
784 assert!(output.contains("fma=true"));
785 }
786}