1#![warn(missing_docs)]
74#![warn(clippy::all)]
75#![warn(clippy::pedantic)]
76#![allow(clippy::module_name_repetitions)]
77
78use bhc_index::Idx;
79use bhc_intern::Symbol;
80use bhc_tensor_ir::{AllocRegion, BufferId, DType};
81use bitflags::bitflags;
82use serde::{Deserialize, Serialize};
83use smallvec::SmallVec;
84
85pub mod lower;
90pub mod parallel;
91pub mod vectorize;
92
93pub use lower::{lower_kernel, lower_kernels, LowerConfig, LowerError};
95pub use parallel::{
96 ParFor, ParMap, ParReduce, ParallelConfig, ParallelPass, ParallelStrategy, Range,
97};
98pub use vectorize::{SimdIntrinsic, VectorizeConfig, VectorizePass, VectorizeReport};
99
100#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
102pub struct ValueId(u32);
103
104impl Idx for ValueId {
105 fn new(idx: usize) -> Self {
106 Self(idx as u32)
107 }
108
109 fn index(self) -> usize {
110 self.0 as usize
111 }
112}
113
114#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
116pub struct LoopId(u32);
117
118impl Idx for LoopId {
119 fn new(idx: usize) -> Self {
120 Self(idx as u32)
121 }
122
123 fn index(self) -> usize {
124 self.0 as usize
125 }
126}
127
128#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
130pub struct BlockId(u32);
131
132impl Idx for BlockId {
133 fn new(idx: usize) -> Self {
134 Self(idx as u32)
135 }
136
137 fn index(self) -> usize {
138 self.0 as usize
139 }
140}
141
142#[derive(Clone, Debug, Serialize, Deserialize)]
144pub struct LoopIR {
145 pub name: Symbol,
147 pub params: Vec<Param>,
149 pub return_ty: LoopType,
151 pub body: Body,
153 pub allocs: Vec<Alloc>,
155 pub loop_info: Vec<LoopMetadata>,
157}
158
159#[derive(Clone, Debug, Serialize, Deserialize)]
161pub struct Param {
162 pub name: Symbol,
164 pub ty: LoopType,
166 pub is_ptr: bool,
168}
169
170#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
172pub enum LoopType {
173 Void,
175 Scalar(ScalarType),
177 Vector(ScalarType, u8),
179 Ptr(Box<LoopType>),
181}
182
183impl LoopType {
184 #[must_use]
186 pub fn size_bytes(&self) -> usize {
187 match self {
188 Self::Void => 0,
189 Self::Scalar(s) => s.size_bytes(),
190 Self::Vector(s, width) => s.size_bytes() * (*width as usize),
191 Self::Ptr(_) => 8, }
193 }
194
195 #[must_use]
197 pub fn is_void(&self) -> bool {
198 matches!(self, Self::Void)
199 }
200
201 #[must_use]
203 pub fn is_vector(&self) -> bool {
204 matches!(self, Self::Vector(_, _))
205 }
206}
207
208#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
210pub enum ScalarType {
211 Bool,
213 Int(u8),
215 UInt(u8),
217 Float(u8),
219}
220
221impl ScalarType {
222 #[must_use]
224 pub const fn size_bytes(self) -> usize {
225 match self {
226 Self::Bool => 1,
227 Self::Int(bits) | Self::UInt(bits) | Self::Float(bits) => (bits as usize + 7) / 8,
228 }
229 }
230
231 #[must_use]
233 pub fn from_dtype(dtype: DType) -> Self {
234 match dtype {
235 DType::Bool => Self::Bool,
236 DType::Int8 => Self::Int(8),
237 DType::Int16 => Self::Int(16),
238 DType::Int32 => Self::Int(32),
239 DType::Int64 => Self::Int(64),
240 DType::UInt8 => Self::UInt(8),
241 DType::UInt16 => Self::UInt(16),
242 DType::UInt32 => Self::UInt(32),
243 DType::UInt64 => Self::UInt(64),
244 DType::Float16 | DType::BFloat16 => Self::Float(16),
245 DType::Float32 => Self::Float(32),
246 DType::Float64 => Self::Float(64),
247 DType::Complex64 => Self::Float(32), DType::Complex128 => Self::Float(64),
249 }
250 }
251
252 pub const F32: Self = Self::Float(32);
254
255 pub const F64: Self = Self::Float(64);
257
258 pub const I32: Self = Self::Int(32);
260
261 pub const I64: Self = Self::Int(64);
263}
264
265impl LoopType {
270 pub const VEC4F32: Self = Self::Vector(ScalarType::F32, 4);
274
275 pub const VEC8F32: Self = Self::Vector(ScalarType::F32, 8);
277
278 pub const VEC2F64: Self = Self::Vector(ScalarType::F64, 2);
280
281 pub const VEC4F64: Self = Self::Vector(ScalarType::F64, 4);
283
284 pub const VEC4I32: Self = Self::Vector(ScalarType::I32, 4);
286
287 pub const VEC8I32: Self = Self::Vector(ScalarType::I32, 8);
289
290 #[must_use]
300 pub fn natural_vector_width(scalar: ScalarType, target: TargetArch) -> u8 {
301 match (target, scalar) {
302 (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Float(32)) => 8,
304 (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Float(64)) => 4,
305 (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Int(32)) => 8,
306 (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Float(32)) => 4,
308 (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Float(64)) => 2,
309 (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Int(32)) => 4,
310 (TargetArch::Aarch64Neon, ScalarType::Float(32)) => 4,
312 (TargetArch::Aarch64Neon, ScalarType::Float(64)) => 2,
313 (TargetArch::Aarch64Neon, ScalarType::Int(32)) => 4,
314 _ => 1,
316 }
317 }
318
319 #[must_use]
321 pub const fn vector(scalar: ScalarType, width: u8) -> Self {
322 Self::Vector(scalar, width)
323 }
324
325 #[must_use]
327 pub fn vector_width(&self) -> Option<u8> {
328 match self {
329 Self::Vector(_, w) => Some(*w),
330 _ => None,
331 }
332 }
333
334 #[must_use]
336 pub fn element_type(&self) -> Option<ScalarType> {
337 match self {
338 Self::Vector(s, _) => Some(*s),
339 Self::Scalar(s) => Some(*s),
340 _ => None,
341 }
342 }
343}
344
345#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
347pub enum TargetArch {
348 X86_64Sse,
350 X86_64Sse2,
352 X86_64Avx,
354 X86_64Avx2,
356 Aarch64Neon,
358 Generic,
360}
361
362impl Default for TargetArch {
363 fn default() -> Self {
364 #[cfg(target_arch = "x86_64")]
366 {
367 Self::X86_64Avx2
368 }
369 #[cfg(target_arch = "aarch64")]
370 {
371 Self::Aarch64Neon
372 }
373 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
374 {
375 Self::Generic
376 }
377 }
378}
379
380#[derive(Clone, Debug, Serialize, Deserialize)]
382pub struct Alloc {
383 pub buffer: BufferId,
385 pub name: Symbol,
387 pub elem_ty: ScalarType,
389 pub size: AllocSize,
391 pub alignment: usize,
393 pub region: AllocRegion,
395}
396
397#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
399pub enum AllocSize {
400 Static(usize),
402 Dynamic(ValueId),
404}
405
406#[derive(Clone, Debug, Default, Serialize, Deserialize)]
408pub struct Body {
409 pub stmts: Vec<Stmt>,
411}
412
413impl Body {
414 #[must_use]
416 pub fn new() -> Self {
417 Self::default()
418 }
419
420 pub fn push(&mut self, stmt: Stmt) {
422 self.stmts.push(stmt);
423 }
424}
425
426#[derive(Clone, Debug, Serialize, Deserialize)]
428pub enum Stmt {
429 Assign(ValueId, Op),
431
432 Loop(Loop),
434
435 If(IfStmt),
437
438 Store(MemRef, Value),
440
441 Call(Option<ValueId>, Symbol, Vec<Value>),
443
444 Return(Option<Value>),
446
447 Barrier(BarrierKind),
449
450 Comment(String),
452}
453
454#[derive(Clone, Debug, Serialize, Deserialize)]
456pub struct Loop {
457 pub id: LoopId,
459 pub var: ValueId,
461 pub lower: Value,
463 pub upper: Value,
465 pub step: Value,
467 pub body: Body,
469 pub attrs: LoopAttrs,
471}
472
473bitflags! {
474 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
476 pub struct LoopAttrs: u32 {
477 const PARALLEL = 0b0000_0001;
479 const VECTORIZE = 0b0000_0010;
481 const UNROLL = 0b0000_0100;
483 const REDUCTION = 0b0000_1000;
485 const INDEPENDENT = 0b0001_0000;
487 const TILED = 0b0010_0000;
489 const TILE_INNER = 0b0100_0000;
491 }
492}
493
494#[derive(Clone, Debug, Serialize, Deserialize)]
496pub struct LoopMetadata {
497 pub id: LoopId,
499 pub trip_count: TripCount,
501 pub vector_width: Option<u8>,
503 pub parallel_chunk: Option<usize>,
505 pub unroll_factor: Option<u8>,
507 pub dependencies: Vec<LoopDependency>,
509}
510
511#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
513pub enum TripCount {
514 Static(usize),
516 Dynamic,
518 Bounded(usize),
520}
521
522#[derive(Clone, Debug, Serialize, Deserialize)]
524pub struct LoopDependency {
525 pub source: LoopId,
527 pub target: LoopId,
529 pub kind: DependencyKind,
531 pub distance: Option<Vec<i32>>,
533}
534
535#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
537pub enum DependencyKind {
538 Flow,
540 Anti,
542 Output,
544 Input,
546}
547
548#[derive(Clone, Debug, Serialize, Deserialize)]
550pub struct IfStmt {
551 pub cond: Value,
553 pub then_body: Body,
555 pub else_body: Option<Body>,
557}
558
559#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
561pub enum Value {
562 Var(ValueId, LoopType),
564 IntConst(i64, ScalarType),
566 FloatConst(f64, ScalarType),
568 BoolConst(bool),
570 Undef(LoopType),
572}
573
574impl Value {
575 #[must_use]
577 pub fn ty(&self) -> LoopType {
578 match self {
579 Self::Var(_, ty) => ty.clone(),
580 Self::IntConst(_, s) => LoopType::Scalar(*s),
581 Self::FloatConst(_, s) => LoopType::Scalar(*s),
582 Self::BoolConst(_) => LoopType::Scalar(ScalarType::Bool),
583 Self::Undef(ty) => ty.clone(),
584 }
585 }
586
587 #[must_use]
589 pub fn int(n: i64, bits: u8) -> Self {
590 Self::IntConst(n, ScalarType::Int(bits))
591 }
592
593 #[must_use]
595 pub fn i64(n: i64) -> Self {
596 Self::int(n, 64)
597 }
598
599 #[must_use]
601 pub fn float(f: f64, bits: u8) -> Self {
602 Self::FloatConst(f, ScalarType::Float(bits))
603 }
604
605 #[must_use]
607 pub fn f64(f: f64) -> Self {
608 Self::float(f, 64)
609 }
610}
611
612#[derive(Clone, Debug, Serialize, Deserialize)]
614pub enum Op {
615 Load(MemRef),
617
618 Binary(BinOp, Value, Value),
620
621 Unary(UnOp, Value),
623
624 Cmp(CmpOp, Value, Value),
626
627 Select(Value, Value, Value),
629
630 Cast(Value, LoopType),
632
633 Broadcast(Value, u8),
635
636 Extract(Value, u8),
638
639 Insert(Value, Value, u8),
641
642 Shuffle(Value, Value, Vec<i32>),
644
645 VecReduce(ReduceOp, Value),
647
648 Fma(Value, Value, Value),
650
651 PtrAdd(Value, Value),
653
654 GetPtr(BufferId, Value),
656
657 Phi(Vec<(BlockId, Value)>),
659}
660
661#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
663pub enum BinOp {
664 Add,
667 Sub,
669 Mul,
671 SDiv,
673 UDiv,
675 FDiv,
677 SRem,
679 URem,
681 FRem,
683
684 And,
687 Or,
689 Xor,
691 Shl,
693 LShr,
695 AShr,
697
698 SMin,
701 UMin,
703 FMin,
705 SMax,
707 UMax,
709 FMax,
711}
712
713#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
715pub enum UnOp {
716 Neg,
718 FNeg,
720 Not,
722 Abs,
724 FAbs,
726 Sqrt,
728 Rsqrt,
730 Floor,
732 Ceil,
734 Round,
736 Trunc,
738 Exp,
740 Log,
742 Sin,
744 Cos,
746}
747
748#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
750pub enum CmpOp {
751 Eq,
753 Ne,
755 SLt,
757 SLe,
759 SGt,
761 SGe,
763 ULt,
765 ULe,
767 UGt,
769 UGe,
771 OEq,
773 ONe,
775 OLt,
777 OLe,
779 OGt,
781 OGe,
783}
784
785#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
787pub enum ReduceOp {
788 Add,
790 Mul,
792 Min,
794 Max,
796 And,
798 Or,
800 Xor,
802}
803
804#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
806pub struct MemRef {
807 pub buffer: BufferId,
809 pub index: Value,
811 pub elem_ty: LoopType,
813 pub access: AccessPattern,
815}
816
817#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
819pub enum AccessPattern {
820 Sequential,
822 Strided(i64),
824 Random,
826 Broadcast,
828 Affine(AffineAccess),
830}
831
832#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
834pub struct AffineAccess {
835 pub coefficients: SmallVec<[(LoopId, i64); 4]>,
837 pub offset: i64,
839}
840
841#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
843pub enum BarrierKind {
844 MemFence,
846 Full,
848 ThreadGroup,
850}
851
852#[derive(Clone, Debug, thiserror::Error, Serialize, Deserialize)]
854pub enum LoopIrError {
855 #[error("type mismatch: expected {expected:?}, got {got:?}")]
857 TypeMismatch {
858 expected: LoopType,
860 got: LoopType,
862 },
863
864 #[error("invalid vector width {width} for type {ty:?}")]
866 InvalidVectorWidth {
867 width: u8,
869 ty: ScalarType,
871 },
872
873 #[error("buffer access out of bounds")]
875 OutOfBounds,
876
877 #[error("invalid loop transformation: {reason}")]
879 InvalidTransform {
880 reason: String,
882 },
883}
884
885#[cfg(test)]
886mod tests {
887 use super::*;
888
889 #[test]
890 fn test_scalar_type_sizes() {
891 assert_eq!(ScalarType::Bool.size_bytes(), 1);
892 assert_eq!(ScalarType::Int(32).size_bytes(), 4);
893 assert_eq!(ScalarType::Float(64).size_bytes(), 8);
894 }
895
896 #[test]
897 fn test_loop_type_size() {
898 assert_eq!(LoopType::Scalar(ScalarType::Float(32)).size_bytes(), 4);
899 assert_eq!(LoopType::Vector(ScalarType::Float(32), 8).size_bytes(), 32);
900 }
901
902 #[test]
903 fn test_value_types() {
904 let v = Value::i64(42);
905 assert_eq!(v.ty(), LoopType::Scalar(ScalarType::Int(64)));
906
907 let f = Value::f64(3.14);
908 assert_eq!(f.ty(), LoopType::Scalar(ScalarType::Float(64)));
909 }
910
911 #[test]
912 fn test_loop_attrs() {
913 let attrs = LoopAttrs::PARALLEL | LoopAttrs::VECTORIZE;
914 assert!(attrs.contains(LoopAttrs::PARALLEL));
915 assert!(attrs.contains(LoopAttrs::VECTORIZE));
916 assert!(!attrs.contains(LoopAttrs::UNROLL));
917 }
918
919 #[test]
920 fn test_trip_count() {
921 let static_trip = TripCount::Static(100);
922 assert_eq!(static_trip, TripCount::Static(100));
923
924 let dynamic_trip = TripCount::Dynamic;
925 assert_eq!(dynamic_trip, TripCount::Dynamic);
926 }
927
928 #[test]
937 fn test_m3_matmul_auto_vectorizes() {
938 use crate::vectorize::{VectorizeConfig, VectorizePass};
939 use bhc_index::Idx;
940 use bhc_tensor_ir::BufferId;
941
942 let loop_id = LoopId::new(0);
944 let loop_var = ValueId::new(0);
945
946 let mem_ref = MemRef {
948 buffer: BufferId::new(0),
949 index: Value::Var(loop_var, LoopType::Scalar(ScalarType::I64)),
950 elem_ty: LoopType::Scalar(ScalarType::F32),
951 access: AccessPattern::Sequential,
952 };
953
954 let mut body = Body::new();
955 let load_a = ValueId::new(1);
956 body.push(Stmt::Assign(load_a, Op::Load(mem_ref.clone())));
957
958 let load_b = ValueId::new(2);
959 body.push(Stmt::Assign(load_b, Op::Load(mem_ref.clone())));
960
961 let mul_result = ValueId::new(3);
962 body.push(Stmt::Assign(
963 mul_result,
964 Op::Binary(
965 BinOp::Mul,
966 Value::Var(load_a, LoopType::Scalar(ScalarType::F32)),
967 Value::Var(load_b, LoopType::Scalar(ScalarType::F32)),
968 ),
969 ));
970
971 let acc = ValueId::new(4);
973 let fma_result = ValueId::new(5);
974 body.push(Stmt::Assign(
975 fma_result,
976 Op::Fma(
977 Value::Var(load_a, LoopType::Scalar(ScalarType::F32)),
978 Value::Var(load_b, LoopType::Scalar(ScalarType::F32)),
979 Value::Var(acc, LoopType::Scalar(ScalarType::F32)),
980 ),
981 ));
982
983 let lp = Loop {
984 id: loop_id,
985 var: loop_var,
986 lower: Value::i64(0),
987 upper: Value::i64(256), step: Value::i64(1),
989 body,
990 attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT,
991 };
992
993 let mut outer_body = Body::new();
994 outer_body.push(Stmt::Loop(lp));
995
996 let ir = LoopIR {
997 name: bhc_intern::Symbol::intern("matmul_kernel"),
998 params: vec![],
999 return_ty: LoopType::Void,
1000 body: outer_body,
1001 allocs: vec![],
1002 loop_info: vec![LoopMetadata {
1003 id: loop_id,
1004 trip_count: TripCount::Static(256),
1005 vector_width: None,
1006 parallel_chunk: None,
1007 unroll_factor: None,
1008 dependencies: Vec::new(),
1009 }],
1010 };
1011
1012 let config_x86 = VectorizeConfig {
1014 target: TargetArch::X86_64Avx2,
1015 ..Default::default()
1016 };
1017 let mut pass_x86 = VectorizePass::new(config_x86);
1018 let analysis_x86 = pass_x86.analyze(&ir);
1019 let info_x86 = analysis_x86.get(&loop_id).expect("loop should be analyzed");
1020
1021 assert!(
1022 info_x86.vectorizable,
1023 "M3 FAIL: matmul kernel not vectorizable on x86_64 AVX2"
1024 );
1025 assert_eq!(
1026 info_x86.recommended_width, 8,
1027 "M3 FAIL: x86_64 AVX2 should use 8-wide vectors for f32"
1028 );
1029
1030 let config_arm = VectorizeConfig {
1032 target: TargetArch::Aarch64Neon,
1033 ..Default::default()
1034 };
1035 let mut pass_arm = VectorizePass::new(config_arm);
1036 let analysis_arm = pass_arm.analyze(&ir);
1037 let info_arm = analysis_arm.get(&loop_id).expect("loop should be analyzed");
1038
1039 assert!(
1040 info_arm.vectorizable,
1041 "M3 FAIL: matmul kernel not vectorizable on aarch64 NEON"
1042 );
1043 assert_eq!(
1044 info_arm.recommended_width, 4,
1045 "M3 FAIL: aarch64 NEON should use 4-wide vectors for f32"
1046 );
1047 }
1048
1049 #[test]
1054 fn test_m3_reductions_scale_linearly() {
1055 use crate::parallel::{ParReduce, ParallelConfig, Range};
1056 use crate::ReduceOp;
1057
1058 let data_size = 1_000_000; for worker_count in [1, 2, 4, 8] {
1062 let config = ParallelConfig {
1063 worker_count,
1064 deterministic: true,
1065 ..Default::default()
1066 };
1067
1068 let par_reduce = ParReduce {
1069 size: data_size,
1070 op: ReduceOp::Add,
1071 config,
1072 };
1073
1074 let chunks = par_reduce.chunk_assignments();
1075
1076 assert_eq!(
1078 chunks.len(),
1079 worker_count,
1080 "M3 FAIL: Expected {} chunks for {} workers",
1081 worker_count,
1082 worker_count
1083 );
1084
1085 let total_work: usize = chunks.iter().map(|c| c.len()).sum();
1087 assert_eq!(
1088 total_work, data_size,
1089 "M3 FAIL: Total work should equal data size"
1090 );
1091
1092 let expected_per_worker = data_size / worker_count;
1094 for (i, chunk) in chunks.iter().enumerate() {
1095 let diff = (chunk.len() as i64 - expected_per_worker as i64).abs();
1096 assert!(
1097 diff <= 1,
1098 "M3 FAIL: Worker {} has {} elements, expected ~{} (diff={})",
1099 i,
1100 chunk.len(),
1101 expected_per_worker,
1102 diff
1103 );
1104 }
1105 }
1106
1107 let _config_4 = ParallelConfig {
1109 worker_count: 4,
1110 ..Default::default()
1111 };
1112 let chunks_4 = Range::new(0, data_size as i64).chunk(4);
1113
1114 let _config_8 = ParallelConfig {
1115 worker_count: 8,
1116 ..Default::default()
1117 };
1118 let chunks_8 = Range::new(0, data_size as i64).chunk(8);
1119
1120 let avg_chunk_4: usize = chunks_4.iter().map(|c| c.len()).sum::<usize>() / 4;
1121 let avg_chunk_8: usize = chunks_8.iter().map(|c| c.len()).sum::<usize>() / 8;
1122
1123 let ratio = avg_chunk_4 as f64 / avg_chunk_8 as f64;
1125 assert!(
1126 (ratio - 2.0).abs() < 0.1,
1127 "M3 FAIL: Chunk size ratio should be ~2.0, got {}",
1128 ratio
1129 );
1130 }
1131
1132 #[test]
1136 fn test_m3_deterministic_mode() {
1137 use crate::parallel::{ParReduce, ParallelConfig, ParallelStrategy};
1138 use crate::ReduceOp;
1139
1140 let data_size = 100_000;
1141 let worker_count = 8;
1142
1143 let config = ParallelConfig {
1145 worker_count,
1146 deterministic: true,
1147 ..Default::default()
1148 };
1149
1150 let par_reduce = ParReduce {
1151 size: data_size,
1152 op: ReduceOp::Add,
1153 config: config.clone(),
1154 };
1155
1156 let chunks1 = par_reduce.chunk_assignments();
1158 let chunks2 = par_reduce.chunk_assignments();
1159 let chunks3 = par_reduce.chunk_assignments();
1160
1161 for i in 0..worker_count {
1162 assert_eq!(
1163 chunks1[i].start, chunks2[i].start,
1164 "M3 FAIL: Chunk {} start differs between runs",
1165 i
1166 );
1167 assert_eq!(
1168 chunks1[i].end, chunks2[i].end,
1169 "M3 FAIL: Chunk {} end differs between runs",
1170 i
1171 );
1172 assert_eq!(
1173 chunks2[i].start, chunks3[i].start,
1174 "M3 FAIL: Chunk {} start differs between runs",
1175 i
1176 );
1177 assert_eq!(
1178 chunks2[i].end, chunks3[i].end,
1179 "M3 FAIL: Chunk {} end differs between runs",
1180 i
1181 );
1182 }
1183
1184 use crate::parallel::ParallelPass;
1186
1187 let parallel_config = ParallelConfig {
1188 worker_count: 8,
1189 deterministic: true,
1190 ..Default::default()
1191 };
1192
1193 let loop_id = LoopId::new(0);
1195 let mut body = Body::new();
1196
1197 let lp = Loop {
1198 id: loop_id,
1199 var: ValueId::new(0),
1200 lower: Value::i64(0),
1201 upper: Value::i64(100000),
1202 step: Value::i64(1),
1203 body: Body::new(),
1204 attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
1205 };
1206
1207 body.push(Stmt::Loop(lp));
1208
1209 let ir = LoopIR {
1210 name: bhc_intern::Symbol::intern("deterministic_test"),
1211 params: vec![],
1212 return_ty: LoopType::Void,
1213 body,
1214 allocs: vec![],
1215 loop_info: vec![LoopMetadata {
1216 id: loop_id,
1217 trip_count: TripCount::Static(100000),
1218 vector_width: None,
1219 parallel_chunk: None,
1220 unroll_factor: None,
1221 dependencies: Vec::new(),
1222 }],
1223 };
1224
1225 let mut pass = ParallelPass::new(parallel_config);
1226 let analysis = pass.analyze(&ir);
1227 let info = analysis.get(&loop_id).expect("loop should be analyzed");
1228
1229 assert!(
1230 info.parallelizable,
1231 "M3 FAIL: Loop should be parallelizable"
1232 );
1233 assert_eq!(
1234 info.strategy,
1235 ParallelStrategy::Static,
1236 "M3 FAIL: Deterministic mode should use Static scheduling"
1237 );
1238 }
1239
1240 #[test]
1242 fn test_m3_vectorized_parallel_reduction() {
1243 use crate::parallel::{ParallelConfig, ParallelPass};
1244 use crate::vectorize::{VectorizeConfig, VectorizePass};
1245 use bhc_index::Idx;
1246 use bhc_tensor_ir::BufferId;
1247
1248 let outer_loop_id = LoopId::new(0);
1250 let inner_loop_id = LoopId::new(1);
1251
1252 let mem_ref = MemRef {
1254 buffer: BufferId::new(0),
1255 index: Value::Var(ValueId::new(1), LoopType::Scalar(ScalarType::I64)),
1256 elem_ty: LoopType::Scalar(ScalarType::F32),
1257 access: AccessPattern::Sequential,
1258 };
1259
1260 let mut inner_body = Body::new();
1261 let load_result = ValueId::new(2);
1262 inner_body.push(Stmt::Assign(load_result, Op::Load(mem_ref)));
1263
1264 let inner_loop = Loop {
1265 id: inner_loop_id,
1266 var: ValueId::new(1),
1267 lower: Value::i64(0),
1268 upper: Value::i64(1024),
1269 step: Value::i64(1),
1270 body: inner_body,
1271 attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT | LoopAttrs::REDUCTION,
1272 };
1273
1274 let mut outer_body = Body::new();
1276 outer_body.push(Stmt::Loop(inner_loop));
1277
1278 let outer_loop = Loop {
1279 id: outer_loop_id,
1280 var: ValueId::new(0),
1281 lower: Value::i64(0),
1282 upper: Value::i64(10000),
1283 step: Value::i64(1),
1284 body: outer_body,
1285 attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
1286 };
1287
1288 let mut top_body = Body::new();
1289 top_body.push(Stmt::Loop(outer_loop));
1290
1291 let ir = LoopIR {
1292 name: bhc_intern::Symbol::intern("vec_par_reduce"),
1293 params: vec![],
1294 return_ty: LoopType::Void,
1295 body: top_body,
1296 allocs: vec![],
1297 loop_info: vec![
1298 LoopMetadata {
1299 id: outer_loop_id,
1300 trip_count: TripCount::Static(10000),
1301 vector_width: None,
1302 parallel_chunk: None,
1303 unroll_factor: None,
1304 dependencies: Vec::new(),
1305 },
1306 LoopMetadata {
1307 id: inner_loop_id,
1308 trip_count: TripCount::Static(1024),
1309 vector_width: None,
1310 parallel_chunk: None,
1311 unroll_factor: None,
1312 dependencies: Vec::new(),
1313 },
1314 ],
1315 };
1316
1317 let vec_config = VectorizeConfig {
1319 target: TargetArch::X86_64Avx2,
1320 ..Default::default()
1321 };
1322 let mut vec_pass = VectorizePass::new(vec_config);
1323 let vec_analysis = vec_pass.analyze(&ir);
1324
1325 let inner_info = vec_analysis
1327 .get(&inner_loop_id)
1328 .expect("inner loop analyzed");
1329 assert!(
1330 inner_info.vectorizable,
1331 "M3 FAIL: Inner reduction loop should be vectorizable"
1332 );
1333
1334 let par_config = ParallelConfig {
1336 worker_count: 8,
1337 deterministic: true,
1338 ..Default::default()
1339 };
1340 let mut par_pass = ParallelPass::new(par_config);
1341 let par_analysis = par_pass.analyze(&ir);
1342
1343 let outer_info = par_analysis
1345 .get(&outer_loop_id)
1346 .expect("outer loop analyzed");
1347 assert!(
1348 outer_info.parallelizable,
1349 "M3 FAIL: Outer loop should be parallelizable"
1350 );
1351 assert_eq!(
1352 outer_info.num_chunks, 8,
1353 "M3 FAIL: Should have 8 parallel chunks"
1354 );
1355 }
1356}