Skip to main content

bhc_tensor_ir/
lib.rs

1//! # BHC Tensor IR
2//!
3//! This crate defines the Tensor Intermediate Representation for the Basel
4//! Haskell Compiler. Tensor IR is specifically designed for numeric optimization
5//! and provides the foundation for guaranteed fusion and vectorization.
6//!
7//! ## Overview
8//!
9//! Tensor IR is the key to BHC's numeric performance. It captures:
10//!
11//! - **Shape and stride information**: For layout-aware optimization
12//! - **Element types (dtypes)**: For unboxed numeric computation
13//! - **Operation structure**: For fusion analysis
14//! - **Aliasing information**: For safe in-place updates
15//!
16//! ## H26-SPEC Section 7 Compliance
17//!
18//! Per the H26 specification, every tensor operation in Tensor IR must track:
19//!
20//! | Property | Description |
21//! |----------|-------------|
22//! | `dtype`  | Element type (Float32, Float64, Int32, etc.) |
23//! | `shape`  | Dimension sizes |
24//! | `strides`| Byte strides per dimension |
25//! | `layout` | Memory layout (contiguous, strided, tiled) |
26//! | `alias`  | Aliasing/ownership information |
27//!
28//! ## IR Pipeline Position
29//!
30//! ```text
31//! Source Code
32//!     |
33//!     v
34//! [Parse/AST]
35//!     |
36//!     v
37//! [HIR]
38//!     |
39//!     v
40//! [Core IR]   <- General purpose optimizations
41//!     |
42//!     | (Numeric Profile only)
43//!     v
44//! [Tensor IR] <- This crate: shape-aware, fusion-ready
45//!     |
46//!     v
47//! [Loop IR]   <- Explicit iteration
48//! ```
49//!
50//! ## Guaranteed Fusion Patterns
51//!
52//! Per H26-SPEC Section 8, these patterns MUST fuse:
53//!
54//! 1. `map f (map g x)` -> single traversal
55//! 2. `zipWith f (map g a) (map h b)` -> single traversal
56//! 3. `sum (map f x)` -> single traversal
57//! 4. `foldl' op z (map f x)` -> single traversal
58//!
59//! ## Main Types
60//!
61//! - [`TensorOp`][]: Tensor operations
62//! - [`TensorMeta`][]: Metadata (shape, stride, dtype)
63//! - [`Kernel`][]: A fused computation unit
64//! - [`Shape`][]: Tensor dimensions
65//! - [`DType`][]: Element types
66//!
67//! ## See Also
68//!
69//! - `bhc-core`: Core IR that lowers to Tensor IR
70//! - `bhc-loop-ir`: Loop IR for explicit iteration
71//! - H26-SPEC Section 7: Tensor Model
72//! - H26-SPEC Section 8: Fusion Laws
73
74#![warn(missing_docs)]
75
76pub mod fusion;
77pub mod lower;
78
79use bhc_index::Idx;
80use bhc_intern::Symbol;
81use bhc_span::Span;
82use serde::{Deserialize, Serialize};
83use smallvec::SmallVec;
84
85/// A unique identifier for tensor operations.
86#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
87pub struct TensorId(u32);
88
89impl Idx for TensorId {
90    fn new(idx: usize) -> Self {
91        Self(idx as u32)
92    }
93
94    fn index(self) -> usize {
95        self.0 as usize
96    }
97}
98
99/// A unique identifier for kernels.
100#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
101pub struct KernelId(u32);
102
103impl Idx for KernelId {
104    fn new(idx: usize) -> Self {
105        Self(idx as u32)
106    }
107
108    fn index(self) -> usize {
109        self.0 as usize
110    }
111}
112
113/// A unique identifier for buffers (memory allocations).
114#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
115pub struct BufferId(u32);
116
117impl Idx for BufferId {
118    fn new(idx: usize) -> Self {
119        Self(idx as u32)
120    }
121
122    fn index(self) -> usize {
123        self.0 as usize
124    }
125}
126
127/// Tensor element types (data types).
128///
129/// These represent the unboxed element types that can be stored
130/// in tensors. Each dtype has known size and alignment.
131#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
132pub enum DType {
133    /// Boolean (1 byte).
134    Bool,
135    /// 8-bit signed integer.
136    Int8,
137    /// 16-bit signed integer.
138    Int16,
139    /// 32-bit signed integer.
140    Int32,
141    /// 64-bit signed integer.
142    Int64,
143    /// 8-bit unsigned integer.
144    UInt8,
145    /// 16-bit unsigned integer.
146    UInt16,
147    /// 32-bit unsigned integer.
148    UInt32,
149    /// 64-bit unsigned integer.
150    UInt64,
151    /// 16-bit floating point (half precision).
152    Float16,
153    /// 32-bit floating point (single precision).
154    Float32,
155    /// 64-bit floating point (double precision).
156    Float64,
157    /// Brain floating point (bfloat16).
158    BFloat16,
159    /// Complex number (single precision).
160    Complex64,
161    /// Complex number (double precision).
162    Complex128,
163}
164
165impl DType {
166    /// Returns the size in bytes of this dtype.
167    #[must_use]
168    pub const fn size_bytes(self) -> usize {
169        match self {
170            Self::Bool | Self::Int8 | Self::UInt8 => 1,
171            Self::Int16 | Self::UInt16 | Self::Float16 | Self::BFloat16 => 2,
172            Self::Int32 | Self::UInt32 | Self::Float32 => 4,
173            Self::Int64 | Self::UInt64 | Self::Float64 | Self::Complex64 => 8,
174            Self::Complex128 => 16,
175        }
176    }
177
178    /// Returns the alignment in bytes for this dtype.
179    #[must_use]
180    pub const fn alignment(self) -> usize {
181        self.size_bytes()
182    }
183
184    /// Returns true if this is a floating-point type.
185    #[must_use]
186    pub const fn is_float(self) -> bool {
187        matches!(
188            self,
189            Self::Float16 | Self::Float32 | Self::Float64 | Self::BFloat16
190        )
191    }
192
193    /// Returns true if this is an integer type.
194    #[must_use]
195    pub const fn is_integer(self) -> bool {
196        matches!(
197            self,
198            Self::Int8
199                | Self::Int16
200                | Self::Int32
201                | Self::Int64
202                | Self::UInt8
203                | Self::UInt16
204                | Self::UInt32
205                | Self::UInt64
206        )
207    }
208
209    /// Returns true if this is a signed type.
210    #[must_use]
211    pub const fn is_signed(self) -> bool {
212        matches!(
213            self,
214            Self::Int8
215                | Self::Int16
216                | Self::Int32
217                | Self::Int64
218                | Self::Float16
219                | Self::Float32
220                | Self::Float64
221                | Self::BFloat16
222                | Self::Complex64
223                | Self::Complex128
224        )
225    }
226}
227
228/// A dimension size (may be static or dynamic).
229#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
230pub enum Dim {
231    /// A statically known dimension.
232    Static(usize),
233    /// A dynamically determined dimension (symbolic).
234    Dynamic(Symbol),
235}
236
237impl Dim {
238    /// Returns the static value if known.
239    #[must_use]
240    pub const fn static_value(&self) -> Option<usize> {
241        match self {
242            Self::Static(n) => Some(*n),
243            Self::Dynamic(_) => None,
244        }
245    }
246
247    /// Returns true if this dimension is statically known.
248    #[must_use]
249    pub const fn is_static(&self) -> bool {
250        matches!(self, Self::Static(_))
251    }
252}
253
254/// Tensor shape (list of dimensions).
255#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
256pub struct Shape(SmallVec<[Dim; 4]>);
257
258impl Shape {
259    /// Creates a new shape from dimensions.
260    #[must_use]
261    pub fn new(dims: impl IntoIterator<Item = Dim>) -> Self {
262        Self(dims.into_iter().collect())
263    }
264
265    /// Creates a shape from static dimensions.
266    #[must_use]
267    pub fn from_static(dims: impl IntoIterator<Item = usize>) -> Self {
268        Self(dims.into_iter().map(Dim::Static).collect())
269    }
270
271    /// Creates a scalar shape (rank 0).
272    #[must_use]
273    pub fn scalar() -> Self {
274        Self(SmallVec::new())
275    }
276
277    /// Returns the rank (number of dimensions).
278    #[must_use]
279    pub fn rank(&self) -> usize {
280        self.0.len()
281    }
282
283    /// Returns the dimensions.
284    #[must_use]
285    pub fn dims(&self) -> &[Dim] {
286        &self.0
287    }
288
289    /// Returns the total number of elements (if statically known).
290    #[must_use]
291    pub fn num_elements(&self) -> Option<usize> {
292        self.0
293            .iter()
294            .try_fold(1usize, |acc, dim| dim.static_value().map(|n| acc * n))
295    }
296
297    /// Returns true if this is a scalar (rank 0).
298    #[must_use]
299    pub fn is_scalar(&self) -> bool {
300        self.0.is_empty()
301    }
302
303    /// Returns true if all dimensions are statically known.
304    #[must_use]
305    pub fn is_static(&self) -> bool {
306        self.0.iter().all(Dim::is_static)
307    }
308}
309
310/// Memory strides for each dimension.
311#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
312pub struct Strides(SmallVec<[i64; 4]>);
313
314impl Strides {
315    /// Creates new strides.
316    #[must_use]
317    pub fn new(strides: impl IntoIterator<Item = i64>) -> Self {
318        Self(strides.into_iter().collect())
319    }
320
321    /// Computes contiguous (row-major) strides for a shape.
322    #[must_use]
323    pub fn contiguous(shape: &Shape, elem_size: usize) -> Option<Self> {
324        let mut strides = SmallVec::with_capacity(shape.rank());
325        let mut stride = elem_size as i64;
326
327        for dim in shape.dims().iter().rev() {
328            strides.push(stride);
329            stride *= dim.static_value()? as i64;
330        }
331
332        strides.reverse();
333        Some(Self(strides))
334    }
335
336    /// Returns the stride values.
337    #[must_use]
338    pub fn values(&self) -> &[i64] {
339        &self.0
340    }
341
342    /// Returns true if these strides represent contiguous memory.
343    #[must_use]
344    pub fn is_contiguous(&self, shape: &Shape, elem_size: usize) -> bool {
345        if let Some(contiguous) = Self::contiguous(shape, elem_size) {
346            self.0 == contiguous.0
347        } else {
348            false
349        }
350    }
351}
352
353/// Memory layout of a tensor.
354#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
355pub enum Layout {
356    /// Contiguous memory (row-major by default).
357    Contiguous,
358    /// Strided layout (possibly non-contiguous).
359    Strided,
360    /// Tiled layout for cache efficiency.
361    Tiled(TileInfo),
362}
363
364/// Tiling information for cache-friendly layouts.
365#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
366pub struct TileInfo {
367    /// Tile sizes for each dimension.
368    pub tile_sizes: SmallVec<[usize; 4]>,
369    /// The dimension order within tiles.
370    pub inner_order: SmallVec<[usize; 4]>,
371}
372
373/// Tensor metadata per H26-SPEC Section 7.3.
374#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
375pub struct TensorMeta {
376    /// Element type.
377    pub dtype: DType,
378    /// Tensor shape.
379    pub shape: Shape,
380    /// Memory strides.
381    pub strides: Strides,
382    /// Memory layout.
383    pub layout: Layout,
384    /// Aliasing information (buffer this tensor references).
385    pub alias: Option<BufferId>,
386}
387
388impl TensorMeta {
389    /// Creates metadata for a new contiguous tensor.
390    #[must_use]
391    pub fn new_contiguous(dtype: DType, shape: Shape) -> Option<Self> {
392        let strides = Strides::contiguous(&shape, dtype.size_bytes())?;
393        Some(Self {
394            dtype,
395            shape,
396            strides,
397            layout: Layout::Contiguous,
398            alias: None,
399        })
400    }
401
402    /// Returns the total size in bytes (if statically known).
403    #[must_use]
404    pub fn size_bytes(&self) -> Option<usize> {
405        self.shape
406            .num_elements()
407            .map(|n| n * self.dtype.size_bytes())
408    }
409}
410
411/// A reference to a tensor value.
412#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
413pub struct TensorRef {
414    /// The tensor ID.
415    pub id: TensorId,
416    /// The metadata.
417    pub meta: TensorMeta,
418}
419
420/// Tensor operations in the IR.
421///
422/// These operations form the building blocks of tensor computations.
423/// The fusion pass analyzes these to produce fused kernels.
424#[derive(Clone, Debug, Serialize, Deserialize)]
425pub enum TensorOp {
426    /// A constant tensor.
427    Constant(ConstantOp),
428
429    // === Elementwise Operations ===
430    /// Unary elementwise operation.
431    Unary(UnaryOp, TensorRef),
432    /// Binary elementwise operation.
433    Binary(BinaryOp, TensorRef, TensorRef),
434    /// Map a function over elements.
435    Map(MapFn, TensorRef),
436    /// Zip two tensors with a function.
437    ZipWith(ZipFn, TensorRef, TensorRef),
438
439    // === Reductions ===
440    /// Reduce along an axis.
441    Reduce(ReduceOp, Axis, TensorRef),
442    /// Full reduction to scalar.
443    ReduceAll(ReduceOp, TensorRef),
444    /// Scan (prefix sum) along an axis.
445    Scan(ReduceOp, Axis, TensorRef),
446    /// Fold with initial value.
447    Fold(FoldFn, TensorRef, TensorRef),
448
449    // === Structure Operations ===
450    /// Reshape to a new shape.
451    Reshape(Shape, TensorRef),
452    /// Slice a region.
453    Slice(SliceSpec, TensorRef),
454    /// Transpose (permute dimensions).
455    Transpose(Permutation, TensorRef),
456    /// Broadcast to a larger shape.
457    Broadcast(Shape, TensorRef),
458    /// Concatenate along an axis.
459    Concat(Axis, Vec<TensorRef>),
460    /// Split along an axis.
461    Split(Axis, Vec<usize>, TensorRef),
462
463    // === Linear Algebra ===
464    /// Matrix multiplication.
465    MatMul(TensorRef, TensorRef),
466    /// Batched matrix multiplication.
467    BatchMatMul(TensorRef, TensorRef),
468    /// Dot product.
469    Dot(TensorRef, TensorRef),
470    /// Outer product.
471    Outer(TensorRef, TensorRef),
472
473    // === Convolution ===
474    /// Convolution operation.
475    Conv(ConvSpec, TensorRef, TensorRef),
476
477    // === Indexing ===
478    /// Gather elements.
479    Gather(Axis, TensorRef, TensorRef),
480    /// Scatter elements.
481    Scatter(Axis, TensorRef, TensorRef, TensorRef),
482}
483
484/// A constant tensor operation.
485#[derive(Clone, Debug, Serialize, Deserialize)]
486pub enum ConstantOp {
487    /// Zeros tensor.
488    Zeros(TensorMeta),
489    /// Ones tensor.
490    Ones(TensorMeta),
491    /// Tensor filled with a value.
492    Full(TensorMeta, ScalarValue),
493    /// Range/arange tensor.
494    Range(DType, i64, i64, i64),
495    /// Identity matrix.
496    Eye(DType, usize),
497}
498
499/// Unary operations.
500#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
501pub enum UnaryOp {
502    /// Negation.
503    Neg,
504    /// Absolute value.
505    Abs,
506    /// Square root.
507    Sqrt,
508    /// Reciprocal square root.
509    Rsqrt,
510    /// Exponential.
511    Exp,
512    /// Natural logarithm.
513    Log,
514    /// Sine.
515    Sin,
516    /// Cosine.
517    Cos,
518    /// Tangent.
519    Tan,
520    /// Hyperbolic tangent.
521    Tanh,
522    /// Sigmoid.
523    Sigmoid,
524    /// `ReLU`.
525    Relu,
526    /// Ceiling.
527    Ceil,
528    /// Floor.
529    Floor,
530    /// Round.
531    Round,
532    /// Bitwise not (integers).
533    Not,
534}
535
536/// Binary operations.
537#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
538pub enum BinaryOp {
539    /// Addition.
540    Add,
541    /// Subtraction.
542    Sub,
543    /// Multiplication.
544    Mul,
545    /// Division.
546    Div,
547    /// Modulo.
548    Mod,
549    /// Power.
550    Pow,
551    /// Maximum.
552    Max,
553    /// Minimum.
554    Min,
555    /// Equality.
556    Eq,
557    /// Not equal.
558    Ne,
559    /// Less than.
560    Lt,
561    /// Less than or equal.
562    Le,
563    /// Greater than.
564    Gt,
565    /// Greater than or equal.
566    Ge,
567    /// Bitwise and.
568    And,
569    /// Bitwise or.
570    Or,
571    /// Bitwise xor.
572    Xor,
573    /// Left shift.
574    Shl,
575    /// Right shift.
576    Shr,
577}
578
579/// Reduction operations.
580#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
581pub enum ReduceOp {
582    /// Sum reduction.
583    Sum,
584    /// Product reduction.
585    Prod,
586    /// Maximum reduction.
587    Max,
588    /// Minimum reduction.
589    Min,
590    /// Logical and.
591    All,
592    /// Logical or.
593    Any,
594    /// Mean (sum / count).
595    Mean,
596}
597
598/// An axis specification.
599#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
600pub struct Axis(pub i32);
601
602impl Axis {
603    /// Creates a new axis.
604    #[must_use]
605    pub const fn new(axis: i32) -> Self {
606        Self(axis)
607    }
608
609    /// Normalizes a potentially negative axis to a positive index.
610    #[must_use]
611    pub const fn normalize(self, rank: usize) -> Option<usize> {
612        let axis = if self.0 < 0 {
613            (rank as i32) + self.0
614        } else {
615            self.0
616        };
617        if axis >= 0 && (axis as usize) < rank {
618            Some(axis as usize)
619        } else {
620            None
621        }
622    }
623}
624
625/// A scalar value.
626#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
627pub enum ScalarValue {
628    /// Boolean.
629    Bool(bool),
630    /// Integer.
631    Int(i64),
632    /// Floating point.
633    Float(f64),
634}
635
636/// A map function (element-wise transformation).
637#[derive(Clone, Debug, Serialize, Deserialize)]
638pub struct MapFn {
639    /// The function name/identifier.
640    pub name: Symbol,
641    /// Source span.
642    pub span: Span,
643}
644
645/// A zip function (combining two elements).
646#[derive(Clone, Debug, Serialize, Deserialize)]
647pub struct ZipFn {
648    /// The function name/identifier.
649    pub name: Symbol,
650    /// Source span.
651    pub span: Span,
652}
653
654/// A fold function.
655#[derive(Clone, Debug, Serialize, Deserialize)]
656pub struct FoldFn {
657    /// The combining function.
658    pub name: Symbol,
659    /// Source span.
660    pub span: Span,
661}
662
663/// A slice specification.
664#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
665pub struct SliceSpec {
666    /// Ranges for each dimension (start, stop, step).
667    pub ranges: SmallVec<[SliceRange; 4]>,
668}
669
670/// A range within a slice.
671#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
672pub struct SliceRange {
673    /// Start index (inclusive).
674    pub start: Option<i64>,
675    /// Stop index (exclusive).
676    pub stop: Option<i64>,
677    /// Step size.
678    pub step: i64,
679}
680
681/// A permutation for transpose.
682#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
683pub struct Permutation(SmallVec<[usize; 4]>);
684
685impl Permutation {
686    /// Creates a new permutation.
687    #[must_use]
688    pub fn new(perm: impl IntoIterator<Item = usize>) -> Self {
689        Self(perm.into_iter().collect())
690    }
691
692    /// Returns the permutation as a slice.
693    #[must_use]
694    pub fn as_slice(&self) -> &[usize] {
695        &self.0
696    }
697
698    /// Returns true if this is the identity permutation.
699    #[must_use]
700    pub fn is_identity(&self) -> bool {
701        self.0.iter().enumerate().all(|(i, &p)| i == p)
702    }
703}
704
705/// Convolution specification.
706#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
707pub struct ConvSpec {
708    /// Padding per dimension.
709    pub padding: SmallVec<[(usize, usize); 4]>,
710    /// Stride per dimension.
711    pub strides: SmallVec<[usize; 4]>,
712    /// Dilation per dimension.
713    pub dilation: SmallVec<[usize; 4]>,
714    /// Number of groups.
715    pub groups: usize,
716}
717
718/// A fused computation kernel.
719///
720/// Kernels are the output of the fusion pass. Each kernel
721/// represents a unit of computation that executes without
722/// intermediate allocation.
723#[derive(Clone, Debug, Serialize, Deserialize)]
724pub struct Kernel {
725    /// Unique kernel identifier.
726    pub id: KernelId,
727    /// Kernel name (for debugging/profiling).
728    pub name: Symbol,
729    /// Input tensors.
730    pub inputs: Vec<TensorRef>,
731    /// Output tensors.
732    pub outputs: Vec<TensorRef>,
733    /// The computation body.
734    pub body: KernelBody,
735    /// Allocation requirements.
736    pub allocs: Vec<AllocInfo>,
737    /// Fusion information.
738    pub fusion_info: FusionInfo,
739}
740
741/// The body of a kernel.
742#[derive(Clone, Debug, Serialize, Deserialize)]
743pub enum KernelBody {
744    /// A simple fused operation.
745    Fused(Vec<TensorOp>),
746    /// A loop nest (lowered from tensor ops).
747    LoopNest(LoopNest),
748}
749
750/// A simple loop nest representation.
751#[derive(Clone, Debug, Serialize, Deserialize)]
752pub struct LoopNest {
753    /// The loops from outermost to innermost.
754    pub loops: Vec<LoopInfo>,
755    /// The innermost computation.
756    pub body: Vec<TensorOp>,
757}
758
759/// Information about a single loop.
760#[derive(Clone, Debug, Serialize, Deserialize)]
761pub struct LoopInfo {
762    /// Loop variable name.
763    pub var: Symbol,
764    /// Lower bound.
765    pub lower: i64,
766    /// Upper bound.
767    pub upper: Dim,
768    /// Step size.
769    pub step: i64,
770    /// Whether this loop can be parallelized.
771    pub parallel: bool,
772    /// Whether this loop can be vectorized.
773    pub vectorize: Option<usize>,
774}
775
776/// Allocation information for a kernel.
777#[derive(Clone, Debug, Serialize, Deserialize)]
778pub struct AllocInfo {
779    /// The buffer being allocated.
780    pub buffer: BufferId,
781    /// Size in bytes.
782    pub size: usize,
783    /// Alignment requirement.
784    pub alignment: usize,
785    /// Allocation region.
786    pub region: AllocRegion,
787}
788
789/// Memory regions for allocation.
790#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
791pub enum AllocRegion {
792    /// Hot arena (bump allocated, scoped lifetime).
793    HotArena,
794    /// Pinned heap (for FFI).
795    Pinned,
796    /// General heap (GC managed).
797    General,
798    /// GPU device memory.
799    DeviceMemory(DeviceTarget),
800}
801
802/// Target device for GPU memory allocation.
803#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
804pub enum DeviceTarget {
805    /// NVIDIA GPU (CUDA).
806    Cuda(u32),
807    /// AMD GPU (`ROCm`).
808    Rocm(u32),
809    /// Any available GPU.
810    Any,
811}
812
813/// Fusion information for debugging and reporting.
814#[derive(Clone, Debug, Serialize, Deserialize)]
815pub struct FusionInfo {
816    /// Original operations before fusion.
817    pub original_ops: Vec<Symbol>,
818    /// Fusion decisions made.
819    pub decisions: Vec<FusionDecision>,
820    /// Whether all expected fusions succeeded.
821    pub complete: bool,
822}
823
824/// A fusion decision made by the compiler.
825#[derive(Clone, Debug, Serialize, Deserialize)]
826pub enum FusionDecision {
827    /// Operations were successfully fused.
828    Fused(Vec<Symbol>),
829    /// A materialization point was inserted.
830    Materialized(Symbol, MaterializeReason),
831    /// Fusion was blocked for a reason.
832    Blocked(Symbol, FusionBlockReason),
833}
834
835/// Why a tensor was materialized.
836#[derive(Clone, Debug, Serialize, Deserialize)]
837pub enum MaterializeReason {
838    /// Used by multiple consumers.
839    MultipleUses,
840    /// Explicitly requested by programmer.
841    Explicit,
842    /// Required for control flow.
843    ControlFlow,
844}
845
846/// Why fusion was blocked.
847#[derive(Clone, Debug, Serialize, Deserialize)]
848pub enum FusionBlockReason {
849    /// Shape mismatch between operations.
850    ShapeMismatch,
851    /// Incompatible data types.
852    DTypeMismatch,
853    /// Data dependency prevents fusion.
854    DataDependency,
855    /// Side effects prevent reordering.
856    SideEffects,
857}
858
859/// Errors in Tensor IR operations.
860#[derive(Clone, Debug, thiserror::Error, Serialize, Deserialize)]
861pub enum TensorIrError {
862    /// Shape mismatch in operation.
863    #[error("shape mismatch: expected {expected:?}, got {got:?}")]
864    ShapeMismatch {
865        /// Expected shape.
866        expected: Shape,
867        /// Actual shape.
868        got: Shape,
869    },
870
871    /// Invalid axis for operation.
872    #[error("invalid axis {axis} for tensor of rank {rank}")]
873    InvalidAxis {
874        /// The axis specified.
875        axis: i32,
876        /// The tensor rank.
877        rank: usize,
878    },
879
880    /// Type mismatch.
881    #[error("dtype mismatch: expected {expected:?}, got {got:?}")]
882    DTypeMismatch {
883        /// Expected dtype.
884        expected: DType,
885        /// Actual dtype.
886        got: DType,
887    },
888
889    /// Fusion failed for guaranteed pattern.
890    #[error("fusion failed for guaranteed pattern: {pattern}")]
891    FusionFailed {
892        /// The pattern that should have fused.
893        pattern: String,
894    },
895}
896
897#[cfg(test)]
898mod tests {
899    use super::*;
900
901    #[test]
902    fn test_dtype_sizes() {
903        assert_eq!(DType::Float32.size_bytes(), 4);
904        assert_eq!(DType::Float64.size_bytes(), 8);
905        assert_eq!(DType::Int32.size_bytes(), 4);
906    }
907
908    #[test]
909    fn test_shape_num_elements() {
910        let shape = Shape::from_static([2, 3, 4]);
911        assert_eq!(shape.num_elements(), Some(24));
912        assert_eq!(shape.rank(), 3);
913    }
914
915    #[test]
916    fn test_strides_contiguous() {
917        let shape = Shape::from_static([2, 3, 4]);
918        let strides = Strides::contiguous(&shape, 4).unwrap();
919        assert_eq!(strides.values(), &[48, 16, 4]);
920    }
921
922    #[test]
923    fn test_axis_normalize() {
924        let axis = Axis::new(-1);
925        assert_eq!(axis.normalize(3), Some(2));
926
927        let axis = Axis::new(1);
928        assert_eq!(axis.normalize(3), Some(1));
929
930        let axis = Axis::new(5);
931        assert_eq!(axis.normalize(3), None);
932    }
933
934    #[test]
935    fn test_permutation_identity() {
936        let perm = Permutation::new([0, 1, 2]);
937        assert!(perm.is_identity());
938
939        let perm = Permutation::new([2, 0, 1]);
940        assert!(!perm.is_identity());
941    }
942}