#![warn(missing_docs)]
#![allow(clippy::module_name_repetitions)]
use bhc_index::Idx;
use bhc_intern::Symbol;
use bhc_tensor_ir::{AllocRegion, BufferId, DType};
use bitflags::bitflags;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
pub mod lower;
pub mod parallel;
pub mod vectorize;
pub use lower::{lower_kernel, lower_kernels, LowerConfig, LowerError};
pub use parallel::{
ParFor, ParMap, ParReduce, ParallelConfig, ParallelPass, ParallelStrategy, Range,
};
pub use vectorize::{SimdIntrinsic, VectorizeConfig, VectorizePass, VectorizeReport};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ValueId(u32);
impl Idx for ValueId {
fn new(idx: usize) -> Self {
Self(idx as u32)
}
fn index(self) -> usize {
self.0 as usize
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct LoopId(u32);
impl Idx for LoopId {
fn new(idx: usize) -> Self {
Self(idx as u32)
}
fn index(self) -> usize {
self.0 as usize
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct BlockId(u32);
impl Idx for BlockId {
fn new(idx: usize) -> Self {
Self(idx as u32)
}
fn index(self) -> usize {
self.0 as usize
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LoopIR {
pub name: Symbol,
pub params: Vec<Param>,
pub return_ty: LoopType,
pub body: Body,
pub allocs: Vec<Alloc>,
pub loop_info: Vec<LoopMetadata>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Param {
pub name: Symbol,
pub ty: LoopType,
pub is_ptr: bool,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum LoopType {
Void,
Scalar(ScalarType),
Vector(ScalarType, u8),
Ptr(Box<LoopType>),
}
impl LoopType {
#[must_use]
pub fn size_bytes(&self) -> usize {
match self {
Self::Void => 0,
Self::Scalar(s) => s.size_bytes(),
Self::Vector(s, width) => s.size_bytes() * (*width as usize),
Self::Ptr(_) => 8, }
}
#[must_use]
pub fn is_void(&self) -> bool {
matches!(self, Self::Void)
}
#[must_use]
pub fn is_vector(&self) -> bool {
matches!(self, Self::Vector(_, _))
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ScalarType {
Bool,
Int(u8),
UInt(u8),
Float(u8),
}
impl ScalarType {
#[must_use]
pub const fn size_bytes(self) -> usize {
match self {
Self::Bool => 1,
Self::Int(bits) | Self::UInt(bits) | Self::Float(bits) => (bits as usize).div_ceil(8),
}
}
#[must_use]
pub fn from_dtype(dtype: DType) -> Self {
match dtype {
DType::Bool => Self::Bool,
DType::Int8 => Self::Int(8),
DType::Int16 => Self::Int(16),
DType::Int32 => Self::Int(32),
DType::Int64 => Self::Int(64),
DType::UInt8 => Self::UInt(8),
DType::UInt16 => Self::UInt(16),
DType::UInt32 => Self::UInt(32),
DType::UInt64 => Self::UInt(64),
DType::Float16 | DType::BFloat16 => Self::Float(16),
DType::Float32 => Self::Float(32),
DType::Float64 => Self::Float(64),
DType::Complex64 => Self::Float(32), DType::Complex128 => Self::Float(64),
}
}
pub const F32: Self = Self::Float(32);
pub const F64: Self = Self::Float(64);
pub const I32: Self = Self::Int(32);
pub const I64: Self = Self::Int(64);
}
impl LoopType {
pub const VEC4F32: Self = Self::Vector(ScalarType::F32, 4);
pub const VEC8F32: Self = Self::Vector(ScalarType::F32, 8);
pub const VEC2F64: Self = Self::Vector(ScalarType::F64, 2);
pub const VEC4F64: Self = Self::Vector(ScalarType::F64, 4);
pub const VEC4I32: Self = Self::Vector(ScalarType::I32, 4);
pub const VEC8I32: Self = Self::Vector(ScalarType::I32, 8);
#[must_use]
pub fn natural_vector_width(scalar: ScalarType, target: TargetArch) -> u8 {
match (target, scalar) {
(TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Float(32)) => 8,
(TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Float(64)) => 4,
(TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Int(32)) => 8,
(TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Float(32)) => 4,
(TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Float(64)) => 2,
(TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Int(32)) => 4,
(TargetArch::Aarch64Neon, ScalarType::Float(32)) => 4,
(TargetArch::Aarch64Neon, ScalarType::Float(64)) => 2,
(TargetArch::Aarch64Neon, ScalarType::Int(32)) => 4,
_ => 1,
}
}
#[must_use]
pub const fn vector(scalar: ScalarType, width: u8) -> Self {
Self::Vector(scalar, width)
}
#[must_use]
pub fn vector_width(&self) -> Option<u8> {
match self {
Self::Vector(_, w) => Some(*w),
_ => None,
}
}
#[must_use]
pub fn element_type(&self) -> Option<ScalarType> {
match self {
Self::Vector(s, _) => Some(*s),
Self::Scalar(s) => Some(*s),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum TargetArch {
X86_64Sse,
X86_64Sse2,
X86_64Avx,
X86_64Avx2,
#[default]
Aarch64Neon,
Generic,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Alloc {
pub buffer: BufferId,
pub name: Symbol,
pub elem_ty: ScalarType,
pub size: AllocSize,
pub alignment: usize,
pub region: AllocRegion,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum AllocSize {
Static(usize),
Dynamic(ValueId),
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct Body {
pub stmts: Vec<Stmt>,
}
impl Body {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, stmt: Stmt) {
self.stmts.push(stmt);
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum Stmt {
Assign(ValueId, Op),
Loop(Loop),
If(IfStmt),
Store(MemRef, Value),
Call(Option<ValueId>, Symbol, Vec<Value>),
Return(Option<Value>),
Barrier(BarrierKind),
Comment(String),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Loop {
pub id: LoopId,
pub var: ValueId,
pub lower: Value,
pub upper: Value,
pub step: Value,
pub body: Body,
pub attrs: LoopAttrs,
}
bitflags! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct LoopAttrs: u32 {
const PARALLEL = 0b0000_0001;
const VECTORIZE = 0b0000_0010;
const UNROLL = 0b0000_0100;
const REDUCTION = 0b0000_1000;
const INDEPENDENT = 0b0001_0000;
const TILED = 0b0010_0000;
const TILE_INNER = 0b0100_0000;
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LoopMetadata {
pub id: LoopId,
pub trip_count: TripCount,
pub vector_width: Option<u8>,
pub parallel_chunk: Option<usize>,
pub unroll_factor: Option<u8>,
pub dependencies: Vec<LoopDependency>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum TripCount {
Static(usize),
Dynamic,
Bounded(usize),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LoopDependency {
pub source: LoopId,
pub target: LoopId,
pub kind: DependencyKind,
pub distance: Option<Vec<i32>>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum DependencyKind {
Flow,
Anti,
Output,
Input,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct IfStmt {
pub cond: Value,
pub then_body: Body,
pub else_body: Option<Body>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum Value {
Var(ValueId, LoopType),
IntConst(i64, ScalarType),
FloatConst(f64, ScalarType),
BoolConst(bool),
Undef(LoopType),
}
impl Value {
#[must_use]
pub fn ty(&self) -> LoopType {
match self {
Self::Var(_, ty) => ty.clone(),
Self::IntConst(_, s) => LoopType::Scalar(*s),
Self::FloatConst(_, s) => LoopType::Scalar(*s),
Self::BoolConst(_) => LoopType::Scalar(ScalarType::Bool),
Self::Undef(ty) => ty.clone(),
}
}
#[must_use]
pub fn int(n: i64, bits: u8) -> Self {
Self::IntConst(n, ScalarType::Int(bits))
}
#[must_use]
pub fn i64(n: i64) -> Self {
Self::int(n, 64)
}
#[must_use]
pub fn float(f: f64, bits: u8) -> Self {
Self::FloatConst(f, ScalarType::Float(bits))
}
#[must_use]
pub fn f64(f: f64) -> Self {
Self::float(f, 64)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum Op {
Load(MemRef),
Binary(BinOp, Value, Value),
Unary(UnOp, Value),
Cmp(CmpOp, Value, Value),
Select(Value, Value, Value),
Cast(Value, LoopType),
Broadcast(Value, u8),
Extract(Value, u8),
Insert(Value, Value, u8),
Shuffle(Value, Value, Vec<i32>),
VecReduce(ReduceOp, Value),
Fma(Value, Value, Value),
PtrAdd(Value, Value),
GetPtr(BufferId, Value),
Phi(Vec<(BlockId, Value)>),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BinOp {
Add,
Sub,
Mul,
SDiv,
UDiv,
FDiv,
SRem,
URem,
FRem,
And,
Or,
Xor,
Shl,
LShr,
AShr,
SMin,
UMin,
FMin,
SMax,
UMax,
FMax,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum UnOp {
Neg,
FNeg,
Not,
Abs,
FAbs,
Sqrt,
Rsqrt,
Floor,
Ceil,
Round,
Trunc,
Exp,
Log,
Sin,
Cos,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CmpOp {
Eq,
Ne,
SLt,
SLe,
SGt,
SGe,
ULt,
ULe,
UGt,
UGe,
OEq,
ONe,
OLt,
OLe,
OGt,
OGe,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ReduceOp {
Add,
Mul,
Min,
Max,
And,
Or,
Xor,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct MemRef {
pub buffer: BufferId,
pub index: Value,
pub elem_ty: LoopType,
pub access: AccessPattern,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum AccessPattern {
Sequential,
Strided(i64),
Random,
Broadcast,
Affine(AffineAccess),
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct AffineAccess {
pub coefficients: SmallVec<[(LoopId, i64); 4]>,
pub offset: i64,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum BarrierKind {
MemFence,
Full,
ThreadGroup,
}
#[derive(Clone, Debug, thiserror::Error, Serialize, Deserialize)]
pub enum LoopIrError {
#[error("type mismatch: expected {expected:?}, got {got:?}")]
TypeMismatch {
expected: LoopType,
got: LoopType,
},
#[error("invalid vector width {width} for type {ty:?}")]
InvalidVectorWidth {
width: u8,
ty: ScalarType,
},
#[error("buffer access out of bounds")]
OutOfBounds,
#[error("invalid loop transformation: {reason}")]
InvalidTransform {
reason: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_type_sizes() {
assert_eq!(ScalarType::Bool.size_bytes(), 1);
assert_eq!(ScalarType::Int(32).size_bytes(), 4);
assert_eq!(ScalarType::Float(64).size_bytes(), 8);
}
#[test]
fn test_loop_type_size() {
assert_eq!(LoopType::Scalar(ScalarType::Float(32)).size_bytes(), 4);
assert_eq!(LoopType::Vector(ScalarType::Float(32), 8).size_bytes(), 32);
}
#[test]
fn test_value_types() {
let v = Value::i64(42);
assert_eq!(v.ty(), LoopType::Scalar(ScalarType::Int(64)));
let f = Value::f64(2.5);
assert_eq!(f.ty(), LoopType::Scalar(ScalarType::Float(64)));
}
#[test]
fn test_loop_attrs() {
let attrs = LoopAttrs::PARALLEL | LoopAttrs::VECTORIZE;
assert!(attrs.contains(LoopAttrs::PARALLEL));
assert!(attrs.contains(LoopAttrs::VECTORIZE));
assert!(!attrs.contains(LoopAttrs::UNROLL));
}
#[test]
fn test_trip_count() {
let static_trip = TripCount::Static(100);
assert_eq!(static_trip, TripCount::Static(100));
let dynamic_trip = TripCount::Dynamic;
assert_eq!(dynamic_trip, TripCount::Dynamic);
}
#[test]
fn test_m3_matmul_auto_vectorizes() {
use crate::vectorize::{VectorizeConfig, VectorizePass};
use bhc_index::Idx;
use bhc_tensor_ir::BufferId;
let loop_id = LoopId::new(0);
let loop_var = ValueId::new(0);
let mem_ref = MemRef {
buffer: BufferId::new(0),
index: Value::Var(loop_var, LoopType::Scalar(ScalarType::I64)),
elem_ty: LoopType::Scalar(ScalarType::F32),
access: AccessPattern::Sequential,
};
let mut body = Body::new();
let load_a = ValueId::new(1);
body.push(Stmt::Assign(load_a, Op::Load(mem_ref.clone())));
let load_b = ValueId::new(2);
body.push(Stmt::Assign(load_b, Op::Load(mem_ref.clone())));
let mul_result = ValueId::new(3);
body.push(Stmt::Assign(
mul_result,
Op::Binary(
BinOp::Mul,
Value::Var(load_a, LoopType::Scalar(ScalarType::F32)),
Value::Var(load_b, LoopType::Scalar(ScalarType::F32)),
),
));
let acc = ValueId::new(4);
let fma_result = ValueId::new(5);
body.push(Stmt::Assign(
fma_result,
Op::Fma(
Value::Var(load_a, LoopType::Scalar(ScalarType::F32)),
Value::Var(load_b, LoopType::Scalar(ScalarType::F32)),
Value::Var(acc, LoopType::Scalar(ScalarType::F32)),
),
));
let lp = Loop {
id: loop_id,
var: loop_var,
lower: Value::i64(0),
upper: Value::i64(256), step: Value::i64(1),
body,
attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT,
};
let mut outer_body = Body::new();
outer_body.push(Stmt::Loop(lp));
let ir = LoopIR {
name: bhc_intern::Symbol::intern("matmul_kernel"),
params: vec![],
return_ty: LoopType::Void,
body: outer_body,
allocs: vec![],
loop_info: vec![LoopMetadata {
id: loop_id,
trip_count: TripCount::Static(256),
vector_width: None,
parallel_chunk: None,
unroll_factor: None,
dependencies: Vec::new(),
}],
};
let config_x86 = VectorizeConfig {
target: TargetArch::X86_64Avx2,
..Default::default()
};
let mut pass_x86 = VectorizePass::new(config_x86);
let analysis_x86 = pass_x86.analyze(&ir);
let info_x86 = analysis_x86.get(&loop_id).expect("loop should be analyzed");
assert!(
info_x86.vectorizable,
"M3 FAIL: matmul kernel not vectorizable on x86_64 AVX2"
);
assert_eq!(
info_x86.recommended_width, 8,
"M3 FAIL: x86_64 AVX2 should use 8-wide vectors for f32"
);
let config_arm = VectorizeConfig {
target: TargetArch::Aarch64Neon,
..Default::default()
};
let mut pass_arm = VectorizePass::new(config_arm);
let analysis_arm = pass_arm.analyze(&ir);
let info_arm = analysis_arm.get(&loop_id).expect("loop should be analyzed");
assert!(
info_arm.vectorizable,
"M3 FAIL: matmul kernel not vectorizable on aarch64 NEON"
);
assert_eq!(
info_arm.recommended_width, 4,
"M3 FAIL: aarch64 NEON should use 4-wide vectors for f32"
);
}
#[test]
fn test_m3_reductions_scale_linearly() {
use crate::parallel::{ParReduce, ParallelConfig, Range};
use crate::ReduceOp;
let data_size = 1_000_000;
for worker_count in [1, 2, 4, 8] {
let config = ParallelConfig {
worker_count,
deterministic: true,
..Default::default()
};
let par_reduce = ParReduce {
size: data_size,
op: ReduceOp::Add,
config,
};
let chunks = par_reduce.chunk_assignments();
assert_eq!(
chunks.len(),
worker_count,
"M3 FAIL: Expected {} chunks for {} workers",
worker_count,
worker_count
);
let total_work: usize = chunks.iter().map(|c| c.len()).sum();
assert_eq!(
total_work, data_size,
"M3 FAIL: Total work should equal data size"
);
let expected_per_worker = data_size / worker_count;
for (i, chunk) in chunks.iter().enumerate() {
let diff = (chunk.len() as i64 - expected_per_worker as i64).abs();
assert!(
diff <= 1,
"M3 FAIL: Worker {} has {} elements, expected ~{} (diff={})",
i,
chunk.len(),
expected_per_worker,
diff
);
}
}
let _config_4 = ParallelConfig {
worker_count: 4,
..Default::default()
};
let chunks_4 = Range::new(0, data_size as i64).chunk(4);
let _config_8 = ParallelConfig {
worker_count: 8,
..Default::default()
};
let chunks_8 = Range::new(0, data_size as i64).chunk(8);
let avg_chunk_4: usize = chunks_4.iter().map(|c| c.len()).sum::<usize>() / 4;
let avg_chunk_8: usize = chunks_8.iter().map(|c| c.len()).sum::<usize>() / 8;
let ratio = avg_chunk_4 as f64 / avg_chunk_8 as f64;
assert!(
(ratio - 2.0).abs() < 0.1,
"M3 FAIL: Chunk size ratio should be ~2.0, got {}",
ratio
);
}
#[test]
fn test_m3_deterministic_mode() {
use crate::parallel::{ParReduce, ParallelConfig, ParallelStrategy};
use crate::ReduceOp;
let data_size = 100_000;
let worker_count = 8;
let config = ParallelConfig {
worker_count,
deterministic: true,
..Default::default()
};
let par_reduce = ParReduce {
size: data_size,
op: ReduceOp::Add,
config: config.clone(),
};
let chunks1 = par_reduce.chunk_assignments();
let chunks2 = par_reduce.chunk_assignments();
let chunks3 = par_reduce.chunk_assignments();
for i in 0..worker_count {
assert_eq!(
chunks1[i].start, chunks2[i].start,
"M3 FAIL: Chunk {} start differs between runs",
i
);
assert_eq!(
chunks1[i].end, chunks2[i].end,
"M3 FAIL: Chunk {} end differs between runs",
i
);
assert_eq!(
chunks2[i].start, chunks3[i].start,
"M3 FAIL: Chunk {} start differs between runs",
i
);
assert_eq!(
chunks2[i].end, chunks3[i].end,
"M3 FAIL: Chunk {} end differs between runs",
i
);
}
use crate::parallel::ParallelPass;
let parallel_config = ParallelConfig {
worker_count: 8,
deterministic: true,
..Default::default()
};
let loop_id = LoopId::new(0);
let mut body = Body::new();
let lp = Loop {
id: loop_id,
var: ValueId::new(0),
lower: Value::i64(0),
upper: Value::i64(100000),
step: Value::i64(1),
body: Body::new(),
attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
};
body.push(Stmt::Loop(lp));
let ir = LoopIR {
name: bhc_intern::Symbol::intern("deterministic_test"),
params: vec![],
return_ty: LoopType::Void,
body,
allocs: vec![],
loop_info: vec![LoopMetadata {
id: loop_id,
trip_count: TripCount::Static(100000),
vector_width: None,
parallel_chunk: None,
unroll_factor: None,
dependencies: Vec::new(),
}],
};
let mut pass = ParallelPass::new(parallel_config);
let analysis = pass.analyze(&ir);
let info = analysis.get(&loop_id).expect("loop should be analyzed");
assert!(
info.parallelizable,
"M3 FAIL: Loop should be parallelizable"
);
assert_eq!(
info.strategy,
ParallelStrategy::Static,
"M3 FAIL: Deterministic mode should use Static scheduling"
);
}
#[test]
fn test_m3_vectorized_parallel_reduction() {
use crate::parallel::{ParallelConfig, ParallelPass};
use crate::vectorize::{VectorizeConfig, VectorizePass};
use bhc_index::Idx;
use bhc_tensor_ir::BufferId;
let outer_loop_id = LoopId::new(0);
let inner_loop_id = LoopId::new(1);
let mem_ref = MemRef {
buffer: BufferId::new(0),
index: Value::Var(ValueId::new(1), LoopType::Scalar(ScalarType::I64)),
elem_ty: LoopType::Scalar(ScalarType::F32),
access: AccessPattern::Sequential,
};
let mut inner_body = Body::new();
let load_result = ValueId::new(2);
inner_body.push(Stmt::Assign(load_result, Op::Load(mem_ref)));
let inner_loop = Loop {
id: inner_loop_id,
var: ValueId::new(1),
lower: Value::i64(0),
upper: Value::i64(1024),
step: Value::i64(1),
body: inner_body,
attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT | LoopAttrs::REDUCTION,
};
let mut outer_body = Body::new();
outer_body.push(Stmt::Loop(inner_loop));
let outer_loop = Loop {
id: outer_loop_id,
var: ValueId::new(0),
lower: Value::i64(0),
upper: Value::i64(10000),
step: Value::i64(1),
body: outer_body,
attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
};
let mut top_body = Body::new();
top_body.push(Stmt::Loop(outer_loop));
let ir = LoopIR {
name: bhc_intern::Symbol::intern("vec_par_reduce"),
params: vec![],
return_ty: LoopType::Void,
body: top_body,
allocs: vec![],
loop_info: vec![
LoopMetadata {
id: outer_loop_id,
trip_count: TripCount::Static(10000),
vector_width: None,
parallel_chunk: None,
unroll_factor: None,
dependencies: Vec::new(),
},
LoopMetadata {
id: inner_loop_id,
trip_count: TripCount::Static(1024),
vector_width: None,
parallel_chunk: None,
unroll_factor: None,
dependencies: Vec::new(),
},
],
};
let vec_config = VectorizeConfig {
target: TargetArch::X86_64Avx2,
..Default::default()
};
let mut vec_pass = VectorizePass::new(vec_config);
let vec_analysis = vec_pass.analyze(&ir);
let inner_info = vec_analysis
.get(&inner_loop_id)
.expect("inner loop analyzed");
assert!(
inner_info.vectorizable,
"M3 FAIL: Inner reduction loop should be vectorizable"
);
let par_config = ParallelConfig {
worker_count: 8,
deterministic: true,
..Default::default()
};
let mut par_pass = ParallelPass::new(par_config);
let par_analysis = par_pass.analyze(&ir);
let outer_info = par_analysis
.get(&outer_loop_id)
.expect("outer loop analyzed");
assert!(
outer_info.parallelizable,
"M3 FAIL: Outer loop should be parallelizable"
);
assert_eq!(
outer_info.num_chunks, 8,
"M3 FAIL: Should have 8 parallel chunks"
);
}
}