1#![warn(missing_docs)]
74#![allow(clippy::module_name_repetitions)]
75
76use bhc_index::Idx;
77use bhc_intern::Symbol;
78use bhc_tensor_ir::{AllocRegion, BufferId, DType};
79use bitflags::bitflags;
80use serde::{Deserialize, Serialize};
81use smallvec::SmallVec;
82
83pub mod lower;
88pub mod parallel;
89pub mod vectorize;
90
91pub use lower::{lower_kernel, lower_kernels, LowerConfig, LowerError};
93pub use parallel::{
94 ParFor, ParMap, ParReduce, ParallelConfig, ParallelPass, ParallelStrategy, Range,
95};
96pub use vectorize::{SimdIntrinsic, VectorizeConfig, VectorizePass, VectorizeReport};
97
98#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
100pub struct ValueId(u32);
101
102impl Idx for ValueId {
103 fn new(idx: usize) -> Self {
104 Self(idx as u32)
105 }
106
107 fn index(self) -> usize {
108 self.0 as usize
109 }
110}
111
112#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
114pub struct LoopId(u32);
115
116impl Idx for LoopId {
117 fn new(idx: usize) -> Self {
118 Self(idx as u32)
119 }
120
121 fn index(self) -> usize {
122 self.0 as usize
123 }
124}
125
126#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
128pub struct BlockId(u32);
129
130impl Idx for BlockId {
131 fn new(idx: usize) -> Self {
132 Self(idx as u32)
133 }
134
135 fn index(self) -> usize {
136 self.0 as usize
137 }
138}
139
140#[derive(Clone, Debug, Serialize, Deserialize)]
142pub struct LoopIR {
143 pub name: Symbol,
145 pub params: Vec<Param>,
147 pub return_ty: LoopType,
149 pub body: Body,
151 pub allocs: Vec<Alloc>,
153 pub loop_info: Vec<LoopMetadata>,
155}
156
157#[derive(Clone, Debug, Serialize, Deserialize)]
159pub struct Param {
160 pub name: Symbol,
162 pub ty: LoopType,
164 pub is_ptr: bool,
166}
167
168#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
170pub enum LoopType {
171 Void,
173 Scalar(ScalarType),
175 Vector(ScalarType, u8),
177 Ptr(Box<LoopType>),
179}
180
181impl LoopType {
182 #[must_use]
184 pub fn size_bytes(&self) -> usize {
185 match self {
186 Self::Void => 0,
187 Self::Scalar(s) => s.size_bytes(),
188 Self::Vector(s, width) => s.size_bytes() * (*width as usize),
189 Self::Ptr(_) => 8, }
191 }
192
193 #[must_use]
195 pub fn is_void(&self) -> bool {
196 matches!(self, Self::Void)
197 }
198
199 #[must_use]
201 pub fn is_vector(&self) -> bool {
202 matches!(self, Self::Vector(_, _))
203 }
204}
205
206#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
208pub enum ScalarType {
209 Bool,
211 Int(u8),
213 UInt(u8),
215 Float(u8),
217}
218
219impl ScalarType {
220 #[must_use]
222 pub const fn size_bytes(self) -> usize {
223 match self {
224 Self::Bool => 1,
225 Self::Int(bits) | Self::UInt(bits) | Self::Float(bits) => (bits as usize).div_ceil(8),
226 }
227 }
228
229 #[must_use]
231 pub fn from_dtype(dtype: DType) -> Self {
232 match dtype {
233 DType::Bool => Self::Bool,
234 DType::Int8 => Self::Int(8),
235 DType::Int16 => Self::Int(16),
236 DType::Int32 => Self::Int(32),
237 DType::Int64 => Self::Int(64),
238 DType::UInt8 => Self::UInt(8),
239 DType::UInt16 => Self::UInt(16),
240 DType::UInt32 => Self::UInt(32),
241 DType::UInt64 => Self::UInt(64),
242 DType::Float16 | DType::BFloat16 => Self::Float(16),
243 DType::Float32 => Self::Float(32),
244 DType::Float64 => Self::Float(64),
245 DType::Complex64 => Self::Float(32), DType::Complex128 => Self::Float(64),
247 }
248 }
249
250 pub const F32: Self = Self::Float(32);
252
253 pub const F64: Self = Self::Float(64);
255
256 pub const I32: Self = Self::Int(32);
258
259 pub const I64: Self = Self::Int(64);
261}
262
263impl LoopType {
268 pub const VEC4F32: Self = Self::Vector(ScalarType::F32, 4);
272
273 pub const VEC8F32: Self = Self::Vector(ScalarType::F32, 8);
275
276 pub const VEC2F64: Self = Self::Vector(ScalarType::F64, 2);
278
279 pub const VEC4F64: Self = Self::Vector(ScalarType::F64, 4);
281
282 pub const VEC4I32: Self = Self::Vector(ScalarType::I32, 4);
284
285 pub const VEC8I32: Self = Self::Vector(ScalarType::I32, 8);
287
288 #[must_use]
298 pub fn natural_vector_width(scalar: ScalarType, target: TargetArch) -> u8 {
299 match (target, scalar) {
300 (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Float(32)) => 8,
302 (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Float(64)) => 4,
303 (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Int(32)) => 8,
304 (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Float(32)) => 4,
306 (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Float(64)) => 2,
307 (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Int(32)) => 4,
308 (TargetArch::Aarch64Neon, ScalarType::Float(32)) => 4,
310 (TargetArch::Aarch64Neon, ScalarType::Float(64)) => 2,
311 (TargetArch::Aarch64Neon, ScalarType::Int(32)) => 4,
312 _ => 1,
314 }
315 }
316
317 #[must_use]
319 pub const fn vector(scalar: ScalarType, width: u8) -> Self {
320 Self::Vector(scalar, width)
321 }
322
323 #[must_use]
325 pub fn vector_width(&self) -> Option<u8> {
326 match self {
327 Self::Vector(_, w) => Some(*w),
328 _ => None,
329 }
330 }
331
332 #[must_use]
334 pub fn element_type(&self) -> Option<ScalarType> {
335 match self {
336 Self::Vector(s, _) => Some(*s),
337 Self::Scalar(s) => Some(*s),
338 _ => None,
339 }
340 }
341}
342
343#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
345pub enum TargetArch {
346 X86_64Sse,
348 X86_64Sse2,
350 X86_64Avx,
352 X86_64Avx2,
354 #[default]
356 Aarch64Neon,
357 Generic,
359}
360
361#[derive(Clone, Debug, Serialize, Deserialize)]
363pub struct Alloc {
364 pub buffer: BufferId,
366 pub name: Symbol,
368 pub elem_ty: ScalarType,
370 pub size: AllocSize,
372 pub alignment: usize,
374 pub region: AllocRegion,
376}
377
378#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
380pub enum AllocSize {
381 Static(usize),
383 Dynamic(ValueId),
385}
386
387#[derive(Clone, Debug, Default, Serialize, Deserialize)]
389pub struct Body {
390 pub stmts: Vec<Stmt>,
392}
393
394impl Body {
395 #[must_use]
397 pub fn new() -> Self {
398 Self::default()
399 }
400
401 pub fn push(&mut self, stmt: Stmt) {
403 self.stmts.push(stmt);
404 }
405}
406
407#[derive(Clone, Debug, Serialize, Deserialize)]
409pub enum Stmt {
410 Assign(ValueId, Op),
412
413 Loop(Loop),
415
416 If(IfStmt),
418
419 Store(MemRef, Value),
421
422 Call(Option<ValueId>, Symbol, Vec<Value>),
424
425 Return(Option<Value>),
427
428 Barrier(BarrierKind),
430
431 Comment(String),
433}
434
435#[derive(Clone, Debug, Serialize, Deserialize)]
437pub struct Loop {
438 pub id: LoopId,
440 pub var: ValueId,
442 pub lower: Value,
444 pub upper: Value,
446 pub step: Value,
448 pub body: Body,
450 pub attrs: LoopAttrs,
452}
453
454bitflags! {
455 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
457 pub struct LoopAttrs: u32 {
458 const PARALLEL = 0b0000_0001;
460 const VECTORIZE = 0b0000_0010;
462 const UNROLL = 0b0000_0100;
464 const REDUCTION = 0b0000_1000;
466 const INDEPENDENT = 0b0001_0000;
468 const TILED = 0b0010_0000;
470 const TILE_INNER = 0b0100_0000;
472 }
473}
474
475#[derive(Clone, Debug, Serialize, Deserialize)]
477pub struct LoopMetadata {
478 pub id: LoopId,
480 pub trip_count: TripCount,
482 pub vector_width: Option<u8>,
484 pub parallel_chunk: Option<usize>,
486 pub unroll_factor: Option<u8>,
488 pub dependencies: Vec<LoopDependency>,
490}
491
492#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
494pub enum TripCount {
495 Static(usize),
497 Dynamic,
499 Bounded(usize),
501}
502
503#[derive(Clone, Debug, Serialize, Deserialize)]
505pub struct LoopDependency {
506 pub source: LoopId,
508 pub target: LoopId,
510 pub kind: DependencyKind,
512 pub distance: Option<Vec<i32>>,
514}
515
516#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
518pub enum DependencyKind {
519 Flow,
521 Anti,
523 Output,
525 Input,
527}
528
529#[derive(Clone, Debug, Serialize, Deserialize)]
531pub struct IfStmt {
532 pub cond: Value,
534 pub then_body: Body,
536 pub else_body: Option<Body>,
538}
539
540#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
542pub enum Value {
543 Var(ValueId, LoopType),
545 IntConst(i64, ScalarType),
547 FloatConst(f64, ScalarType),
549 BoolConst(bool),
551 Undef(LoopType),
553}
554
555impl Value {
556 #[must_use]
558 pub fn ty(&self) -> LoopType {
559 match self {
560 Self::Var(_, ty) => ty.clone(),
561 Self::IntConst(_, s) => LoopType::Scalar(*s),
562 Self::FloatConst(_, s) => LoopType::Scalar(*s),
563 Self::BoolConst(_) => LoopType::Scalar(ScalarType::Bool),
564 Self::Undef(ty) => ty.clone(),
565 }
566 }
567
568 #[must_use]
570 pub fn int(n: i64, bits: u8) -> Self {
571 Self::IntConst(n, ScalarType::Int(bits))
572 }
573
574 #[must_use]
576 pub fn i64(n: i64) -> Self {
577 Self::int(n, 64)
578 }
579
580 #[must_use]
582 pub fn float(f: f64, bits: u8) -> Self {
583 Self::FloatConst(f, ScalarType::Float(bits))
584 }
585
586 #[must_use]
588 pub fn f64(f: f64) -> Self {
589 Self::float(f, 64)
590 }
591}
592
593#[derive(Clone, Debug, Serialize, Deserialize)]
595pub enum Op {
596 Load(MemRef),
598
599 Binary(BinOp, Value, Value),
601
602 Unary(UnOp, Value),
604
605 Cmp(CmpOp, Value, Value),
607
608 Select(Value, Value, Value),
610
611 Cast(Value, LoopType),
613
614 Broadcast(Value, u8),
616
617 Extract(Value, u8),
619
620 Insert(Value, Value, u8),
622
623 Shuffle(Value, Value, Vec<i32>),
625
626 VecReduce(ReduceOp, Value),
628
629 Fma(Value, Value, Value),
631
632 PtrAdd(Value, Value),
634
635 GetPtr(BufferId, Value),
637
638 Phi(Vec<(BlockId, Value)>),
640}
641
642#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
644pub enum BinOp {
645 Add,
648 Sub,
650 Mul,
652 SDiv,
654 UDiv,
656 FDiv,
658 SRem,
660 URem,
662 FRem,
664
665 And,
668 Or,
670 Xor,
672 Shl,
674 LShr,
676 AShr,
678
679 SMin,
682 UMin,
684 FMin,
686 SMax,
688 UMax,
690 FMax,
692}
693
694#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
696pub enum UnOp {
697 Neg,
699 FNeg,
701 Not,
703 Abs,
705 FAbs,
707 Sqrt,
709 Rsqrt,
711 Floor,
713 Ceil,
715 Round,
717 Trunc,
719 Exp,
721 Log,
723 Sin,
725 Cos,
727}
728
729#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
731pub enum CmpOp {
732 Eq,
734 Ne,
736 SLt,
738 SLe,
740 SGt,
742 SGe,
744 ULt,
746 ULe,
748 UGt,
750 UGe,
752 OEq,
754 ONe,
756 OLt,
758 OLe,
760 OGt,
762 OGe,
764}
765
766#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
768pub enum ReduceOp {
769 Add,
771 Mul,
773 Min,
775 Max,
777 And,
779 Or,
781 Xor,
783}
784
785#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
787pub struct MemRef {
788 pub buffer: BufferId,
790 pub index: Value,
792 pub elem_ty: LoopType,
794 pub access: AccessPattern,
796}
797
798#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
800pub enum AccessPattern {
801 Sequential,
803 Strided(i64),
805 Random,
807 Broadcast,
809 Affine(AffineAccess),
811}
812
813#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
815pub struct AffineAccess {
816 pub coefficients: SmallVec<[(LoopId, i64); 4]>,
818 pub offset: i64,
820}
821
822#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
824pub enum BarrierKind {
825 MemFence,
827 Full,
829 ThreadGroup,
831}
832
833#[derive(Clone, Debug, thiserror::Error, Serialize, Deserialize)]
835pub enum LoopIrError {
836 #[error("type mismatch: expected {expected:?}, got {got:?}")]
838 TypeMismatch {
839 expected: LoopType,
841 got: LoopType,
843 },
844
845 #[error("invalid vector width {width} for type {ty:?}")]
847 InvalidVectorWidth {
848 width: u8,
850 ty: ScalarType,
852 },
853
854 #[error("buffer access out of bounds")]
856 OutOfBounds,
857
858 #[error("invalid loop transformation: {reason}")]
860 InvalidTransform {
861 reason: String,
863 },
864}
865
866#[cfg(test)]
867mod tests {
868 use super::*;
869
870 #[test]
871 fn test_scalar_type_sizes() {
872 assert_eq!(ScalarType::Bool.size_bytes(), 1);
873 assert_eq!(ScalarType::Int(32).size_bytes(), 4);
874 assert_eq!(ScalarType::Float(64).size_bytes(), 8);
875 }
876
877 #[test]
878 fn test_loop_type_size() {
879 assert_eq!(LoopType::Scalar(ScalarType::Float(32)).size_bytes(), 4);
880 assert_eq!(LoopType::Vector(ScalarType::Float(32), 8).size_bytes(), 32);
881 }
882
883 #[test]
884 fn test_value_types() {
885 let v = Value::i64(42);
886 assert_eq!(v.ty(), LoopType::Scalar(ScalarType::Int(64)));
887
888 let f = Value::f64(2.5);
889 assert_eq!(f.ty(), LoopType::Scalar(ScalarType::Float(64)));
890 }
891
892 #[test]
893 fn test_loop_attrs() {
894 let attrs = LoopAttrs::PARALLEL | LoopAttrs::VECTORIZE;
895 assert!(attrs.contains(LoopAttrs::PARALLEL));
896 assert!(attrs.contains(LoopAttrs::VECTORIZE));
897 assert!(!attrs.contains(LoopAttrs::UNROLL));
898 }
899
900 #[test]
901 fn test_trip_count() {
902 let static_trip = TripCount::Static(100);
903 assert_eq!(static_trip, TripCount::Static(100));
904
905 let dynamic_trip = TripCount::Dynamic;
906 assert_eq!(dynamic_trip, TripCount::Dynamic);
907 }
908
909 #[test]
918 fn test_m3_matmul_auto_vectorizes() {
919 use crate::vectorize::{VectorizeConfig, VectorizePass};
920 use bhc_index::Idx;
921 use bhc_tensor_ir::BufferId;
922
923 let loop_id = LoopId::new(0);
925 let loop_var = ValueId::new(0);
926
927 let mem_ref = MemRef {
929 buffer: BufferId::new(0),
930 index: Value::Var(loop_var, LoopType::Scalar(ScalarType::I64)),
931 elem_ty: LoopType::Scalar(ScalarType::F32),
932 access: AccessPattern::Sequential,
933 };
934
935 let mut body = Body::new();
936 let load_a = ValueId::new(1);
937 body.push(Stmt::Assign(load_a, Op::Load(mem_ref.clone())));
938
939 let load_b = ValueId::new(2);
940 body.push(Stmt::Assign(load_b, Op::Load(mem_ref.clone())));
941
942 let mul_result = ValueId::new(3);
943 body.push(Stmt::Assign(
944 mul_result,
945 Op::Binary(
946 BinOp::Mul,
947 Value::Var(load_a, LoopType::Scalar(ScalarType::F32)),
948 Value::Var(load_b, LoopType::Scalar(ScalarType::F32)),
949 ),
950 ));
951
952 let acc = ValueId::new(4);
954 let fma_result = ValueId::new(5);
955 body.push(Stmt::Assign(
956 fma_result,
957 Op::Fma(
958 Value::Var(load_a, LoopType::Scalar(ScalarType::F32)),
959 Value::Var(load_b, LoopType::Scalar(ScalarType::F32)),
960 Value::Var(acc, LoopType::Scalar(ScalarType::F32)),
961 ),
962 ));
963
964 let lp = Loop {
965 id: loop_id,
966 var: loop_var,
967 lower: Value::i64(0),
968 upper: Value::i64(256), step: Value::i64(1),
970 body,
971 attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT,
972 };
973
974 let mut outer_body = Body::new();
975 outer_body.push(Stmt::Loop(lp));
976
977 let ir = LoopIR {
978 name: bhc_intern::Symbol::intern("matmul_kernel"),
979 params: vec![],
980 return_ty: LoopType::Void,
981 body: outer_body,
982 allocs: vec![],
983 loop_info: vec![LoopMetadata {
984 id: loop_id,
985 trip_count: TripCount::Static(256),
986 vector_width: None,
987 parallel_chunk: None,
988 unroll_factor: None,
989 dependencies: Vec::new(),
990 }],
991 };
992
993 let config_x86 = VectorizeConfig {
995 target: TargetArch::X86_64Avx2,
996 ..Default::default()
997 };
998 let mut pass_x86 = VectorizePass::new(config_x86);
999 let analysis_x86 = pass_x86.analyze(&ir);
1000 let info_x86 = analysis_x86.get(&loop_id).expect("loop should be analyzed");
1001
1002 assert!(
1003 info_x86.vectorizable,
1004 "M3 FAIL: matmul kernel not vectorizable on x86_64 AVX2"
1005 );
1006 assert_eq!(
1007 info_x86.recommended_width, 8,
1008 "M3 FAIL: x86_64 AVX2 should use 8-wide vectors for f32"
1009 );
1010
1011 let config_arm = VectorizeConfig {
1013 target: TargetArch::Aarch64Neon,
1014 ..Default::default()
1015 };
1016 let mut pass_arm = VectorizePass::new(config_arm);
1017 let analysis_arm = pass_arm.analyze(&ir);
1018 let info_arm = analysis_arm.get(&loop_id).expect("loop should be analyzed");
1019
1020 assert!(
1021 info_arm.vectorizable,
1022 "M3 FAIL: matmul kernel not vectorizable on aarch64 NEON"
1023 );
1024 assert_eq!(
1025 info_arm.recommended_width, 4,
1026 "M3 FAIL: aarch64 NEON should use 4-wide vectors for f32"
1027 );
1028 }
1029
1030 #[test]
1035 fn test_m3_reductions_scale_linearly() {
1036 use crate::parallel::{ParReduce, ParallelConfig, Range};
1037 use crate::ReduceOp;
1038
1039 let data_size = 1_000_000; for worker_count in [1, 2, 4, 8] {
1043 let config = ParallelConfig {
1044 worker_count,
1045 deterministic: true,
1046 ..Default::default()
1047 };
1048
1049 let par_reduce = ParReduce {
1050 size: data_size,
1051 op: ReduceOp::Add,
1052 config,
1053 };
1054
1055 let chunks = par_reduce.chunk_assignments();
1056
1057 assert_eq!(
1059 chunks.len(),
1060 worker_count,
1061 "M3 FAIL: Expected {} chunks for {} workers",
1062 worker_count,
1063 worker_count
1064 );
1065
1066 let total_work: usize = chunks.iter().map(|c| c.len()).sum();
1068 assert_eq!(
1069 total_work, data_size,
1070 "M3 FAIL: Total work should equal data size"
1071 );
1072
1073 let expected_per_worker = data_size / worker_count;
1075 for (i, chunk) in chunks.iter().enumerate() {
1076 let diff = (chunk.len() as i64 - expected_per_worker as i64).abs();
1077 assert!(
1078 diff <= 1,
1079 "M3 FAIL: Worker {} has {} elements, expected ~{} (diff={})",
1080 i,
1081 chunk.len(),
1082 expected_per_worker,
1083 diff
1084 );
1085 }
1086 }
1087
1088 let _config_4 = ParallelConfig {
1090 worker_count: 4,
1091 ..Default::default()
1092 };
1093 let chunks_4 = Range::new(0, data_size as i64).chunk(4);
1094
1095 let _config_8 = ParallelConfig {
1096 worker_count: 8,
1097 ..Default::default()
1098 };
1099 let chunks_8 = Range::new(0, data_size as i64).chunk(8);
1100
1101 let avg_chunk_4: usize = chunks_4.iter().map(|c| c.len()).sum::<usize>() / 4;
1102 let avg_chunk_8: usize = chunks_8.iter().map(|c| c.len()).sum::<usize>() / 8;
1103
1104 let ratio = avg_chunk_4 as f64 / avg_chunk_8 as f64;
1106 assert!(
1107 (ratio - 2.0).abs() < 0.1,
1108 "M3 FAIL: Chunk size ratio should be ~2.0, got {}",
1109 ratio
1110 );
1111 }
1112
1113 #[test]
1117 fn test_m3_deterministic_mode() {
1118 use crate::parallel::{ParReduce, ParallelConfig, ParallelStrategy};
1119 use crate::ReduceOp;
1120
1121 let data_size = 100_000;
1122 let worker_count = 8;
1123
1124 let config = ParallelConfig {
1126 worker_count,
1127 deterministic: true,
1128 ..Default::default()
1129 };
1130
1131 let par_reduce = ParReduce {
1132 size: data_size,
1133 op: ReduceOp::Add,
1134 config: config.clone(),
1135 };
1136
1137 let chunks1 = par_reduce.chunk_assignments();
1139 let chunks2 = par_reduce.chunk_assignments();
1140 let chunks3 = par_reduce.chunk_assignments();
1141
1142 for i in 0..worker_count {
1143 assert_eq!(
1144 chunks1[i].start, chunks2[i].start,
1145 "M3 FAIL: Chunk {} start differs between runs",
1146 i
1147 );
1148 assert_eq!(
1149 chunks1[i].end, chunks2[i].end,
1150 "M3 FAIL: Chunk {} end differs between runs",
1151 i
1152 );
1153 assert_eq!(
1154 chunks2[i].start, chunks3[i].start,
1155 "M3 FAIL: Chunk {} start differs between runs",
1156 i
1157 );
1158 assert_eq!(
1159 chunks2[i].end, chunks3[i].end,
1160 "M3 FAIL: Chunk {} end differs between runs",
1161 i
1162 );
1163 }
1164
1165 use crate::parallel::ParallelPass;
1167
1168 let parallel_config = ParallelConfig {
1169 worker_count: 8,
1170 deterministic: true,
1171 ..Default::default()
1172 };
1173
1174 let loop_id = LoopId::new(0);
1176 let mut body = Body::new();
1177
1178 let lp = Loop {
1179 id: loop_id,
1180 var: ValueId::new(0),
1181 lower: Value::i64(0),
1182 upper: Value::i64(100000),
1183 step: Value::i64(1),
1184 body: Body::new(),
1185 attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
1186 };
1187
1188 body.push(Stmt::Loop(lp));
1189
1190 let ir = LoopIR {
1191 name: bhc_intern::Symbol::intern("deterministic_test"),
1192 params: vec![],
1193 return_ty: LoopType::Void,
1194 body,
1195 allocs: vec![],
1196 loop_info: vec![LoopMetadata {
1197 id: loop_id,
1198 trip_count: TripCount::Static(100000),
1199 vector_width: None,
1200 parallel_chunk: None,
1201 unroll_factor: None,
1202 dependencies: Vec::new(),
1203 }],
1204 };
1205
1206 let mut pass = ParallelPass::new(parallel_config);
1207 let analysis = pass.analyze(&ir);
1208 let info = analysis.get(&loop_id).expect("loop should be analyzed");
1209
1210 assert!(
1211 info.parallelizable,
1212 "M3 FAIL: Loop should be parallelizable"
1213 );
1214 assert_eq!(
1215 info.strategy,
1216 ParallelStrategy::Static,
1217 "M3 FAIL: Deterministic mode should use Static scheduling"
1218 );
1219 }
1220
1221 #[test]
1223 fn test_m3_vectorized_parallel_reduction() {
1224 use crate::parallel::{ParallelConfig, ParallelPass};
1225 use crate::vectorize::{VectorizeConfig, VectorizePass};
1226 use bhc_index::Idx;
1227 use bhc_tensor_ir::BufferId;
1228
1229 let outer_loop_id = LoopId::new(0);
1231 let inner_loop_id = LoopId::new(1);
1232
1233 let mem_ref = MemRef {
1235 buffer: BufferId::new(0),
1236 index: Value::Var(ValueId::new(1), LoopType::Scalar(ScalarType::I64)),
1237 elem_ty: LoopType::Scalar(ScalarType::F32),
1238 access: AccessPattern::Sequential,
1239 };
1240
1241 let mut inner_body = Body::new();
1242 let load_result = ValueId::new(2);
1243 inner_body.push(Stmt::Assign(load_result, Op::Load(mem_ref)));
1244
1245 let inner_loop = Loop {
1246 id: inner_loop_id,
1247 var: ValueId::new(1),
1248 lower: Value::i64(0),
1249 upper: Value::i64(1024),
1250 step: Value::i64(1),
1251 body: inner_body,
1252 attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT | LoopAttrs::REDUCTION,
1253 };
1254
1255 let mut outer_body = Body::new();
1257 outer_body.push(Stmt::Loop(inner_loop));
1258
1259 let outer_loop = Loop {
1260 id: outer_loop_id,
1261 var: ValueId::new(0),
1262 lower: Value::i64(0),
1263 upper: Value::i64(10000),
1264 step: Value::i64(1),
1265 body: outer_body,
1266 attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
1267 };
1268
1269 let mut top_body = Body::new();
1270 top_body.push(Stmt::Loop(outer_loop));
1271
1272 let ir = LoopIR {
1273 name: bhc_intern::Symbol::intern("vec_par_reduce"),
1274 params: vec![],
1275 return_ty: LoopType::Void,
1276 body: top_body,
1277 allocs: vec![],
1278 loop_info: vec![
1279 LoopMetadata {
1280 id: outer_loop_id,
1281 trip_count: TripCount::Static(10000),
1282 vector_width: None,
1283 parallel_chunk: None,
1284 unroll_factor: None,
1285 dependencies: Vec::new(),
1286 },
1287 LoopMetadata {
1288 id: inner_loop_id,
1289 trip_count: TripCount::Static(1024),
1290 vector_width: None,
1291 parallel_chunk: None,
1292 unroll_factor: None,
1293 dependencies: Vec::new(),
1294 },
1295 ],
1296 };
1297
1298 let vec_config = VectorizeConfig {
1300 target: TargetArch::X86_64Avx2,
1301 ..Default::default()
1302 };
1303 let mut vec_pass = VectorizePass::new(vec_config);
1304 let vec_analysis = vec_pass.analyze(&ir);
1305
1306 let inner_info = vec_analysis
1308 .get(&inner_loop_id)
1309 .expect("inner loop analyzed");
1310 assert!(
1311 inner_info.vectorizable,
1312 "M3 FAIL: Inner reduction loop should be vectorizable"
1313 );
1314
1315 let par_config = ParallelConfig {
1317 worker_count: 8,
1318 deterministic: true,
1319 ..Default::default()
1320 };
1321 let mut par_pass = ParallelPass::new(par_config);
1322 let par_analysis = par_pass.analyze(&ir);
1323
1324 let outer_info = par_analysis
1326 .get(&outer_loop_id)
1327 .expect("outer loop analyzed");
1328 assert!(
1329 outer_info.parallelizable,
1330 "M3 FAIL: Outer loop should be parallelizable"
1331 );
1332 assert_eq!(
1333 outer_info.num_chunks, 8,
1334 "M3 FAIL: Should have 8 parallel chunks"
1335 );
1336 }
1337}