Skip to main content

bhc_loop_ir/
lib.rs

1//! # BHC Loop IR
2//!
3//! This crate defines the Loop Intermediate Representation for the Basel
4//! Haskell Compiler. Loop IR makes iteration structure explicit and is the
5//! target for vectorization and low-level optimization.
6//!
7//! ## Overview
8//!
9//! Loop IR is the lowest-level IR before code generation. It provides:
10//!
11//! - **Explicit iteration**: Loops with bounds and strides
12//! - **Vectorization information**: Which loops can be SIMD-ized
13//! - **Parallelization hints**: Which loops can run in parallel
14//! - **Memory access patterns**: For cache optimization
15//!
16//! ## IR Pipeline Position
17//!
18//! ```text
19//! Source Code
20//!     |
21//!     v
22//! [Parse/AST]
23//!     |
24//!     v
25//! [HIR]
26//!     |
27//!     v
28//! [Core IR]
29//!     |
30//!     v
31//! [Tensor IR]  <- High-level tensor operations
32//!     |
33//!     v
34//! [Loop IR]    <- This crate: explicit iteration
35//!     |
36//!     v
37//! [Codegen]    <- LLVM IR / Native code
38//! ```
39//!
40//! ## Key Transformations
41//!
42//! Loop IR supports several important optimizations:
43//!
44//! 1. **Loop tiling**: Break loops into cache-friendly tiles
45//! 2. **Vectorization**: Convert scalar operations to SIMD
46//! 3. **Parallelization**: Mark loops for parallel execution
47//! 4. **Interchange**: Reorder loops for better memory access
48//! 5. **Unrolling**: Reduce loop overhead
49//!
50//! ## Main Types
51//!
52//! - [`LoopIR`]: The top-level IR structure
53//! - [`Loop`]: A single loop with bounds and body
54//! - [`Stmt`]: Statements within loop bodies
55//! - [`Value`]: SSA values (registers)
56//! - [`MemRef`]: Memory references with access patterns
57//!
58//! ## M3 Deliverables
59//!
60//! This crate implements the following M3 features:
61//!
62//! - **SIMD Types**: [`LoopType::VEC4F32`], [`LoopType::VEC8F32`], [`LoopType::VEC2F64`], [`LoopType::VEC4F64`]
63//! - **Auto-vectorization**: [`vectorize::VectorizePass`]
64//! - **Parallel primitives**: [`parallel::ParFor`], [`parallel::ParMap`], [`parallel::ParReduce`]
65//! - **SIMD intrinsics**: [`vectorize::SimdIntrinsic`]
66//!
67//! ## See Also
68//!
69//! - `bhc-tensor-ir`: Tensor IR that lowers to Loop IR
70//! - `bhc-codegen`: Code generation from Loop IR
71//! - H26-SPEC Section 7: Tensor Model (lowering)
72
73#![warn(missing_docs)]
74#![warn(clippy::all)]
75#![warn(clippy::pedantic)]
76#![allow(clippy::module_name_repetitions)]
77
78use bhc_index::Idx;
79use bhc_intern::Symbol;
80use bhc_tensor_ir::{AllocRegion, BufferId, DType};
81use bitflags::bitflags;
82use serde::{Deserialize, Serialize};
83use smallvec::SmallVec;
84
85// ============================================================================
86// Submodules
87// ============================================================================
88
89pub mod lower;
90pub mod parallel;
91pub mod vectorize;
92
93// Re-export key types from submodules
94pub use lower::{lower_kernel, lower_kernels, LowerConfig, LowerError};
95pub use parallel::{
96    ParFor, ParMap, ParReduce, ParallelConfig, ParallelPass, ParallelStrategy, Range,
97};
98pub use vectorize::{SimdIntrinsic, VectorizeConfig, VectorizePass, VectorizeReport};
99
100/// A unique identifier for values (SSA registers).
101#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
102pub struct ValueId(u32);
103
104impl Idx for ValueId {
105    fn new(idx: usize) -> Self {
106        Self(idx as u32)
107    }
108
109    fn index(self) -> usize {
110        self.0 as usize
111    }
112}
113
114/// A unique identifier for loops.
115#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
116pub struct LoopId(u32);
117
118impl Idx for LoopId {
119    fn new(idx: usize) -> Self {
120        Self(idx as u32)
121    }
122
123    fn index(self) -> usize {
124        self.0 as usize
125    }
126}
127
128/// A unique identifier for basic blocks.
129#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
130pub struct BlockId(u32);
131
132impl Idx for BlockId {
133    fn new(idx: usize) -> Self {
134        Self(idx as u32)
135    }
136
137    fn index(self) -> usize {
138        self.0 as usize
139    }
140}
141
142/// The main Loop IR structure.
143#[derive(Clone, Debug, Serialize, Deserialize)]
144pub struct LoopIR {
145    /// Function name.
146    pub name: Symbol,
147    /// Function parameters.
148    pub params: Vec<Param>,
149    /// Return type.
150    pub return_ty: LoopType,
151    /// The body (list of statements and loops).
152    pub body: Body,
153    /// Memory allocations.
154    pub allocs: Vec<Alloc>,
155    /// Loop metadata for optimization.
156    pub loop_info: Vec<LoopMetadata>,
157}
158
159/// A function parameter.
160#[derive(Clone, Debug, Serialize, Deserialize)]
161pub struct Param {
162    /// Parameter name.
163    pub name: Symbol,
164    /// Parameter type.
165    pub ty: LoopType,
166    /// Whether this is a pointer to memory.
167    pub is_ptr: bool,
168}
169
170/// Types in Loop IR.
171#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
172pub enum LoopType {
173    /// Void (no value).
174    Void,
175    /// Scalar type.
176    Scalar(ScalarType),
177    /// Vector type (SIMD).
178    Vector(ScalarType, u8),
179    /// Pointer to memory.
180    Ptr(Box<LoopType>),
181}
182
183impl LoopType {
184    /// Returns the size in bytes.
185    #[must_use]
186    pub fn size_bytes(&self) -> usize {
187        match self {
188            Self::Void => 0,
189            Self::Scalar(s) => s.size_bytes(),
190            Self::Vector(s, width) => s.size_bytes() * (*width as usize),
191            Self::Ptr(_) => 8, // Assuming 64-bit pointers
192        }
193    }
194
195    /// Returns true if this is a void type.
196    #[must_use]
197    pub fn is_void(&self) -> bool {
198        matches!(self, Self::Void)
199    }
200
201    /// Returns true if this is a vector type.
202    #[must_use]
203    pub fn is_vector(&self) -> bool {
204        matches!(self, Self::Vector(_, _))
205    }
206}
207
208/// Scalar types in Loop IR.
209#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
210pub enum ScalarType {
211    /// Boolean.
212    Bool,
213    /// Signed integer with bit width.
214    Int(u8),
215    /// Unsigned integer with bit width.
216    UInt(u8),
217    /// Floating point with bit width.
218    Float(u8),
219}
220
221impl ScalarType {
222    /// Returns the size in bytes.
223    #[must_use]
224    pub const fn size_bytes(self) -> usize {
225        match self {
226            Self::Bool => 1,
227            Self::Int(bits) | Self::UInt(bits) | Self::Float(bits) => (bits as usize + 7) / 8,
228        }
229    }
230
231    /// Converts from tensor DType.
232    #[must_use]
233    pub fn from_dtype(dtype: DType) -> Self {
234        match dtype {
235            DType::Bool => Self::Bool,
236            DType::Int8 => Self::Int(8),
237            DType::Int16 => Self::Int(16),
238            DType::Int32 => Self::Int(32),
239            DType::Int64 => Self::Int(64),
240            DType::UInt8 => Self::UInt(8),
241            DType::UInt16 => Self::UInt(16),
242            DType::UInt32 => Self::UInt(32),
243            DType::UInt64 => Self::UInt(64),
244            DType::Float16 | DType::BFloat16 => Self::Float(16),
245            DType::Float32 => Self::Float(32),
246            DType::Float64 => Self::Float(64),
247            DType::Complex64 => Self::Float(32), // Represented as pair
248            DType::Complex128 => Self::Float(64),
249        }
250    }
251
252    /// 32-bit float scalar type.
253    pub const F32: Self = Self::Float(32);
254
255    /// 64-bit float scalar type.
256    pub const F64: Self = Self::Float(64);
257
258    /// 32-bit signed integer scalar type.
259    pub const I32: Self = Self::Int(32);
260
261    /// 64-bit signed integer scalar type.
262    pub const I64: Self = Self::Int(64);
263}
264
265// ============================================================================
266// SIMD Type Aliases and Constructors (M3 Deliverable)
267// ============================================================================
268
269impl LoopType {
270    // --- Standard SIMD Vector Types ---
271
272    /// 4-wide 32-bit float vector (128-bit, SSE/NEON compatible).
273    pub const VEC4F32: Self = Self::Vector(ScalarType::F32, 4);
274
275    /// 8-wide 32-bit float vector (256-bit, AVX compatible).
276    pub const VEC8F32: Self = Self::Vector(ScalarType::F32, 8);
277
278    /// 2-wide 64-bit float vector (128-bit, SSE/NEON compatible).
279    pub const VEC2F64: Self = Self::Vector(ScalarType::F64, 2);
280
281    /// 4-wide 64-bit float vector (256-bit, AVX compatible).
282    pub const VEC4F64: Self = Self::Vector(ScalarType::F64, 4);
283
284    /// 4-wide 32-bit integer vector (128-bit).
285    pub const VEC4I32: Self = Self::Vector(ScalarType::I32, 4);
286
287    /// 8-wide 32-bit integer vector (256-bit).
288    pub const VEC8I32: Self = Self::Vector(ScalarType::I32, 8);
289
290    /// Returns the natural vector width for a scalar type on the target.
291    ///
292    /// # Target Widths
293    ///
294    /// | Target | F32 | F64 | I32 |
295    /// |--------|-----|-----|-----|
296    /// | x86_64 (SSE) | 4 | 2 | 4 |
297    /// | x86_64 (AVX) | 8 | 4 | 8 |
298    /// | aarch64 (NEON) | 4 | 2 | 4 |
299    #[must_use]
300    pub fn natural_vector_width(scalar: ScalarType, target: TargetArch) -> u8 {
301        match (target, scalar) {
302            // x86_64 with AVX (256-bit vectors)
303            (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Float(32)) => 8,
304            (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Float(64)) => 4,
305            (TargetArch::X86_64Avx | TargetArch::X86_64Avx2, ScalarType::Int(32)) => 8,
306            // x86_64 with SSE (128-bit vectors)
307            (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Float(32)) => 4,
308            (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Float(64)) => 2,
309            (TargetArch::X86_64Sse | TargetArch::X86_64Sse2, ScalarType::Int(32)) => 4,
310            // aarch64 with NEON (128-bit vectors)
311            (TargetArch::Aarch64Neon, ScalarType::Float(32)) => 4,
312            (TargetArch::Aarch64Neon, ScalarType::Float(64)) => 2,
313            (TargetArch::Aarch64Neon, ScalarType::Int(32)) => 4,
314            // Fallback: no vectorization
315            _ => 1,
316        }
317    }
318
319    /// Creates a vector type for the given scalar and width.
320    #[must_use]
321    pub const fn vector(scalar: ScalarType, width: u8) -> Self {
322        Self::Vector(scalar, width)
323    }
324
325    /// Returns the vector width if this is a vector type, otherwise None.
326    #[must_use]
327    pub fn vector_width(&self) -> Option<u8> {
328        match self {
329            Self::Vector(_, w) => Some(*w),
330            _ => None,
331        }
332    }
333
334    /// Returns the scalar element type if this is a vector type.
335    #[must_use]
336    pub fn element_type(&self) -> Option<ScalarType> {
337        match self {
338            Self::Vector(s, _) => Some(*s),
339            Self::Scalar(s) => Some(*s),
340            _ => None,
341        }
342    }
343}
344
345/// Target architecture for vectorization decisions.
346#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
347pub enum TargetArch {
348    /// x86_64 with SSE instructions (128-bit).
349    X86_64Sse,
350    /// x86_64 with SSE2 instructions (128-bit).
351    X86_64Sse2,
352    /// x86_64 with AVX instructions (256-bit).
353    X86_64Avx,
354    /// x86_64 with AVX2 instructions (256-bit).
355    X86_64Avx2,
356    /// aarch64 with NEON instructions (128-bit).
357    Aarch64Neon,
358    /// Generic target (no vectorization).
359    Generic,
360}
361
362impl Default for TargetArch {
363    fn default() -> Self {
364        // Default to AVX for x86_64, NEON for aarch64
365        #[cfg(target_arch = "x86_64")]
366        {
367            Self::X86_64Avx2
368        }
369        #[cfg(target_arch = "aarch64")]
370        {
371            Self::Aarch64Neon
372        }
373        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
374        {
375            Self::Generic
376        }
377    }
378}
379
380/// A memory allocation.
381#[derive(Clone, Debug, Serialize, Deserialize)]
382pub struct Alloc {
383    /// Buffer identifier.
384    pub buffer: BufferId,
385    /// Name for debugging.
386    pub name: Symbol,
387    /// Element type.
388    pub elem_ty: ScalarType,
389    /// Total size in elements.
390    pub size: AllocSize,
391    /// Alignment in bytes.
392    pub alignment: usize,
393    /// Allocation region.
394    pub region: AllocRegion,
395}
396
397/// Size of an allocation.
398#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
399pub enum AllocSize {
400    /// Statically known size.
401    Static(usize),
402    /// Dynamic size (computed at runtime).
403    Dynamic(ValueId),
404}
405
406/// The body of a function or loop.
407#[derive(Clone, Debug, Default, Serialize, Deserialize)]
408pub struct Body {
409    /// Statements in execution order.
410    pub stmts: Vec<Stmt>,
411}
412
413impl Body {
414    /// Creates an empty body.
415    #[must_use]
416    pub fn new() -> Self {
417        Self::default()
418    }
419
420    /// Adds a statement to the body.
421    pub fn push(&mut self, stmt: Stmt) {
422        self.stmts.push(stmt);
423    }
424}
425
426/// Statements in Loop IR.
427#[derive(Clone, Debug, Serialize, Deserialize)]
428pub enum Stmt {
429    /// An assignment: `%v = op`.
430    Assign(ValueId, Op),
431
432    /// A loop construct.
433    Loop(Loop),
434
435    /// A conditional branch.
436    If(IfStmt),
437
438    /// A store to memory.
439    Store(MemRef, Value),
440
441    /// A function call (for external functions).
442    Call(Option<ValueId>, Symbol, Vec<Value>),
443
444    /// A return statement.
445    Return(Option<Value>),
446
447    /// A barrier for synchronization.
448    Barrier(BarrierKind),
449
450    /// A comment/annotation.
451    Comment(String),
452}
453
454/// A loop construct.
455#[derive(Clone, Debug, Serialize, Deserialize)]
456pub struct Loop {
457    /// Unique loop identifier.
458    pub id: LoopId,
459    /// Loop variable.
460    pub var: ValueId,
461    /// Lower bound (inclusive).
462    pub lower: Value,
463    /// Upper bound (exclusive).
464    pub upper: Value,
465    /// Step size.
466    pub step: Value,
467    /// Loop body.
468    pub body: Body,
469    /// Loop attributes.
470    pub attrs: LoopAttrs,
471}
472
473bitflags! {
474    /// Loop attributes for optimization.
475    #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
476    pub struct LoopAttrs: u32 {
477        /// Loop can be parallelized.
478        const PARALLEL = 0b0000_0001;
479        /// Loop can be vectorized.
480        const VECTORIZE = 0b0000_0010;
481        /// Loop should be unrolled.
482        const UNROLL = 0b0000_0100;
483        /// Loop is a reduction loop.
484        const REDUCTION = 0b0000_1000;
485        /// Loop iterations are independent.
486        const INDEPENDENT = 0b0001_0000;
487        /// Loop has been tiled.
488        const TILED = 0b0010_0000;
489        /// Loop is the innermost of a tile.
490        const TILE_INNER = 0b0100_0000;
491    }
492}
493
494/// Loop metadata for optimization.
495#[derive(Clone, Debug, Serialize, Deserialize)]
496pub struct LoopMetadata {
497    /// Loop identifier.
498    pub id: LoopId,
499    /// Trip count (iterations).
500    pub trip_count: TripCount,
501    /// Vectorization width (if applicable).
502    pub vector_width: Option<u8>,
503    /// Parallel chunk size (if applicable).
504    pub parallel_chunk: Option<usize>,
505    /// Unroll factor (if applicable).
506    pub unroll_factor: Option<u8>,
507    /// Dependencies with other loops.
508    pub dependencies: Vec<LoopDependency>,
509}
510
511/// Trip count information.
512#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
513pub enum TripCount {
514    /// Statically known trip count.
515    Static(usize),
516    /// Dynamic trip count.
517    Dynamic,
518    /// Bounded trip count (upper bound known).
519    Bounded(usize),
520}
521
522/// A dependency between loops.
523#[derive(Clone, Debug, Serialize, Deserialize)]
524pub struct LoopDependency {
525    /// Source loop.
526    pub source: LoopId,
527    /// Target loop.
528    pub target: LoopId,
529    /// Dependency type.
530    pub kind: DependencyKind,
531    /// Distance vector (for affine dependencies).
532    pub distance: Option<Vec<i32>>,
533}
534
535/// Kinds of dependencies.
536#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
537pub enum DependencyKind {
538    /// Flow dependency (read after write).
539    Flow,
540    /// Anti dependency (write after read).
541    Anti,
542    /// Output dependency (write after write).
543    Output,
544    /// Input dependency (read after read, for locality).
545    Input,
546}
547
548/// A conditional statement.
549#[derive(Clone, Debug, Serialize, Deserialize)]
550pub struct IfStmt {
551    /// Condition value.
552    pub cond: Value,
553    /// Then branch.
554    pub then_body: Body,
555    /// Else branch (optional).
556    pub else_body: Option<Body>,
557}
558
559/// A value (SSA reference or constant).
560#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
561pub enum Value {
562    /// A register/variable reference.
563    Var(ValueId, LoopType),
564    /// An integer constant.
565    IntConst(i64, ScalarType),
566    /// A floating-point constant.
567    FloatConst(f64, ScalarType),
568    /// A boolean constant.
569    BoolConst(bool),
570    /// Undefined value.
571    Undef(LoopType),
572}
573
574impl Value {
575    /// Returns the type of this value.
576    #[must_use]
577    pub fn ty(&self) -> LoopType {
578        match self {
579            Self::Var(_, ty) => ty.clone(),
580            Self::IntConst(_, s) => LoopType::Scalar(*s),
581            Self::FloatConst(_, s) => LoopType::Scalar(*s),
582            Self::BoolConst(_) => LoopType::Scalar(ScalarType::Bool),
583            Self::Undef(ty) => ty.clone(),
584        }
585    }
586
587    /// Creates an integer constant.
588    #[must_use]
589    pub fn int(n: i64, bits: u8) -> Self {
590        Self::IntConst(n, ScalarType::Int(bits))
591    }
592
593    /// Creates a 64-bit integer constant.
594    #[must_use]
595    pub fn i64(n: i64) -> Self {
596        Self::int(n, 64)
597    }
598
599    /// Creates a float constant.
600    #[must_use]
601    pub fn float(f: f64, bits: u8) -> Self {
602        Self::FloatConst(f, ScalarType::Float(bits))
603    }
604
605    /// Creates a 64-bit float constant.
606    #[must_use]
607    pub fn f64(f: f64) -> Self {
608        Self::float(f, 64)
609    }
610}
611
612/// Operations in Loop IR.
613#[derive(Clone, Debug, Serialize, Deserialize)]
614pub enum Op {
615    /// Load from memory.
616    Load(MemRef),
617
618    /// Binary arithmetic operation.
619    Binary(BinOp, Value, Value),
620
621    /// Unary operation.
622    Unary(UnOp, Value),
623
624    /// Comparison.
625    Cmp(CmpOp, Value, Value),
626
627    /// Select (conditional).
628    Select(Value, Value, Value),
629
630    /// Cast between types.
631    Cast(Value, LoopType),
632
633    /// Vector broadcast (scalar to vector).
634    Broadcast(Value, u8),
635
636    /// Vector extract (vector to scalar).
637    Extract(Value, u8),
638
639    /// Vector insert.
640    Insert(Value, Value, u8),
641
642    /// Vector shuffle.
643    Shuffle(Value, Value, Vec<i32>),
644
645    /// Reduction within a vector.
646    VecReduce(ReduceOp, Value),
647
648    /// Fused multiply-add: a * b + c.
649    Fma(Value, Value, Value),
650
651    /// Pointer arithmetic.
652    PtrAdd(Value, Value),
653
654    /// Get pointer to buffer element.
655    GetPtr(BufferId, Value),
656
657    /// Phi node (for SSA).
658    Phi(Vec<(BlockId, Value)>),
659}
660
661/// Binary operations.
662#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
663pub enum BinOp {
664    // Arithmetic
665    /// Addition.
666    Add,
667    /// Subtraction.
668    Sub,
669    /// Multiplication.
670    Mul,
671    /// Signed division.
672    SDiv,
673    /// Unsigned division.
674    UDiv,
675    /// Floating-point division.
676    FDiv,
677    /// Signed remainder.
678    SRem,
679    /// Unsigned remainder.
680    URem,
681    /// Floating-point remainder.
682    FRem,
683
684    // Bitwise
685    /// Bitwise AND.
686    And,
687    /// Bitwise OR.
688    Or,
689    /// Bitwise XOR.
690    Xor,
691    /// Left shift.
692    Shl,
693    /// Logical right shift.
694    LShr,
695    /// Arithmetic right shift.
696    AShr,
697
698    // Min/Max
699    /// Signed minimum.
700    SMin,
701    /// Unsigned minimum.
702    UMin,
703    /// Floating-point minimum.
704    FMin,
705    /// Signed maximum.
706    SMax,
707    /// Unsigned maximum.
708    UMax,
709    /// Floating-point maximum.
710    FMax,
711}
712
713/// Unary operations.
714#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
715pub enum UnOp {
716    /// Negation.
717    Neg,
718    /// Floating-point negation.
719    FNeg,
720    /// Bitwise NOT.
721    Not,
722    /// Absolute value.
723    Abs,
724    /// Floating-point absolute value.
725    FAbs,
726    /// Square root.
727    Sqrt,
728    /// Reciprocal square root.
729    Rsqrt,
730    /// Floor.
731    Floor,
732    /// Ceiling.
733    Ceil,
734    /// Round to nearest.
735    Round,
736    /// Truncate.
737    Trunc,
738    /// Exponential.
739    Exp,
740    /// Natural logarithm.
741    Log,
742    /// Sine.
743    Sin,
744    /// Cosine.
745    Cos,
746}
747
748/// Comparison operations.
749#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
750pub enum CmpOp {
751    /// Equal.
752    Eq,
753    /// Not equal.
754    Ne,
755    /// Signed less than.
756    SLt,
757    /// Signed less than or equal.
758    SLe,
759    /// Signed greater than.
760    SGt,
761    /// Signed greater than or equal.
762    SGe,
763    /// Unsigned less than.
764    ULt,
765    /// Unsigned less than or equal.
766    ULe,
767    /// Unsigned greater than.
768    UGt,
769    /// Unsigned greater than or equal.
770    UGe,
771    /// Floating-point ordered equal.
772    OEq,
773    /// Floating-point ordered not equal.
774    ONe,
775    /// Floating-point ordered less than.
776    OLt,
777    /// Floating-point ordered less than or equal.
778    OLe,
779    /// Floating-point ordered greater than.
780    OGt,
781    /// Floating-point ordered greater than or equal.
782    OGe,
783}
784
785/// Reduction operations.
786#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
787pub enum ReduceOp {
788    /// Sum reduction.
789    Add,
790    /// Product reduction.
791    Mul,
792    /// Minimum reduction.
793    Min,
794    /// Maximum reduction.
795    Max,
796    /// AND reduction.
797    And,
798    /// OR reduction.
799    Or,
800    /// XOR reduction.
801    Xor,
802}
803
804/// A memory reference.
805#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
806pub struct MemRef {
807    /// The buffer being accessed.
808    pub buffer: BufferId,
809    /// The index/offset.
810    pub index: Value,
811    /// The element type.
812    pub elem_ty: LoopType,
813    /// Access pattern information.
814    pub access: AccessPattern,
815}
816
817/// Memory access patterns for optimization.
818#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
819pub enum AccessPattern {
820    /// Sequential access (stride 1).
821    Sequential,
822    /// Strided access.
823    Strided(i64),
824    /// Random/indirect access.
825    Random,
826    /// Broadcast (same element for all iterations).
827    Broadcast,
828    /// Affine access (linear combination of loop indices).
829    Affine(AffineAccess),
830}
831
832/// Affine memory access pattern.
833#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
834pub struct AffineAccess {
835    /// Coefficients for each loop variable.
836    pub coefficients: SmallVec<[(LoopId, i64); 4]>,
837    /// Constant offset.
838    pub offset: i64,
839}
840
841/// Barrier kinds for synchronization.
842#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
843pub enum BarrierKind {
844    /// Memory fence.
845    MemFence,
846    /// Full barrier (all threads).
847    Full,
848    /// Thread group barrier.
849    ThreadGroup,
850}
851
852/// Errors in Loop IR.
853#[derive(Clone, Debug, thiserror::Error, Serialize, Deserialize)]
854pub enum LoopIrError {
855    /// Type mismatch.
856    #[error("type mismatch: expected {expected:?}, got {got:?}")]
857    TypeMismatch {
858        /// Expected type.
859        expected: LoopType,
860        /// Actual type.
861        got: LoopType,
862    },
863
864    /// Invalid vector width.
865    #[error("invalid vector width {width} for type {ty:?}")]
866    InvalidVectorWidth {
867        /// The vector width.
868        width: u8,
869        /// The element type.
870        ty: ScalarType,
871    },
872
873    /// Out of bounds access.
874    #[error("buffer access out of bounds")]
875    OutOfBounds,
876
877    /// Invalid loop transformation.
878    #[error("invalid loop transformation: {reason}")]
879    InvalidTransform {
880        /// Reason for the error.
881        reason: String,
882    },
883}
884
885#[cfg(test)]
886mod tests {
887    use super::*;
888
889    #[test]
890    fn test_scalar_type_sizes() {
891        assert_eq!(ScalarType::Bool.size_bytes(), 1);
892        assert_eq!(ScalarType::Int(32).size_bytes(), 4);
893        assert_eq!(ScalarType::Float(64).size_bytes(), 8);
894    }
895
896    #[test]
897    fn test_loop_type_size() {
898        assert_eq!(LoopType::Scalar(ScalarType::Float(32)).size_bytes(), 4);
899        assert_eq!(LoopType::Vector(ScalarType::Float(32), 8).size_bytes(), 32);
900    }
901
902    #[test]
903    fn test_value_types() {
904        let v = Value::i64(42);
905        assert_eq!(v.ty(), LoopType::Scalar(ScalarType::Int(64)));
906
907        let f = Value::f64(3.14);
908        assert_eq!(f.ty(), LoopType::Scalar(ScalarType::Float(64)));
909    }
910
911    #[test]
912    fn test_loop_attrs() {
913        let attrs = LoopAttrs::PARALLEL | LoopAttrs::VECTORIZE;
914        assert!(attrs.contains(LoopAttrs::PARALLEL));
915        assert!(attrs.contains(LoopAttrs::VECTORIZE));
916        assert!(!attrs.contains(LoopAttrs::UNROLL));
917    }
918
919    #[test]
920    fn test_trip_count() {
921        let static_trip = TripCount::Static(100);
922        assert_eq!(static_trip, TripCount::Static(100));
923
924        let dynamic_trip = TripCount::Dynamic;
925        assert_eq!(dynamic_trip, TripCount::Dynamic);
926    }
927
928    // ========================================================================
929    // M3 Exit Criteria Integration Tests
930    // ========================================================================
931
932    /// M3 Exit Criterion 1: matmul microkernel auto-vectorizes on x86_64 and aarch64
933    ///
934    /// This test verifies that the vectorization pass correctly identifies
935    /// a matmul-like kernel as vectorizable and selects appropriate vector widths.
936    #[test]
937    fn test_m3_matmul_auto_vectorizes() {
938        use crate::vectorize::{VectorizeConfig, VectorizePass};
939        use bhc_index::Idx;
940        use bhc_tensor_ir::BufferId;
941
942        // Create a matmul-like loop structure (innermost loop computes dot product)
943        let loop_id = LoopId::new(0);
944        let loop_var = ValueId::new(0);
945
946        // Simulate the innermost loop of matmul: c[i,j] += a[i,k] * b[k,j]
947        let mem_ref = MemRef {
948            buffer: BufferId::new(0),
949            index: Value::Var(loop_var, LoopType::Scalar(ScalarType::I64)),
950            elem_ty: LoopType::Scalar(ScalarType::F32),
951            access: AccessPattern::Sequential,
952        };
953
954        let mut body = Body::new();
955        let load_a = ValueId::new(1);
956        body.push(Stmt::Assign(load_a, Op::Load(mem_ref.clone())));
957
958        let load_b = ValueId::new(2);
959        body.push(Stmt::Assign(load_b, Op::Load(mem_ref.clone())));
960
961        let mul_result = ValueId::new(3);
962        body.push(Stmt::Assign(
963            mul_result,
964            Op::Binary(
965                BinOp::Mul,
966                Value::Var(load_a, LoopType::Scalar(ScalarType::F32)),
967                Value::Var(load_b, LoopType::Scalar(ScalarType::F32)),
968            ),
969        ));
970
971        // FMA opportunity: acc = acc + a * b
972        let acc = ValueId::new(4);
973        let fma_result = ValueId::new(5);
974        body.push(Stmt::Assign(
975            fma_result,
976            Op::Fma(
977                Value::Var(load_a, LoopType::Scalar(ScalarType::F32)),
978                Value::Var(load_b, LoopType::Scalar(ScalarType::F32)),
979                Value::Var(acc, LoopType::Scalar(ScalarType::F32)),
980            ),
981        ));
982
983        let lp = Loop {
984            id: loop_id,
985            var: loop_var,
986            lower: Value::i64(0),
987            upper: Value::i64(256), // K dimension
988            step: Value::i64(1),
989            body,
990            attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT,
991        };
992
993        let mut outer_body = Body::new();
994        outer_body.push(Stmt::Loop(lp));
995
996        let ir = LoopIR {
997            name: bhc_intern::Symbol::intern("matmul_kernel"),
998            params: vec![],
999            return_ty: LoopType::Void,
1000            body: outer_body,
1001            allocs: vec![],
1002            loop_info: vec![LoopMetadata {
1003                id: loop_id,
1004                trip_count: TripCount::Static(256),
1005                vector_width: None,
1006                parallel_chunk: None,
1007                unroll_factor: None,
1008                dependencies: Vec::new(),
1009            }],
1010        };
1011
1012        // Test on x86_64 AVX2
1013        let config_x86 = VectorizeConfig {
1014            target: TargetArch::X86_64Avx2,
1015            ..Default::default()
1016        };
1017        let mut pass_x86 = VectorizePass::new(config_x86);
1018        let analysis_x86 = pass_x86.analyze(&ir);
1019        let info_x86 = analysis_x86.get(&loop_id).expect("loop should be analyzed");
1020
1021        assert!(
1022            info_x86.vectorizable,
1023            "M3 FAIL: matmul kernel not vectorizable on x86_64 AVX2"
1024        );
1025        assert_eq!(
1026            info_x86.recommended_width, 8,
1027            "M3 FAIL: x86_64 AVX2 should use 8-wide vectors for f32"
1028        );
1029
1030        // Test on aarch64 NEON
1031        let config_arm = VectorizeConfig {
1032            target: TargetArch::Aarch64Neon,
1033            ..Default::default()
1034        };
1035        let mut pass_arm = VectorizePass::new(config_arm);
1036        let analysis_arm = pass_arm.analyze(&ir);
1037        let info_arm = analysis_arm.get(&loop_id).expect("loop should be analyzed");
1038
1039        assert!(
1040            info_arm.vectorizable,
1041            "M3 FAIL: matmul kernel not vectorizable on aarch64 NEON"
1042        );
1043        assert_eq!(
1044            info_arm.recommended_width, 4,
1045            "M3 FAIL: aarch64 NEON should use 4-wide vectors for f32"
1046        );
1047    }
1048
1049    /// M3 Exit Criterion 2: Reductions scale linearly up to 8 cores
1050    ///
1051    /// This test verifies that parallel reduction chunking distributes
1052    /// work evenly across workers.
1053    #[test]
1054    fn test_m3_reductions_scale_linearly() {
1055        use crate::parallel::{ParReduce, ParallelConfig, Range};
1056        use crate::ReduceOp;
1057
1058        let data_size = 1_000_000; // 1M elements
1059
1060        // Test with different worker counts
1061        for worker_count in [1, 2, 4, 8] {
1062            let config = ParallelConfig {
1063                worker_count,
1064                deterministic: true,
1065                ..Default::default()
1066            };
1067
1068            let par_reduce = ParReduce {
1069                size: data_size,
1070                op: ReduceOp::Add,
1071                config,
1072            };
1073
1074            let chunks = par_reduce.chunk_assignments();
1075
1076            // Verify correct number of chunks
1077            assert_eq!(
1078                chunks.len(),
1079                worker_count,
1080                "M3 FAIL: Expected {} chunks for {} workers",
1081                worker_count,
1082                worker_count
1083            );
1084
1085            // Verify total work is correct
1086            let total_work: usize = chunks.iter().map(|c| c.len()).sum();
1087            assert_eq!(
1088                total_work, data_size,
1089                "M3 FAIL: Total work should equal data size"
1090            );
1091
1092            // Verify work is evenly distributed (within 1 element difference)
1093            let expected_per_worker = data_size / worker_count;
1094            for (i, chunk) in chunks.iter().enumerate() {
1095                let diff = (chunk.len() as i64 - expected_per_worker as i64).abs();
1096                assert!(
1097                    diff <= 1,
1098                    "M3 FAIL: Worker {} has {} elements, expected ~{} (diff={})",
1099                    i,
1100                    chunk.len(),
1101                    expected_per_worker,
1102                    diff
1103                );
1104            }
1105        }
1106
1107        // Verify scaling: doubling workers should halve chunk size
1108        let _config_4 = ParallelConfig {
1109            worker_count: 4,
1110            ..Default::default()
1111        };
1112        let chunks_4 = Range::new(0, data_size as i64).chunk(4);
1113
1114        let _config_8 = ParallelConfig {
1115            worker_count: 8,
1116            ..Default::default()
1117        };
1118        let chunks_8 = Range::new(0, data_size as i64).chunk(8);
1119
1120        let avg_chunk_4: usize = chunks_4.iter().map(|c| c.len()).sum::<usize>() / 4;
1121        let avg_chunk_8: usize = chunks_8.iter().map(|c| c.len()).sum::<usize>() / 8;
1122
1123        // 8 workers should have approximately half the chunk size of 4 workers
1124        let ratio = avg_chunk_4 as f64 / avg_chunk_8 as f64;
1125        assert!(
1126            (ratio - 2.0).abs() < 0.1,
1127            "M3 FAIL: Chunk size ratio should be ~2.0, got {}",
1128            ratio
1129        );
1130    }
1131
1132    /// M3 Exit Criterion 3: Deterministic mode produces identical results across runs
1133    ///
1134    /// This test verifies that parallel chunking is deterministic when configured.
1135    #[test]
1136    fn test_m3_deterministic_mode() {
1137        use crate::parallel::{ParReduce, ParallelConfig, ParallelStrategy};
1138        use crate::ReduceOp;
1139
1140        let data_size = 100_000;
1141        let worker_count = 8;
1142
1143        // Configure deterministic mode
1144        let config = ParallelConfig {
1145            worker_count,
1146            deterministic: true,
1147            ..Default::default()
1148        };
1149
1150        let par_reduce = ParReduce {
1151            size: data_size,
1152            op: ReduceOp::Add,
1153            config: config.clone(),
1154        };
1155
1156        // Run multiple times and verify identical chunk assignments
1157        let chunks1 = par_reduce.chunk_assignments();
1158        let chunks2 = par_reduce.chunk_assignments();
1159        let chunks3 = par_reduce.chunk_assignments();
1160
1161        for i in 0..worker_count {
1162            assert_eq!(
1163                chunks1[i].start, chunks2[i].start,
1164                "M3 FAIL: Chunk {} start differs between runs",
1165                i
1166            );
1167            assert_eq!(
1168                chunks1[i].end, chunks2[i].end,
1169                "M3 FAIL: Chunk {} end differs between runs",
1170                i
1171            );
1172            assert_eq!(
1173                chunks2[i].start, chunks3[i].start,
1174                "M3 FAIL: Chunk {} start differs between runs",
1175                i
1176            );
1177            assert_eq!(
1178                chunks2[i].end, chunks3[i].end,
1179                "M3 FAIL: Chunk {} end differs between runs",
1180                i
1181            );
1182        }
1183
1184        // Verify strategy is Static for deterministic mode
1185        use crate::parallel::ParallelPass;
1186
1187        let parallel_config = ParallelConfig {
1188            worker_count: 8,
1189            deterministic: true,
1190            ..Default::default()
1191        };
1192
1193        // Build a simple parallelizable loop
1194        let loop_id = LoopId::new(0);
1195        let mut body = Body::new();
1196
1197        let lp = Loop {
1198            id: loop_id,
1199            var: ValueId::new(0),
1200            lower: Value::i64(0),
1201            upper: Value::i64(100000),
1202            step: Value::i64(1),
1203            body: Body::new(),
1204            attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
1205        };
1206
1207        body.push(Stmt::Loop(lp));
1208
1209        let ir = LoopIR {
1210            name: bhc_intern::Symbol::intern("deterministic_test"),
1211            params: vec![],
1212            return_ty: LoopType::Void,
1213            body,
1214            allocs: vec![],
1215            loop_info: vec![LoopMetadata {
1216                id: loop_id,
1217                trip_count: TripCount::Static(100000),
1218                vector_width: None,
1219                parallel_chunk: None,
1220                unroll_factor: None,
1221                dependencies: Vec::new(),
1222            }],
1223        };
1224
1225        let mut pass = ParallelPass::new(parallel_config);
1226        let analysis = pass.analyze(&ir);
1227        let info = analysis.get(&loop_id).expect("loop should be analyzed");
1228
1229        assert!(
1230            info.parallelizable,
1231            "M3 FAIL: Loop should be parallelizable"
1232        );
1233        assert_eq!(
1234            info.strategy,
1235            ParallelStrategy::Static,
1236            "M3 FAIL: Deterministic mode should use Static scheduling"
1237        );
1238    }
1239
1240    /// M3 Integration: Complete pipeline test for vectorized parallel reduction
1241    #[test]
1242    fn test_m3_vectorized_parallel_reduction() {
1243        use crate::parallel::{ParallelConfig, ParallelPass};
1244        use crate::vectorize::{VectorizeConfig, VectorizePass};
1245        use bhc_index::Idx;
1246        use bhc_tensor_ir::BufferId;
1247
1248        // Create a reduction loop that should be both vectorized and parallelized
1249        let outer_loop_id = LoopId::new(0);
1250        let inner_loop_id = LoopId::new(1);
1251
1252        // Inner loop: vectorizable reduction
1253        let mem_ref = MemRef {
1254            buffer: BufferId::new(0),
1255            index: Value::Var(ValueId::new(1), LoopType::Scalar(ScalarType::I64)),
1256            elem_ty: LoopType::Scalar(ScalarType::F32),
1257            access: AccessPattern::Sequential,
1258        };
1259
1260        let mut inner_body = Body::new();
1261        let load_result = ValueId::new(2);
1262        inner_body.push(Stmt::Assign(load_result, Op::Load(mem_ref)));
1263
1264        let inner_loop = Loop {
1265            id: inner_loop_id,
1266            var: ValueId::new(1),
1267            lower: Value::i64(0),
1268            upper: Value::i64(1024),
1269            step: Value::i64(1),
1270            body: inner_body,
1271            attrs: LoopAttrs::VECTORIZE | LoopAttrs::INDEPENDENT | LoopAttrs::REDUCTION,
1272        };
1273
1274        // Outer loop: parallelizable
1275        let mut outer_body = Body::new();
1276        outer_body.push(Stmt::Loop(inner_loop));
1277
1278        let outer_loop = Loop {
1279            id: outer_loop_id,
1280            var: ValueId::new(0),
1281            lower: Value::i64(0),
1282            upper: Value::i64(10000),
1283            step: Value::i64(1),
1284            body: outer_body,
1285            attrs: LoopAttrs::PARALLEL | LoopAttrs::INDEPENDENT,
1286        };
1287
1288        let mut top_body = Body::new();
1289        top_body.push(Stmt::Loop(outer_loop));
1290
1291        let ir = LoopIR {
1292            name: bhc_intern::Symbol::intern("vec_par_reduce"),
1293            params: vec![],
1294            return_ty: LoopType::Void,
1295            body: top_body,
1296            allocs: vec![],
1297            loop_info: vec![
1298                LoopMetadata {
1299                    id: outer_loop_id,
1300                    trip_count: TripCount::Static(10000),
1301                    vector_width: None,
1302                    parallel_chunk: None,
1303                    unroll_factor: None,
1304                    dependencies: Vec::new(),
1305                },
1306                LoopMetadata {
1307                    id: inner_loop_id,
1308                    trip_count: TripCount::Static(1024),
1309                    vector_width: None,
1310                    parallel_chunk: None,
1311                    unroll_factor: None,
1312                    dependencies: Vec::new(),
1313                },
1314            ],
1315        };
1316
1317        // Apply vectorization pass
1318        let vec_config = VectorizeConfig {
1319            target: TargetArch::X86_64Avx2,
1320            ..Default::default()
1321        };
1322        let mut vec_pass = VectorizePass::new(vec_config);
1323        let vec_analysis = vec_pass.analyze(&ir);
1324
1325        // Verify inner loop is vectorizable
1326        let inner_info = vec_analysis
1327            .get(&inner_loop_id)
1328            .expect("inner loop analyzed");
1329        assert!(
1330            inner_info.vectorizable,
1331            "M3 FAIL: Inner reduction loop should be vectorizable"
1332        );
1333
1334        // Apply parallelization pass
1335        let par_config = ParallelConfig {
1336            worker_count: 8,
1337            deterministic: true,
1338            ..Default::default()
1339        };
1340        let mut par_pass = ParallelPass::new(par_config);
1341        let par_analysis = par_pass.analyze(&ir);
1342
1343        // Verify outer loop is parallelizable
1344        let outer_info = par_analysis
1345            .get(&outer_loop_id)
1346            .expect("outer loop analyzed");
1347        assert!(
1348            outer_info.parallelizable,
1349            "M3 FAIL: Outer loop should be parallelizable"
1350        );
1351        assert_eq!(
1352            outer_info.num_chunks, 8,
1353            "M3 FAIL: Should have 8 parallel chunks"
1354        );
1355    }
1356}