use crate::{
AccessPattern, BinOp, Body, Loop, LoopAttrs, LoopIR, LoopId, LoopMetadata, LoopType, Op,
ScalarType, Stmt, TargetArch, TripCount, Value,
};
use rustc_hash::FxHashMap;
use thiserror::Error;
#[derive(Clone, Debug, Error)]
pub enum VectorizeError {
#[error("loop {loop_id:?} cannot be vectorized: {reason}")]
NotVectorizable {
loop_id: LoopId,
reason: String,
},
#[error("invalid vector width {width} for type {ty:?}")]
InvalidWidth {
width: u8,
ty: ScalarType,
},
}
#[derive(Clone, Debug)]
pub struct VectorizationInfo {
pub vectorizable: bool,
pub reason: Option<String>,
pub recommended_width: u8,
pub access_patterns: Vec<AccessPattern>,
pub has_fma: bool,
pub has_reduction: bool,
}
impl Default for VectorizationInfo {
fn default() -> Self {
Self {
vectorizable: false,
reason: Some("not analyzed".to_string()),
recommended_width: 1,
access_patterns: Vec::new(),
has_fma: false,
has_reduction: false,
}
}
}
#[derive(Clone, Debug)]
pub struct VectorizeConfig {
pub target: TargetArch,
pub forced_width: u8,
pub generate_remainder: bool,
pub enable_fma: bool,
pub min_trip_count: usize,
}
impl Default for VectorizeConfig {
fn default() -> Self {
Self {
target: TargetArch::default(),
forced_width: 0,
generate_remainder: true,
enable_fma: true,
min_trip_count: 4,
}
}
}
pub struct VectorizePass {
config: VectorizeConfig,
analysis: FxHashMap<LoopId, VectorizationInfo>,
}
impl VectorizePass {
pub fn new(config: VectorizeConfig) -> Self {
Self {
config,
analysis: FxHashMap::default(),
}
}
pub fn analyze(&mut self, ir: &LoopIR) -> FxHashMap<LoopId, VectorizationInfo> {
self.analysis.clear();
for stmt in &ir.body.stmts {
self.analyze_stmt(stmt, &ir.loop_info);
}
self.analysis.clone()
}
fn analyze_stmt(&mut self, stmt: &Stmt, loop_info: &[LoopMetadata]) {
if let Stmt::Loop(lp) = stmt {
let info = self.analyze_loop(lp, loop_info);
self.analysis.insert(lp.id, info);
for inner_stmt in &lp.body.stmts {
self.analyze_stmt(inner_stmt, loop_info);
}
}
}
fn analyze_loop(&self, lp: &Loop, loop_info: &[LoopMetadata]) -> VectorizationInfo {
let mut info = VectorizationInfo::default();
if !lp.attrs.contains(LoopAttrs::VECTORIZE) {
info.reason = Some("loop not marked VECTORIZE".to_string());
return info;
}
let metadata = loop_info.iter().find(|m| m.id == lp.id);
let trip_count = metadata.map(|m| &m.trip_count);
match trip_count {
Some(TripCount::Static(n)) if *n < self.config.min_trip_count => {
info.reason = Some(format!(
"trip count {} below threshold {}",
n, self.config.min_trip_count
));
return info;
}
Some(TripCount::Dynamic) => {
}
_ => {}
}
let (patterns, has_fma, has_reduction) = self.analyze_loop_body(&lp.body);
info.access_patterns = patterns.clone();
info.has_fma = has_fma;
info.has_reduction = has_reduction;
let all_sequential = patterns
.iter()
.all(|p| matches!(p, AccessPattern::Sequential | AccessPattern::Broadcast));
if !all_sequential {
info.reason = Some("non-sequential access pattern".to_string());
return info;
}
let elem_type = self.infer_element_type(&lp.body);
let width = if self.config.forced_width > 0 {
self.config.forced_width
} else {
LoopType::natural_vector_width(elem_type, self.config.target)
};
info.vectorizable = width > 1;
info.recommended_width = width;
info.reason = None;
info
}
fn analyze_loop_body(&self, body: &Body) -> (Vec<AccessPattern>, bool, bool) {
let mut patterns = Vec::new();
let mut has_fma = false;
let mut has_reduction = false;
for stmt in &body.stmts {
match stmt {
Stmt::Assign(_, op) => {
if let Op::Load(mem_ref) = op {
patterns.push(mem_ref.access.clone());
}
if self.config.enable_fma {
has_fma |= self.is_fma_opportunity(op);
}
if let Op::VecReduce(_, _) = op {
has_reduction = true;
}
}
Stmt::Store(mem_ref, _) => {
patterns.push(mem_ref.access.clone());
}
Stmt::Loop(inner)
if inner.attrs.contains(LoopAttrs::REDUCTION) => {
has_reduction = true;
}
_ => {}
}
}
(patterns, has_fma, has_reduction)
}
fn is_fma_opportunity(&self, op: &Op) -> bool {
match op {
Op::Binary(BinOp::Add, _, _) => {
false
}
_ => false,
}
}
fn infer_element_type(&self, body: &Body) -> ScalarType {
for stmt in &body.stmts {
if let Stmt::Assign(_, Op::Load(mem_ref)) = stmt {
if let LoopType::Scalar(s) = &mem_ref.elem_ty {
return *s;
}
}
}
ScalarType::Float(32) }
pub fn vectorize(&self, ir: &mut LoopIR) -> Result<VectorizeReport, VectorizeError> {
let mut report = VectorizeReport::default();
for stmt in &mut ir.body.stmts {
self.vectorize_stmt(stmt, &mut ir.loop_info, &mut report)?;
}
Ok(report)
}
fn vectorize_stmt(
&self,
stmt: &mut Stmt,
loop_info: &mut [LoopMetadata],
report: &mut VectorizeReport,
) -> Result<(), VectorizeError> {
if let Stmt::Loop(lp) = stmt {
if let Some(info) = self.analysis.get(&lp.id) {
if info.vectorizable {
self.vectorize_loop(lp, info, loop_info, report)?;
}
}
for inner_stmt in &mut lp.body.stmts {
self.vectorize_stmt(inner_stmt, loop_info, report)?;
}
}
Ok(())
}
fn vectorize_loop(
&self,
lp: &mut Loop,
info: &VectorizationInfo,
loop_info: &mut [LoopMetadata],
report: &mut VectorizeReport,
) -> Result<(), VectorizeError> {
let width = info.recommended_width;
lp.step = Value::i64(width as i64);
if let Some(meta) = loop_info.iter_mut().find(|m| m.id == lp.id) {
meta.vector_width = Some(width);
}
self.vectorize_body(&mut lp.body, width)?;
report.vectorized_loops.push(VectorizedLoopInfo {
loop_id: lp.id,
vector_width: width,
has_fma: info.has_fma,
has_reduction: info.has_reduction,
});
Ok(())
}
fn vectorize_body(&self, body: &mut Body, width: u8) -> Result<(), VectorizeError> {
for stmt in &mut body.stmts {
if let Stmt::Assign(_, op) = stmt {
*op = self.vectorize_op(op, width)?;
}
}
Ok(())
}
fn vectorize_op(&self, op: &Op, width: u8) -> Result<Op, VectorizeError> {
match op {
Op::Load(mem_ref) => {
let mut vec_ref = mem_ref.clone();
if let LoopType::Scalar(s) = &mem_ref.elem_ty {
vec_ref.elem_ty = LoopType::Vector(*s, width);
}
Ok(Op::Load(vec_ref))
}
Op::Binary(bin_op, a, b) => {
let vec_a = self.vectorize_value(a, width);
let vec_b = self.vectorize_value(b, width);
Ok(Op::Binary(*bin_op, vec_a, vec_b))
}
Op::Unary(un_op, a) => {
let vec_a = self.vectorize_value(a, width);
Ok(Op::Unary(*un_op, vec_a))
}
Op::Fma(a, b, c) => {
let vec_a = self.vectorize_value(a, width);
let vec_b = self.vectorize_value(b, width);
let vec_c = self.vectorize_value(c, width);
Ok(Op::Fma(vec_a, vec_b, vec_c))
}
_ => Ok(op.clone()),
}
}
fn vectorize_value(&self, val: &Value, width: u8) -> Value {
match val {
Value::Var(id, LoopType::Scalar(s)) => Value::Var(*id, LoopType::Vector(*s, width)),
Value::FloatConst(f, s) => {
Value::FloatConst(*f, *s)
}
Value::IntConst(i, s) => Value::IntConst(*i, *s),
_ => val.clone(),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct VectorizeReport {
pub vectorized_loops: Vec<VectorizedLoopInfo>,
pub failed_loops: Vec<(LoopId, String)>,
}
impl VectorizeReport {
pub fn any_vectorized(&self) -> bool {
!self.vectorized_loops.is_empty()
}
pub fn count(&self) -> usize {
self.vectorized_loops.len()
}
}
impl std::fmt::Display for VectorizeReport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Vectorization Report")?;
writeln!(f, "====================")?;
writeln!(f, "Vectorized loops: {}", self.vectorized_loops.len())?;
for info in &self.vectorized_loops {
writeln!(
f,
" Loop {:?}: width={}, fma={}, reduction={}",
info.loop_id, info.vector_width, info.has_fma, info.has_reduction
)?;
}
if !self.failed_loops.is_empty() {
writeln!(f, "\nFailed loops: {}", self.failed_loops.len())?;
for (id, reason) in &self.failed_loops {
writeln!(f, " Loop {:?}: {}", id, reason)?;
}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct VectorizedLoopInfo {
pub loop_id: LoopId,
pub vector_width: u8,
pub has_fma: bool,
pub has_reduction: bool,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SimdIntrinsic {
Add,
Sub,
Mul,
Div,
Fmadd,
Fmsub,
Fnmadd,
Hadd,
HorizontalSum,
Min,
Max,
CmpEq,
CmpLt,
CmpLe,
Broadcast,
Extract,
Insert,
Shuffle,
LoadAligned,
LoadUnaligned,
StoreAligned,
StoreUnaligned,
}
impl SimdIntrinsic {
pub fn x86_name(&self, ty: ScalarType, width: u8) -> &'static str {
match (self, ty, width) {
(Self::Add, ScalarType::Float(32), 4) => "_mm_add_ps",
(Self::Sub, ScalarType::Float(32), 4) => "_mm_sub_ps",
(Self::Mul, ScalarType::Float(32), 4) => "_mm_mul_ps",
(Self::Div, ScalarType::Float(32), 4) => "_mm_div_ps",
(Self::Fmadd, ScalarType::Float(32), 4) => "_mm_fmadd_ps",
(Self::Min, ScalarType::Float(32), 4) => "_mm_min_ps",
(Self::Max, ScalarType::Float(32), 4) => "_mm_max_ps",
(Self::LoadAligned, ScalarType::Float(32), 4) => "_mm_load_ps",
(Self::StoreAligned, ScalarType::Float(32), 4) => "_mm_store_ps",
(Self::Add, ScalarType::Float(32), 8) => "_mm256_add_ps",
(Self::Sub, ScalarType::Float(32), 8) => "_mm256_sub_ps",
(Self::Mul, ScalarType::Float(32), 8) => "_mm256_mul_ps",
(Self::Div, ScalarType::Float(32), 8) => "_mm256_div_ps",
(Self::Fmadd, ScalarType::Float(32), 8) => "_mm256_fmadd_ps",
(Self::Min, ScalarType::Float(32), 8) => "_mm256_min_ps",
(Self::Max, ScalarType::Float(32), 8) => "_mm256_max_ps",
(Self::LoadAligned, ScalarType::Float(32), 8) => "_mm256_load_ps",
(Self::StoreAligned, ScalarType::Float(32), 8) => "_mm256_store_ps",
(Self::Hadd, ScalarType::Float(32), 8) => "_mm256_hadd_ps",
(Self::Add, ScalarType::Float(64), 2) => "_mm_add_pd",
(Self::Sub, ScalarType::Float(64), 2) => "_mm_sub_pd",
(Self::Mul, ScalarType::Float(64), 2) => "_mm_mul_pd",
(Self::Fmadd, ScalarType::Float(64), 2) => "_mm_fmadd_pd",
(Self::Add, ScalarType::Float(64), 4) => "_mm256_add_pd",
(Self::Sub, ScalarType::Float(64), 4) => "_mm256_sub_pd",
(Self::Mul, ScalarType::Float(64), 4) => "_mm256_mul_pd",
(Self::Fmadd, ScalarType::Float(64), 4) => "_mm256_fmadd_pd",
_ => "unknown_intrinsic",
}
}
pub fn arm_name(&self, ty: ScalarType, width: u8) -> &'static str {
match (self, ty, width) {
(Self::Add, ScalarType::Float(32), 4) => "vaddq_f32",
(Self::Sub, ScalarType::Float(32), 4) => "vsubq_f32",
(Self::Mul, ScalarType::Float(32), 4) => "vmulq_f32",
(Self::Fmadd, ScalarType::Float(32), 4) => "vfmaq_f32",
(Self::Min, ScalarType::Float(32), 4) => "vminq_f32",
(Self::Max, ScalarType::Float(32), 4) => "vmaxq_f32",
(Self::LoadAligned, ScalarType::Float(32), 4) => "vld1q_f32",
(Self::StoreAligned, ScalarType::Float(32), 4) => "vst1q_f32",
(Self::Add, ScalarType::Float(64), 2) => "vaddq_f64",
(Self::Sub, ScalarType::Float(64), 2) => "vsubq_f64",
(Self::Mul, ScalarType::Float(64), 2) => "vmulq_f64",
(Self::Fmadd, ScalarType::Float(64), 2) => "vfmaq_f64",
_ => "unknown_intrinsic",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{MemRef, Param, ValueId};
use bhc_index::Idx;
use bhc_intern::Symbol;
use bhc_tensor_ir::BufferId;
fn make_vectorizable_loop(trip_count: usize) -> (LoopIR, LoopId) {
let loop_id = LoopId::new(0);
let loop_var = ValueId::new(0);
let mem_ref = MemRef {
buffer: BufferId::new(0),
index: Value::Var(loop_var, LoopType::Scalar(ScalarType::I64)),
elem_ty: LoopType::Scalar(ScalarType::F32),
access: AccessPattern::Sequential,
};
let mut body = Body::new();
let load_result = ValueId::new(1);
body.push(Stmt::Assign(load_result, Op::Load(mem_ref.clone())));
let mul_result = ValueId::new(2);
body.push(Stmt::Assign(
mul_result,
Op::Binary(
BinOp::Mul,
Value::Var(load_result, LoopType::Scalar(ScalarType::F32)),
Value::float(2.0, 32),
),
));
body.push(Stmt::Store(
mem_ref,
Value::Var(mul_result, LoopType::Scalar(ScalarType::F32)),
));
let lp = Loop {
id: loop_id,
var: loop_var,
lower: Value::i64(0),
upper: Value::i64(trip_count as i64),
step: Value::i64(1),
body,
attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT,
};
let mut outer_body = Body::new();
outer_body.push(Stmt::Loop(lp));
let ir = LoopIR {
name: Symbol::intern("test_kernel"),
params: vec![Param {
name: Symbol::intern("data"),
ty: LoopType::Ptr(Box::new(LoopType::Scalar(ScalarType::F32))),
is_ptr: true,
}],
return_ty: LoopType::Void,
body: outer_body,
allocs: vec![],
loop_info: vec![LoopMetadata {
id: loop_id,
trip_count: TripCount::Static(trip_count),
vector_width: None,
parallel_chunk: None,
unroll_factor: None,
dependencies: Vec::new(),
}],
};
(ir, loop_id)
}
#[test]
fn test_vectorization_analysis() {
let (ir, loop_id) = make_vectorizable_loop(1024);
let mut pass = VectorizePass::new(VectorizeConfig::default());
let analysis = pass.analyze(&ir);
let info = analysis.get(&loop_id).expect("loop should be analyzed");
assert!(info.vectorizable, "loop should be vectorizable");
assert!(
info.recommended_width > 1,
"should recommend vector width > 1"
);
}
#[test]
fn test_vectorization_below_threshold() {
let (ir, loop_id) = make_vectorizable_loop(2);
let mut pass = VectorizePass::new(VectorizeConfig::default());
let analysis = pass.analyze(&ir);
let info = analysis.get(&loop_id).expect("loop should be analyzed");
assert!(!info.vectorizable, "small loop should not be vectorizable");
}
#[test]
fn test_vectorization_transform() {
let (mut ir, _loop_id) = make_vectorizable_loop(1024);
let mut pass = VectorizePass::new(VectorizeConfig::default());
pass.analyze(&ir);
let report = pass
.vectorize(&mut ir)
.expect("vectorization should succeed");
assert!(report.any_vectorized(), "should have vectorized loops");
assert_eq!(report.count(), 1, "should have vectorized 1 loop");
}
#[test]
fn test_simd_intrinsic_names() {
assert_eq!(
SimdIntrinsic::Add.x86_name(ScalarType::F32, 4),
"_mm_add_ps"
);
assert_eq!(
SimdIntrinsic::Fmadd.x86_name(ScalarType::F32, 4),
"_mm_fmadd_ps"
);
assert_eq!(
SimdIntrinsic::Add.x86_name(ScalarType::F32, 8),
"_mm256_add_ps"
);
assert_eq!(
SimdIntrinsic::Hadd.x86_name(ScalarType::F32, 8),
"_mm256_hadd_ps"
);
assert_eq!(SimdIntrinsic::Add.arm_name(ScalarType::F32, 4), "vaddq_f32");
assert_eq!(
SimdIntrinsic::Fmadd.arm_name(ScalarType::F32, 4),
"vfmaq_f32"
);
}
#[test]
fn test_target_vector_widths() {
assert_eq!(
LoopType::natural_vector_width(ScalarType::F32, TargetArch::X86_64Avx2),
8
);
assert_eq!(
LoopType::natural_vector_width(ScalarType::F32, TargetArch::X86_64Sse2),
4
);
assert_eq!(
LoopType::natural_vector_width(ScalarType::F32, TargetArch::Aarch64Neon),
4
);
assert_eq!(
LoopType::natural_vector_width(ScalarType::F64, TargetArch::X86_64Avx2),
4
);
}
#[test]
fn test_vectorize_report_display() {
let report = VectorizeReport {
vectorized_loops: vec![VectorizedLoopInfo {
loop_id: LoopId::new(0),
vector_width: 8,
has_fma: true,
has_reduction: false,
}],
failed_loops: vec![],
};
let output = format!("{}", report);
assert!(output.contains("Vectorized loops: 1"));
assert!(output.contains("width=8"));
assert!(output.contains("fma=true"));
}
}