Skip to main content

bhc_loop_ir/
lib.rs

1//! # BHC Loop IR
2//!
3//! This crate defines the Loop Intermediate Representation for the Basel
4//! Haskell Compiler. Loop IR makes iteration structure explicit and is the
5//! target for vectorization and low-level optimization.
6//!
7//! ## Overview
8//!
9//! Loop IR is the lowest-level IR before code generation. It provides:
10//!
11//! - **Explicit iteration**: Loops with bounds and strides
12//! - **Vectorization information**: Which loops can be SIMD-ized
13//! - **Parallelization hints**: Which loops can run in parallel
14//! - **Memory access patterns**: For cache optimization
15//!
16//! ## IR Pipeline Position
17//!
18//! ```text
19//! Source Code
20//!     |
21//!     v
22//! [Parse/AST]
23//!     |
24//!     v
25//! [HIR]
26//!     |
27//!     v
28//! [Core IR]
29//!     |
30//!     v
31//! [Tensor IR]  <- High-level tensor operations
32//!     |
33//!     v
34//! [Loop IR]    <- This crate: explicit iteration
35//!     |
36//!     v
37//! [Codegen]    <- LLVM IR / Native code
38//! ```
39//!
40//! ## Key Transformations
41//!
42//! Loop IR supports several important optimizations:
43//!
44//! 1. **Loop tiling**: Break loops into cache-friendly tiles
45//! 2. **Vectorization**: Convert scalar operations to SIMD
46//! 3. **Parallelization**: Mark loops for parallel execution
47//! 4. **Interchange**: Reorder loops for better memory access
48//! 5. **Unrolling**: Reduce loop overhead
49//!
50//! ## Main Types
51//!
52//! - [`LoopIR`]: The top-level IR structure
53//! - [`Loop`]: A single loop with bounds and body
54//! - [`Stmt`]: Statements within loop bodies
55//! - [`Value`]: SSA values (registers)
56//! - [`MemRef`]: Memory references with access patterns
57//!
58//! ## M3 Deliverables
59//!
60//! This crate implements the following M3 features:
61//!
62//! - **SIMD Types**: [`LoopType::VEC4F32`], [`LoopType::VEC8F32`], [`LoopType::VEC2F64`], [`LoopType::VEC4F64`]
63//! - **Auto-vectorization**: [`vectorize::VectorizePass`]
64//! - **Parallel primitives**: [`parallel::ParFor`], [`parallel::ParMap`], [`parallel::ParReduce`]
65//! - **SIMD intrinsics**: [`vectorize::SimdIntrinsic`]
66//!
67//! ## See Also
68//!
69//! - `bhc-tensor-ir`: Tensor IR that lowers to Loop IR
70//! - `bhc-codegen`: Code generation from Loop IR
71//! - H26-SPEC Section 7: Tensor Model (lowering)
72
73#![warn(missing_docs)]
74#![allow(clippy::module_name_repetitions)]
75
76use bhc_index::Idx;
77use bhc_intern::Symbol;
78use bhc_tensor_ir::{AllocRegion, BufferId, DType};
79use bitflags::bitflags;
80use serde::{Deserialize, Serialize};
81use smallvec::SmallVec;
82
83// ============================================================================
84// Submodules
85// ============================================================================
86
87pub mod lower;
88pub mod parallel;
89pub mod vectorize;
90
91// Re-export key types from submodules
92pub use lower::{lower_kernel, lower_kernels, LowerConfig, LowerError};
93pub use parallel::{
94    ParFor, ParMap, ParReduce, ParallelConfig, ParallelPass, ParallelStrategy, Range,
95};
96pub use vectorize::{SimdIntrinsic, VectorizeConfig, VectorizePass, VectorizeReport};
97
98/// A unique identifier for values (SSA registers).
99#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
100pub struct ValueId(u32);
101
102impl Idx for ValueId {
103    fn new(idx: usize) -> Self {
104        Self(idx as u32)
105    }
106
107    fn index(self) -> usize {
108        self.0 as usize
109    }
110}
111
112/// A unique identifier for loops.
113#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
114pub struct LoopId(u32);
115
116impl Idx for LoopId {
117    fn new(idx: usize) -> Self {
118        Self(idx as u32)
119    }
120
121    fn index(self) -> usize {
122        self.0 as usize
123    }
124}
125
126/// A unique identifier for basic blocks.
127#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
128pub struct BlockId(u32);
129
130impl Idx for BlockId {
131    fn new(idx: usize) -> Self {
132        Self(idx as u32)
133    }
134
135    fn index(self) -> usize {
136        self.0 as usize
137    }
138}
139
140/// The main Loop IR structure.
141#[derive(Clone, Debug, Serialize, Deserialize)]
142pub struct LoopIR {
143    /// Function name.
144    pub name: Symbol,
145    /// Function parameters.
146    pub params: Vec<Param>,
147    /// Return type.
148    pub return_ty: LoopType,
149    /// The body (list of statements and loops).
150    pub body: Body,
151    /// Memory allocations.
152    pub allocs: Vec<Alloc>,
153    /// Loop metadata for optimization.
154    pub loop_info: Vec<LoopMetadata>,
155}
156
157/// A function parameter.
158#[derive(Clone, Debug, Serialize, Deserialize)]
159pub struct Param {
160    /// Parameter name.
161    pub name: Symbol,
162    /// Parameter type.
163    pub ty: LoopType,
164    /// Whether this is a pointer to memory.
165    pub is_ptr: bool,
166}
167
168/// Types in Loop IR.
169#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
170pub enum LoopType {
171    /// Void (no value).
172    Void,
173    /// Scalar type.
174    Scalar(ScalarType),
175    /// Vector type (SIMD).
176    Vector(ScalarType, u8),
177    /// Pointer to memory.
178    Ptr(Box<LoopType>),
179}
180
181impl LoopType {
182    /// Returns the size in bytes.
183    #[must_use]
184    pub fn size_bytes(&self) -> usize {
185        match self {
186            Self::Void => 0,
187            Self::Scalar(s) => s.size_bytes(),
188            Self::Vector(s, width) => s.size_bytes() * (*width as usize),
189            Self::Ptr(_) => 8, // Assuming 64-bit pointers
190        }
191    }
192
193    /// Returns true if this is a void type.
194    #[must_use]
195    pub fn is_void(&self) -> bool {
196        matches!(self, Self::Void)
197    }
198
199    /// Returns true if this is a vector type.
200    #[must_use]
201    pub fn is_vector(&self) -> bool {
202        matches!(self, Self::Vector(_, _))
203    }
204}
205
206/// Scalar types in Loop IR.
207#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
208pub enum ScalarType {
209    /// Boolean.
210    Bool,
211    /// Signed integer with bit width.
212    Int(u8),
213    /// Unsigned integer with bit width.
214    UInt(u8),
215    /// Floating point with bit width.
216    Float(u8),
217}
218
219impl ScalarType {
220    /// Returns the size in bytes.
221    #[must_use]
222    pub const fn size_bytes(self) -> usize {
223        match self {
224            Self::Bool => 1,
225            Self::Int(bits) | Self::UInt(bits) | Self::Float(bits) => (bits as usize).div_ceil(8),
226        }
227    }
228
229    /// Converts from tensor DType.
230    #[must_use]
231    pub fn from_dtype(dtype: DType) -> Self {
232        match dtype {
233            DType::Bool => Self::Bool,
234            DType::Int8 => Self::Int(8),
235            DType::Int16 => Self::Int(16),
236            DType::Int32 => Self::Int(32),
237            DType::Int64 => Self::Int(64),
238            DType::UInt8 => Self::UInt(8),
239            DType::UInt16 => Self::UInt(16),
240            DType::UInt32 => Self::UInt(32),
241            DType::UInt64 => Self::UInt(64),
242            DType::Float16 | DType::BFloat16 => Self::Float(16),
243            DType::Float32 => Self::Float(32),
244            DType::Float64 => Self::Float(64),
245            DType::Complex64 => Self::Float(32), // Represented as pair
246            DType::Complex128 => Self::Float(64),
247        }
248    }
249
250    /// 32-bit float scalar type.
251    pub const F32: Self = Self::Float(32);
252
253    /// 64-bit float scalar type.
254    pub const F64: Self = Self::Float(64);
255
256    /// 32-bit signed integer scalar type.
257    pub const I32: Self = Self::Int(32);
258
259    /// 64-bit signed integer scalar type.
260    pub const I64: Self = Self::Int(64);
261}
262
263// ============================================================================
264// SIMD Type Aliases and Constructors (M3 Deliverable)
265// ============================================================================
266
267impl LoopType {
268    // --- Standard SIMD Vector Types ---
269
270    /// 4-wide 32-bit float vector (128-bit, SSE/NEON compatible).
271    pub const VEC4F32: Self = Self::Vector(ScalarType::F32, 4);
272
273    /// 8-wide 32-bit float vector (256-bit, AVX compatible).
274    pub const VEC8F32: Self = Self::Vector(ScalarType::F32, 8);
275
276    /// 2-wide 64-bit float vector (128-bit, SSE/NEON compatible).
277    pub const VEC2F64: Self = Self::Vector(ScalarType::F64, 2);
278
279    /// 4-wide 64-bit float vector (256-bit, AVX compatible).
280    pub const VEC4F64: Self = Self::Vector(ScalarType::F64, 4);
281
282    /// 4-wide 32-bit integer vector (128-bit).
283    pub const VEC4I32: Self = Self::Vector(ScalarType::I32, 4);
284
285    /// 8-wide 32-bit integer vector (256-bit).
286    pub const VEC8I32: Self = Self::Vector(ScalarType::I32, 8);
287
288    /// Returns the natural vector width for a scalar type on the target.
289    ///
290    /// # Target Widths
291    ///
292    /// | Target | F32 | F64 | I32 |
293    /// |--------|-----|-----|-----|
294    /// | x86_64 (SSE) | 4 | 2 | 4 |
295    /// | x86_64 (AVX) | 8 | 4 | 8 |
296    /// | aarch64 (NEON) | 4 | 2 | 4 |
297    #[must_use]
298    pub fn natural_vector_width(scalar: ScalarType, target: TargetArch) -> u8 {
299        match (target, scalar) {
300            // x86_64 with AVX (256-bit vectors)
301            (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Float(32)) => 8,
302            (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Float(64)) => 4,
303            (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Int(32)) => 8,
304            // x86_64 with SSE (128-bit vectors)
305            (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Float(32)) => 4,
306            (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Float(64)) => 2,
307            (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Int(32)) => 4,
308            // aarch64 with NEON (128-bit vectors)
309            (TargetArch::Aarch64Neon, ScalarType::Float(32)) => 4,
310            (TargetArch::Aarch64Neon, ScalarType::Float(64)) => 2,
311            (TargetArch::Aarch64Neon, ScalarType::Int(32)) => 4,
312            // Fallback: no vectorization
313            _ => 1,
314        }
315    }
316
317    /// Creates a vector type for the given scalar and width.
318    #[must_use]
319    pub const fn vector(scalar: ScalarType, width: u8) -> Self {
320        Self::Vector(scalar, width)
321    }
322
323    /// Returns the vector width if this is a vector type, otherwise None.
324    #[must_use]
325    pub fn vector_width(&self) -> Option<u8> {
326        match self {
327            Self::Vector(_, w) => Some(*w),
328            _ => None,
329        }
330    }
331
332    /// Returns the scalar element type if this is a vector type.
333    #[must_use]
334    pub fn element_type(&self) -> Option<ScalarType> {
335        match self {
336            Self::Vector(s, _) => Some(*s),
337            Self::Scalar(s) => Some(*s),
338            _ => None,
339        }
340    }
341}
342
343/// Target architecture for vectorization decisions.
344#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
345pub enum TargetArch {
346    /// x86_64 with SSE instructions (128-bit).
347    X86_64Sse,
348    /// x86_64 with SSE2 instructions (128-bit).
349    X86_64Sse2,
350    /// x86_64 with AVX instructions (256-bit).
351    X86_64Avx,
352    /// x86_64 with AVX2 instructions (256-bit).
353    X86_64Avx2,
354    /// aarch64 with NEON instructions (128-bit).
355    #[default]
356    Aarch64Neon,
357    /// Generic target (no vectorization).
358    Generic,
359}
360
361/// A memory allocation.
362#[derive(Clone, Debug, Serialize, Deserialize)]
363pub struct Alloc {
364    /// Buffer identifier.
365    pub buffer: BufferId,
366    /// Name for debugging.
367    pub name: Symbol,
368    /// Element type.
369    pub elem_ty: ScalarType,
370    /// Total size in elements.
371    pub size: AllocSize,
372    /// Alignment in bytes.
373    pub alignment: usize,
374    /// Allocation region.
375    pub region: AllocRegion,
376}
377
378/// Size of an allocation.
379#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
380pub enum AllocSize {
381    /// Statically known size.
382    Static(usize),
383    /// Dynamic size (computed at runtime).
384    Dynamic(ValueId),
385}
386
387/// The body of a function or loop.
388#[derive(Clone, Debug, Default, Serialize, Deserialize)]
389pub struct Body {
390    /// Statements in execution order.
391    pub stmts: Vec<Stmt>,
392}
393
394impl Body {
395    /// Creates an empty body.
396    #[must_use]
397    pub fn new() -> Self {
398        Self::default()
399    }
400
401    /// Adds a statement to the body.
402    pub fn push(&mut self, stmt: Stmt) {
403        self.stmts.push(stmt);
404    }
405}
406
407/// Statements in Loop IR.
408#[derive(Clone, Debug, Serialize, Deserialize)]
409pub enum Stmt {
410    /// An assignment: `%v = op`.
411    Assign(ValueId, Op),
412
413    /// A loop construct.
414    Loop(Loop),
415
416    /// A conditional branch.
417    If(IfStmt),
418
419    /// A store to memory.
420    Store(MemRef, Value),
421
422    /// A function call (for external functions).
423    Call(Option<ValueId>, Symbol, Vec<Value>),
424
425    /// A return statement.
426    Return(Option<Value>),
427
428    /// A barrier for synchronization.
429    Barrier(BarrierKind),
430
431    /// A comment/annotation.
432    Comment(String),
433}
434
435/// A loop construct.
436#[derive(Clone, Debug, Serialize, Deserialize)]
437pub struct Loop {
438    /// Unique loop identifier.
439    pub id: LoopId,
440    /// Loop variable.
441    pub var: ValueId,
442    /// Lower bound (inclusive).
443    pub lower: Value,
444    /// Upper bound (exclusive).
445    pub upper: Value,
446    /// Step size.
447    pub step: Value,
448    /// Loop body.
449    pub body: Body,
450    /// Loop attributes.
451    pub attrs: LoopAttrs,
452}
453
454bitflags! {
455    /// Loop attributes for optimization.
456    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
457    pub struct LoopAttrs: u32 {
458        /// Loop can be parallelized.
459        const PARALLEL = 0b0000_0001;
460        /// Loop can be vectorized.
461        const VECTORIZE = 0b0000_0010;
462        /// Loop should be unrolled.
463        const UNROLL = 0b0000_0100;
464        /// Loop is a reduction loop.
465        const REDUCTION = 0b0000_1000;
466        /// Loop iterations are independent.
467        const INDEPENDENT = 0b0001_0000;
468        /// Loop has been tiled.
469        const TILED = 0b0010_0000;
470        /// Loop is the innermost of a tile.
471        const TILE_INNER = 0b0100_0000;
472    }
473}
474
475/// Loop metadata for optimization.
476#[derive(Clone, Debug, Serialize, Deserialize)]
477pub struct LoopMetadata {
478    /// Loop identifier.
479    pub id: LoopId,
480    /// Trip count (iterations).
481    pub trip_count: TripCount,
482    /// Vectorization width (if applicable).
483    pub vector_width: Option<u8>,
484    /// Parallel chunk size (if applicable).
485    pub parallel_chunk: Option<usize>,
486    /// Unroll factor (if applicable).
487    pub unroll_factor: Option<u8>,
488    /// Dependencies with other loops.
489    pub dependencies: Vec<LoopDependency>,
490}
491
492/// Trip count information.
493#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
494pub enum TripCount {
495    /// Statically known trip count.
496    Static(usize),
497    /// Dynamic trip count.
498    Dynamic,
499    /// Bounded trip count (upper bound known).
500    Bounded(usize),
501}
502
503/// A dependency between loops.
504#[derive(Clone, Debug, Serialize, Deserialize)]
505pub struct LoopDependency {
506    /// Source loop.
507    pub source: LoopId,
508    /// Target loop.
509    pub target: LoopId,
510    /// Dependency type.
511    pub kind: DependencyKind,
512    /// Distance vector (for affine dependencies).
513    pub distance: Option<Vec<i32>>,
514}
515
516/// Kinds of dependencies.
517#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
518pub enum DependencyKind {
519    /// Flow dependency (read after write).
520    Flow,
521    /// Anti dependency (write after read).
522    Anti,
523    /// Output dependency (write after write).
524    Output,
525    /// Input dependency (read after read, for locality).
526    Input,
527}
528
529/// A conditional statement.
530#[derive(Clone, Debug, Serialize, Deserialize)]
531pub struct IfStmt {
532    /// Condition value.
533    pub cond: Value,
534    /// Then branch.
535    pub then_body: Body,
536    /// Else branch (optional).
537    pub else_body: Option<Body>,
538}
539
540/// A value (SSA reference or constant).
541#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
542pub enum Value {
543    /// A register/variable reference.
544    Var(ValueId, LoopType),
545    /// An integer constant.
546    IntConst(i64, ScalarType),
547    /// A floating-point constant.
548    FloatConst(f64, ScalarType),
549    /// A boolean constant.
550    BoolConst(bool),
551    /// Undefined value.
552    Undef(LoopType),
553}
554
555impl Value {
556    /// Returns the type of this value.
557    #[must_use]
558    pub fn ty(&self) -> LoopType {
559        match self {
560            Self::Var(_, ty) => ty.clone(),
561            Self::IntConst(_, s) => LoopType::Scalar(*s),
562            Self::FloatConst(_, s) => LoopType::Scalar(*s),
563            Self::BoolConst(_) => LoopType::Scalar(ScalarType::Bool),
564            Self::Undef(ty) => ty.clone(),
565        }
566    }
567
568    /// Creates an integer constant.
569    #[must_use]
570    pub fn int(n: i64, bits: u8) -> Self {
571        Self::IntConst(n, ScalarType::Int(bits))
572    }
573
574    /// Creates a 64-bit integer constant.
575    #[must_use]
576    pub fn i64(n: i64) -> Self {
577        Self::int(n, 64)
578    }
579
580    /// Creates a float constant.
581    #[must_use]
582    pub fn float(f: f64, bits: u8) -> Self {
583        Self::FloatConst(f, ScalarType::Float(bits))
584    }
585
586    /// Creates a 64-bit float constant.
587    #[must_use]
588    pub fn f64(f: f64) -> Self {
589        Self::float(f, 64)
590    }
591}
592
593/// Operations in Loop IR.
594#[derive(Clone, Debug, Serialize, Deserialize)]
595pub enum Op {
596    /// Load from memory.
597    Load(MemRef),
598
599    /// Binary arithmetic operation.
600    Binary(BinOp, Value, Value),
601
602    /// Unary operation.
603    Unary(UnOp, Value),
604
605    /// Comparison.
606    Cmp(CmpOp, Value, Value),
607
608    /// Select (conditional).
609    Select(Value, Value, Value),
610
611    /// Cast between types.
612    Cast(Value, LoopType),
613
614    /// Vector broadcast (scalar to vector).
615    Broadcast(Value, u8),
616
617    /// Vector extract (vector to scalar).
618    Extract(Value, u8),
619
620    /// Vector insert.
621    Insert(Value, Value, u8),
622
623    /// Vector shuffle.
624    Shuffle(Value, Value, Vec<i32>),
625
626    /// Reduction within a vector.
627    VecReduce(ReduceOp, Value),
628
629    /// Fused multiply-add: a * b + c.
630    Fma(Value, Value, Value),
631
632    /// Pointer arithmetic.
633    PtrAdd(Value, Value),
634
635    /// Get pointer to buffer element.
636    GetPtr(BufferId, Value),
637
638    /// Phi node (for SSA).
639    Phi(Vec<(BlockId, Value)>),
640}
641
642/// Binary operations.
643#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
644pub enum BinOp {
645    // Arithmetic
646    /// Addition.
647    Add,
648    /// Subtraction.
649    Sub,
650    /// Multiplication.
651    Mul,
652    /// Signed division.
653    SDiv,
654    /// Unsigned division.
655    UDiv,
656    /// Floating-point division.
657    FDiv,
658    /// Signed remainder.
659    SRem,
660    /// Unsigned remainder.
661    URem,
662    /// Floating-point remainder.
663    FRem,
664
665    // Bitwise
666    /// Bitwise AND.
667    And,
668    /// Bitwise OR.
669    Or,
670    /// Bitwise XOR.
671    Xor,
672    /// Left shift.
673    Shl,
674    /// Logical right shift.
675    LShr,
676    /// Arithmetic right shift.
677    AShr,
678
679    // Min/Max
680    /// Signed minimum.
681    SMin,
682    /// Unsigned minimum.
683    UMin,
684    /// Floating-point minimum.
685    FMin,
686    /// Signed maximum.
687    SMax,
688    /// Unsigned maximum.
689    UMax,
690    /// Floating-point maximum.
691    FMax,
692}
693
694/// Unary operations.
695#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
696pub enum UnOp {
697    /// Negation.
698    Neg,
699    /// Floating-point negation.
700    FNeg,
701    /// Bitwise NOT.
702    Not,
703    /// Absolute value.
704    Abs,
705    /// Floating-point absolute value.
706    FAbs,
707    /// Square root.
708    Sqrt,
709    /// Reciprocal square root.
710    Rsqrt,
711    /// Floor.
712    Floor,
713    /// Ceiling.
714    Ceil,
715    /// Round to nearest.
716    Round,
717    /// Truncate.
718    Trunc,
719    /// Exponential.
720    Exp,
721    /// Natural logarithm.
722    Log,
723    /// Sine.
724    Sin,
725    /// Cosine.
726    Cos,
727}
728
729/// Comparison operations.
730#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
731pub enum CmpOp {
732    /// Equal.
733    Eq,
734    /// Not equal.
735    Ne,
736    /// Signed less than.
737    SLt,
738    /// Signed less than or equal.
739    SLe,
740    /// Signed greater than.
741    SGt,
742    /// Signed greater than or equal.
743    SGe,
744    /// Unsigned less than.
745    ULt,
746    /// Unsigned less than or equal.
747    ULe,
748    /// Unsigned greater than.
749    UGt,
750    /// Unsigned greater than or equal.
751    UGe,
752    /// Floating-point ordered equal.
753    OEq,
754    /// Floating-point ordered not equal.
755    ONe,
756    /// Floating-point ordered less than.
757    OLt,
758    /// Floating-point ordered less than or equal.
759    OLe,
760    /// Floating-point ordered greater than.
761    OGt,
762    /// Floating-point ordered greater than or equal.
763    OGe,
764}
765
766/// Reduction operations.
767#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
768pub enum ReduceOp {
769    /// Sum reduction.
770    Add,
771    /// Product reduction.
772    Mul,
773    /// Minimum reduction.
774    Min,
775    /// Maximum reduction.
776    Max,
777    /// AND reduction.
778    And,
779    /// OR reduction.
780    Or,
781    /// XOR reduction.
782    Xor,
783}
784
785/// A memory reference.
786#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
787pub struct MemRef {
788    /// The buffer being accessed.
789    pub buffer: BufferId,
790    /// The index/offset.
791    pub index: Value,
792    /// The element type.
793    pub elem_ty: LoopType,
794    /// Access pattern information.
795    pub access: AccessPattern,
796}
797
798/// Memory access patterns for optimization.
799#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
800pub enum AccessPattern {
801    /// Sequential access (stride 1).
802    Sequential,
803    /// Strided access.
804    Strided(i64),
805    /// Random/indirect access.
806    Random,
807    /// Broadcast (same element for all iterations).
808    Broadcast,
809    /// Affine access (linear combination of loop indices).
810    Affine(AffineAccess),
811}
812
813/// Affine memory access pattern.
814#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
815pub struct AffineAccess {
816    /// Coefficients for each loop variable.
817    pub coefficients: SmallVec<[(LoopId, i64); 4]>,
818    /// Constant offset.
819    pub offset: i64,
820}
821
822/// Barrier kinds for synchronization.
823#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
824pub enum BarrierKind {
825    /// Memory fence.
826    MemFence,
827    /// Full barrier (all threads).
828    Full,
829    /// Thread group barrier.
830    ThreadGroup,
831}
832
833/// Errors in Loop IR.
834#[derive(Clone, Debug, thiserror::Error, Serialize, Deserialize)]
835pub enum LoopIrError {
836    /// Type mismatch.
837    #[error("type mismatch: expected {expected:?}, got {got:?}")]
838    TypeMismatch {
839        /// Expected type.
840        expected: LoopType,
841        /// Actual type.
842        got: LoopType,
843    },
844
845    /// Invalid vector width.
846    #[error("invalid vector width {width} for type {ty:?}")]
847    InvalidVectorWidth {
848        /// The vector width.
849        width: u8,
850        /// The element type.
851        ty: ScalarType,
852    },
853
854    /// Out of bounds access.
855    #[error("buffer access out of bounds")]
856    OutOfBounds,
857
858    /// Invalid loop transformation.
859    #[error("invalid loop transformation: {reason}")]
860    InvalidTransform {
861        /// Reason for the error.
862        reason: String,
863    },
864}
865
866#[cfg(test)]
867mod tests {
868    use super::*;
869
870    #[test]
871    fn test_scalar_type_sizes() {
872        assert_eq!(ScalarType::Bool.size_bytes(), 1);
873        assert_eq!(ScalarType::Int(32).size_bytes(), 4);
874        assert_eq!(ScalarType::Float(64).size_bytes(), 8);
875    }
876
877    #[test]
878    fn test_loop_type_size() {
879        assert_eq!(LoopType::Scalar(ScalarType::Float(32)).size_bytes(), 4);
880        assert_eq!(LoopType::Vector(ScalarType::Float(32), 8).size_bytes(), 32);
881    }
882
883    #[test]
884    fn test_value_types() {
885        let v = Value::i64(42);
886        assert_eq!(v.ty(), LoopType::Scalar(ScalarType::Int(64)));
887
888        let f = Value::f64(2.5);
889        assert_eq!(f.ty(), LoopType::Scalar(ScalarType::Float(64)));
890    }
891
892    #[test]
893    fn test_loop_attrs() {
894        let attrs = LoopAttrs::PARALLEL | LoopAttrs::VECTORIZE;
895        assert!(attrs.contains(LoopAttrs::PARALLEL));
896        assert!(attrs.contains(LoopAttrs::VECTORIZE));
897        assert!(!attrs.contains(LoopAttrs::UNROLL));
898    }
899
900    #[test]
901    fn test_trip_count() {
902        let static_trip = TripCount::Static(100);
903        assert_eq!(static_trip, TripCount::Static(100));
904
905        let dynamic_trip = TripCount::Dynamic;
906        assert_eq!(dynamic_trip, TripCount::Dynamic);
907    }
908
909    // ========================================================================
910    // M3 Exit Criteria Integration Tests
911    // ========================================================================
912
913    /// M3 Exit Criterion 1: matmul microkernel auto-vectorizes on x86_64 and aarch64
914    ///
915    /// This test verifies that the vectorization pass correctly identifies
916    /// a matmul-like kernel as vectorizable and selects appropriate vector widths.
917    #[test]
918    fn test_m3_matmul_auto_vectorizes() {
919        use crate::vectorize::{VectorizeConfig, VectorizePass};
920        use bhc_index::Idx;
921        use bhc_tensor_ir::BufferId;
922
923        // Create a matmul-like loop structure (innermost loop computes dot product)
924        let loop_id = LoopId::new(0);
925        let loop_var = ValueId::new(0);
926
927        // Simulate the innermost loop of matmul: c[i,j] += a[i,k] * b[k,j]
928        let mem_ref = MemRef {
929            buffer: BufferId::new(0),
930            index: Value::Var(loop_var, LoopType::Scalar(ScalarType::I64)),
931            elem_ty: LoopType::Scalar(ScalarType::F32),
932            access: AccessPattern::Sequential,
933        };
934
935        let mut body = Body::new();
936        let load_a = ValueId::new(1);
937        body.push(Stmt::Assign(load_a, Op::Load(mem_ref.clone())));
938
939        let load_b = ValueId::new(2);
940        body.push(Stmt::Assign(load_b, Op::Load(mem_ref.clone())));
941
942        let mul_result = ValueId::new(3);
943        body.push(Stmt::Assign(
944            mul_result,
945            Op::Binary(
946                BinOp::Mul,
947                Value::Var(load_a, LoopType::Scalar(ScalarType::F32)),
948                Value::Var(load_b, LoopType::Scalar(ScalarType::F32)),
949            ),
950        ));
951
952        // FMA opportunity: acc = acc + a * b
953        let acc = ValueId::new(4);
954        let fma_result = ValueId::new(5);
955        body.push(Stmt::Assign(
956            fma_result,
957            Op::Fma(
958                Value::Var(load_a, LoopType::Scalar(ScalarType::F32)),
959                Value::Var(load_b, LoopType::Scalar(ScalarType::F32)),
960                Value::Var(acc, LoopType::Scalar(ScalarType::F32)),
961            ),
962        ));
963
964        let lp = Loop {
965            id: loop_id,
966            var: loop_var,
967            lower: Value::i64(0),
968            upper: Value::i64(256), // K dimension
969            step: Value::i64(1),
970            body,
971            attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT,
972        };
973
974        let mut outer_body = Body::new();
975        outer_body.push(Stmt::Loop(lp));
976
977        let ir = LoopIR {
978            name: bhc_intern::Symbol::intern("matmul_kernel"),
979            params: vec![],
980            return_ty: LoopType::Void,
981            body: outer_body,
982            allocs: vec![],
983            loop_info: vec![LoopMetadata {
984                id: loop_id,
985                trip_count: TripCount::Static(256),
986                vector_width: None,
987                parallel_chunk: None,
988                unroll_factor: None,
989                dependencies: Vec::new(),
990            }],
991        };
992
993        // Test on x86_64 AVX2
994        let config_x86 = VectorizeConfig {
995            target: TargetArch::X86_64Avx2,
996            ..Default::default()
997        };
998        let mut pass_x86 = VectorizePass::new(config_x86);
999        let analysis_x86 = pass_x86.analyze(&ir);
1000        let info_x86 = analysis_x86.get(&loop_id).expect("loop should be analyzed");
1001
1002        assert!(
1003            info_x86.vectorizable,
1004            "M3 FAIL: matmul kernel not vectorizable on x86_64 AVX2"
1005        );
1006        assert_eq!(
1007            info_x86.recommended_width, 8,
1008            "M3 FAIL: x86_64 AVX2 should use 8-wide vectors for f32"
1009        );
1010
1011        // Test on aarch64 NEON
1012        let config_arm = VectorizeConfig {
1013            target: TargetArch::Aarch64Neon,
1014            ..Default::default()
1015        };
1016        let mut pass_arm = VectorizePass::new(config_arm);
1017        let analysis_arm = pass_arm.analyze(&ir);
1018        let info_arm = analysis_arm.get(&loop_id).expect("loop should be analyzed");
1019
1020        assert!(
1021            info_arm.vectorizable,
1022            "M3 FAIL: matmul kernel not vectorizable on aarch64 NEON"
1023        );
1024        assert_eq!(
1025            info_arm.recommended_width, 4,
1026            "M3 FAIL: aarch64 NEON should use 4-wide vectors for f32"
1027        );
1028    }
1029
1030    /// M3 Exit Criterion 2: Reductions scale linearly up to 8 cores
1031    ///
1032    /// This test verifies that parallel reduction chunking distributes
1033    /// work evenly across workers.
1034    #[test]
1035    fn test_m3_reductions_scale_linearly() {
1036        use crate::parallel::{ParReduce, ParallelConfig, Range};
1037        use crate::ReduceOp;
1038
1039        let data_size = 1_000_000; // 1M elements
1040
1041        // Test with different worker counts
1042        for worker_count in [1, 2, 4, 8] {
1043            let config = ParallelConfig {
1044                worker_count,
1045                deterministic: true,
1046                ..Default::default()
1047            };
1048
1049            let par_reduce = ParReduce {
1050                size: data_size,
1051                op: ReduceOp::Add,
1052                config,
1053            };
1054
1055            let chunks = par_reduce.chunk_assignments();
1056
1057            // Verify correct number of chunks
1058            assert_eq!(
1059                chunks.len(),
1060                worker_count,
1061                "M3 FAIL: Expected {} chunks for {} workers",
1062                worker_count,
1063                worker_count
1064            );
1065
1066            // Verify total work is correct
1067            let total_work: usize = chunks.iter().map(|c| c.len()).sum();
1068            assert_eq!(
1069                total_work, data_size,
1070                "M3 FAIL: Total work should equal data size"
1071            );
1072
1073            // Verify work is evenly distributed (within 1 element difference)
1074            let expected_per_worker = data_size / worker_count;
1075            for (i, chunk) in chunks.iter().enumerate() {
1076                let diff = (chunk.len() as i64 - expected_per_worker as i64).abs();
1077                assert!(
1078                    diff <= 1,
1079                    "M3 FAIL: Worker {} has {} elements, expected ~{} (diff={})",
1080                    i,
1081                    chunk.len(),
1082                    expected_per_worker,
1083                    diff
1084                );
1085            }
1086        }
1087
1088        // Verify scaling: doubling workers should halve chunk size
1089        let _config_4 = ParallelConfig {
1090            worker_count: 4,
1091            ..Default::default()
1092        };
1093        let chunks_4 = Range::new(0, data_size as i64).chunk(4);
1094
1095        let _config_8 = ParallelConfig {
1096            worker_count: 8,
1097            ..Default::default()
1098        };
1099        let chunks_8 = Range::new(0, data_size as i64).chunk(8);
1100
1101        let avg_chunk_4: usize = chunks_4.iter().map(|c| c.len()).sum::<usize>() / 4;
1102        let avg_chunk_8: usize = chunks_8.iter().map(|c| c.len()).sum::<usize>() / 8;
1103
1104        // 8 workers should have approximately half the chunk size of 4 workers
1105        let ratio = avg_chunk_4 as f64 / avg_chunk_8 as f64;
1106        assert!(
1107            (ratio - 2.0).abs() < 0.1,
1108            "M3 FAIL: Chunk size ratio should be ~2.0, got {}",
1109            ratio
1110        );
1111    }
1112
1113    /// M3 Exit Criterion 3: Deterministic mode produces identical results across runs
1114    ///
1115    /// This test verifies that parallel chunking is deterministic when configured.
1116    #[test]
1117    fn test_m3_deterministic_mode() {
1118        use crate::parallel::{ParReduce, ParallelConfig, ParallelStrategy};
1119        use crate::ReduceOp;
1120
1121        let data_size = 100_000;
1122        let worker_count = 8;
1123
1124        // Configure deterministic mode
1125        let config = ParallelConfig {
1126            worker_count,
1127            deterministic: true,
1128            ..Default::default()
1129        };
1130
1131        let par_reduce = ParReduce {
1132            size: data_size,
1133            op: ReduceOp::Add,
1134            config: config.clone(),
1135        };
1136
1137        // Run multiple times and verify identical chunk assignments
1138        let chunks1 = par_reduce.chunk_assignments();
1139        let chunks2 = par_reduce.chunk_assignments();
1140        let chunks3 = par_reduce.chunk_assignments();
1141
1142        for i in 0..worker_count {
1143            assert_eq!(
1144                chunks1[i].start, chunks2[i].start,
1145                "M3 FAIL: Chunk {} start differs between runs",
1146                i
1147            );
1148            assert_eq!(
1149                chunks1[i].end, chunks2[i].end,
1150                "M3 FAIL: Chunk {} end differs between runs",
1151                i
1152            );
1153            assert_eq!(
1154                chunks2[i].start, chunks3[i].start,
1155                "M3 FAIL: Chunk {} start differs between runs",
1156                i
1157            );
1158            assert_eq!(
1159                chunks2[i].end, chunks3[i].end,
1160                "M3 FAIL: Chunk {} end differs between runs",
1161                i
1162            );
1163        }
1164
1165        // Verify strategy is Static for deterministic mode
1166        use crate::parallel::ParallelPass;
1167
1168        let parallel_config = ParallelConfig {
1169            worker_count: 8,
1170            deterministic: true,
1171            ..Default::default()
1172        };
1173
1174        // Build a simple parallelizable loop
1175        let loop_id = LoopId::new(0);
1176        let mut body = Body::new();
1177
1178        let lp = Loop {
1179            id: loop_id,
1180            var: ValueId::new(0),
1181            lower: Value::i64(0),
1182            upper: Value::i64(100000),
1183            step: Value::i64(1),
1184            body: Body::new(),
1185            attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
1186        };
1187
1188        body.push(Stmt::Loop(lp));
1189
1190        let ir = LoopIR {
1191            name: bhc_intern::Symbol::intern("deterministic_test"),
1192            params: vec![],
1193            return_ty: LoopType::Void,
1194            body,
1195            allocs: vec![],
1196            loop_info: vec![LoopMetadata {
1197                id: loop_id,
1198                trip_count: TripCount::Static(100000),
1199                vector_width: None,
1200                parallel_chunk: None,
1201                unroll_factor: None,
1202                dependencies: Vec::new(),
1203            }],
1204        };
1205
1206        let mut pass = ParallelPass::new(parallel_config);
1207        let analysis = pass.analyze(&ir);
1208        let info = analysis.get(&loop_id).expect("loop should be analyzed");
1209
1210        assert!(
1211            info.parallelizable,
1212            "M3 FAIL: Loop should be parallelizable"
1213        );
1214        assert_eq!(
1215            info.strategy,
1216            ParallelStrategy::Static,
1217            "M3 FAIL: Deterministic mode should use Static scheduling"
1218        );
1219    }
1220
1221    /// M3 Integration: Complete pipeline test for vectorized parallel reduction
1222    #[test]
1223    fn test_m3_vectorized_parallel_reduction() {
1224        use crate::parallel::{ParallelConfig, ParallelPass};
1225        use crate::vectorize::{VectorizeConfig, VectorizePass};
1226        use bhc_index::Idx;
1227        use bhc_tensor_ir::BufferId;
1228
1229        // Create a reduction loop that should be both vectorized and parallelized
1230        let outer_loop_id = LoopId::new(0);
1231        let inner_loop_id = LoopId::new(1);
1232
1233        // Inner loop: vectorizable reduction
1234        let mem_ref = MemRef {
1235            buffer: BufferId::new(0),
1236            index: Value::Var(ValueId::new(1), LoopType::Scalar(ScalarType::I64)),
1237            elem_ty: LoopType::Scalar(ScalarType::F32),
1238            access: AccessPattern::Sequential,
1239        };
1240
1241        let mut inner_body = Body::new();
1242        let load_result = ValueId::new(2);
1243        inner_body.push(Stmt::Assign(load_result, Op::Load(mem_ref)));
1244
1245        let inner_loop = Loop {
1246            id: inner_loop_id,
1247            var: ValueId::new(1),
1248            lower: Value::i64(0),
1249            upper: Value::i64(1024),
1250            step: Value::i64(1),
1251            body: inner_body,
1252            attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT | LoopAttrs::REDUCTION,
1253        };
1254
1255        // Outer loop: parallelizable
1256        let mut outer_body = Body::new();
1257        outer_body.push(Stmt::Loop(inner_loop));
1258
1259        let outer_loop = Loop {
1260            id: outer_loop_id,
1261            var: ValueId::new(0),
1262            lower: Value::i64(0),
1263            upper: Value::i64(10000),
1264            step: Value::i64(1),
1265            body: outer_body,
1266            attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
1267        };
1268
1269        let mut top_body = Body::new();
1270        top_body.push(Stmt::Loop(outer_loop));
1271
1272        let ir = LoopIR {
1273            name: bhc_intern::Symbol::intern("vec_par_reduce"),
1274            params: vec![],
1275            return_ty: LoopType::Void,
1276            body: top_body,
1277            allocs: vec![],
1278            loop_info: vec![
1279                LoopMetadata {
1280                    id: outer_loop_id,
1281                    trip_count: TripCount::Static(10000),
1282                    vector_width: None,
1283                    parallel_chunk: None,
1284                    unroll_factor: None,
1285                    dependencies: Vec::new(),
1286                },
1287                LoopMetadata {
1288                    id: inner_loop_id,
1289                    trip_count: TripCount::Static(1024),
1290                    vector_width: None,
1291                    parallel_chunk: None,
1292                    unroll_factor: None,
1293                    dependencies: Vec::new(),
1294                },
1295            ],
1296        };
1297
1298        // Apply vectorization pass
1299        let vec_config = VectorizeConfig {
1300            target: TargetArch::X86_64Avx2,
1301            ..Default::default()
1302        };
1303        let mut vec_pass = VectorizePass::new(vec_config);
1304        let vec_analysis = vec_pass.analyze(&ir);
1305
1306        // Verify inner loop is vectorizable
1307        let inner_info = vec_analysis
1308            .get(&inner_loop_id)
1309            .expect("inner loop analyzed");
1310        assert!(
1311            inner_info.vectorizable,
1312            "M3 FAIL: Inner reduction loop should be vectorizable"
1313        );
1314
1315        // Apply parallelization pass
1316        let par_config = ParallelConfig {
1317            worker_count: 8,
1318            deterministic: true,
1319            ..Default::default()
1320        };
1321        let mut par_pass = ParallelPass::new(par_config);
1322        let par_analysis = par_pass.analyze(&ir);
1323
1324        // Verify outer loop is parallelizable
1325        let outer_info = par_analysis
1326            .get(&outer_loop_id)
1327            .expect("outer loop analyzed");
1328        assert!(
1329            outer_info.parallelizable,
1330            "M3 FAIL: Outer loop should be parallelizable"
1331        );
1332        assert_eq!(
1333            outer_info.num_chunks, 8,
1334            "M3 FAIL: Should have 8 parallel chunks"
1335        );
1336    }
1337}