1#![warn(missing_docs)]
75
76pub mod fusion;
77pub mod lower;
78
79use bhc_index::Idx;
80use bhc_intern::Symbol;
81use bhc_span::Span;
82use serde::{Deserialize, Serialize};
83use smallvec::SmallVec;
84
85#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
87pub struct TensorId(u32);
88
89impl Idx for TensorId {
90 fn new(idx: usize) -> Self {
91 Self(idx as u32)
92 }
93
94 fn index(self) -> usize {
95 self.0 as usize
96 }
97}
98
99#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
101pub struct KernelId(u32);
102
103impl Idx for KernelId {
104 fn new(idx: usize) -> Self {
105 Self(idx as u32)
106 }
107
108 fn index(self) -> usize {
109 self.0 as usize
110 }
111}
112
113#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
115pub struct BufferId(u32);
116
117impl Idx for BufferId {
118 fn new(idx: usize) -> Self {
119 Self(idx as u32)
120 }
121
122 fn index(self) -> usize {
123 self.0 as usize
124 }
125}
126
127#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
132pub enum DType {
133 Bool,
135 Int8,
137 Int16,
139 Int32,
141 Int64,
143 UInt8,
145 UInt16,
147 UInt32,
149 UInt64,
151 Float16,
153 Float32,
155 Float64,
157 BFloat16,
159 Complex64,
161 Complex128,
163}
164
165impl DType {
166 #[must_use]
168 pub const fn size_bytes(self) -> usize {
169 match self {
170 Self::Bool | Self::Int8 | Self::UInt8 => 1,
171 Self::Int16 | Self::UInt16 | Self::Float16 | Self::BFloat16 => 2,
172 Self::Int32 | Self::UInt32 | Self::Float32 => 4,
173 Self::Int64 | Self::UInt64 | Self::Float64 | Self::Complex64 => 8,
174 Self::Complex128 => 16,
175 }
176 }
177
178 #[must_use]
180 pub const fn alignment(self) -> usize {
181 self.size_bytes()
182 }
183
184 #[must_use]
186 pub const fn is_float(self) -> bool {
187 matches!(
188 self,
189 Self::Float16 | Self::Float32 | Self::Float64 | Self::BFloat16
190 )
191 }
192
193 #[must_use]
195 pub const fn is_integer(self) -> bool {
196 matches!(
197 self,
198 Self::Int8
199 | Self::Int16
200 | Self::Int32
201 | Self::Int64
202 | Self::UInt8
203 | Self::UInt16
204 | Self::UInt32
205 | Self::UInt64
206 )
207 }
208
209 #[must_use]
211 pub const fn is_signed(self) -> bool {
212 matches!(
213 self,
214 Self::Int8
215 | Self::Int16
216 | Self::Int32
217 | Self::Int64
218 | Self::Float16
219 | Self::Float32
220 | Self::Float64
221 | Self::BFloat16
222 | Self::Complex64
223 | Self::Complex128
224 )
225 }
226}
227
228#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
230pub enum Dim {
231 Static(usize),
233 Dynamic(Symbol),
235}
236
237impl Dim {
238 #[must_use]
240 pub const fn static_value(&self) -> Option<usize> {
241 match self {
242 Self::Static(n) => Some(*n),
243 Self::Dynamic(_) => None,
244 }
245 }
246
247 #[must_use]
249 pub const fn is_static(&self) -> bool {
250 matches!(self, Self::Static(_))
251 }
252}
253
254#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
256pub struct Shape(SmallVec<[Dim; 4]>);
257
258impl Shape {
259 #[must_use]
261 pub fn new(dims: impl IntoIterator<Item = Dim>) -> Self {
262 Self(dims.into_iter().collect())
263 }
264
265 #[must_use]
267 pub fn from_static(dims: impl IntoIterator<Item = usize>) -> Self {
268 Self(dims.into_iter().map(Dim::Static).collect())
269 }
270
271 #[must_use]
273 pub fn scalar() -> Self {
274 Self(SmallVec::new())
275 }
276
277 #[must_use]
279 pub fn rank(&self) -> usize {
280 self.0.len()
281 }
282
283 #[must_use]
285 pub fn dims(&self) -> &[Dim] {
286 &self.0
287 }
288
289 #[must_use]
291 pub fn num_elements(&self) -> Option<usize> {
292 self.0
293 .iter()
294 .try_fold(1usize, |acc, dim| dim.static_value().map(|n| acc * n))
295 }
296
297 #[must_use]
299 pub fn is_scalar(&self) -> bool {
300 self.0.is_empty()
301 }
302
303 #[must_use]
305 pub fn is_static(&self) -> bool {
306 self.0.iter().all(Dim::is_static)
307 }
308}
309
310#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
312pub struct Strides(SmallVec<[i64; 4]>);
313
314impl Strides {
315 #[must_use]
317 pub fn new(strides: impl IntoIterator<Item = i64>) -> Self {
318 Self(strides.into_iter().collect())
319 }
320
321 #[must_use]
323 pub fn contiguous(shape: &Shape, elem_size: usize) -> Option<Self> {
324 let mut strides = SmallVec::with_capacity(shape.rank());
325 let mut stride = elem_size as i64;
326
327 for dim in shape.dims().iter().rev() {
328 strides.push(stride);
329 stride *= dim.static_value()? as i64;
330 }
331
332 strides.reverse();
333 Some(Self(strides))
334 }
335
336 #[must_use]
338 pub fn values(&self) -> &[i64] {
339 &self.0
340 }
341
342 #[must_use]
344 pub fn is_contiguous(&self, shape: &Shape, elem_size: usize) -> bool {
345 if let Some(contiguous) = Self::contiguous(shape, elem_size) {
346 self.0 == contiguous.0
347 } else {
348 false
349 }
350 }
351}
352
353#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
355pub enum Layout {
356 Contiguous,
358 Strided,
360 Tiled(TileInfo),
362}
363
364#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
366pub struct TileInfo {
367 pub tile_sizes: SmallVec<[usize; 4]>,
369 pub inner_order: SmallVec<[usize; 4]>,
371}
372
373#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
375pub struct TensorMeta {
376 pub dtype: DType,
378 pub shape: Shape,
380 pub strides: Strides,
382 pub layout: Layout,
384 pub alias: Option<BufferId>,
386}
387
388impl TensorMeta {
389 #[must_use]
391 pub fn new_contiguous(dtype: DType, shape: Shape) -> Option<Self> {
392 let strides = Strides::contiguous(&shape, dtype.size_bytes())?;
393 Some(Self {
394 dtype,
395 shape,
396 strides,
397 layout: Layout::Contiguous,
398 alias: None,
399 })
400 }
401
402 #[must_use]
404 pub fn size_bytes(&self) -> Option<usize> {
405 self.shape
406 .num_elements()
407 .map(|n| n * self.dtype.size_bytes())
408 }
409}
410
411#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
413pub struct TensorRef {
414 pub id: TensorId,
416 pub meta: TensorMeta,
418}
419
420#[derive(Clone, Debug, Serialize, Deserialize)]
425pub enum TensorOp {
426 Constant(ConstantOp),
428
429 Unary(UnaryOp, TensorRef),
432 Binary(BinaryOp, TensorRef, TensorRef),
434 Map(MapFn, TensorRef),
436 ZipWith(ZipFn, TensorRef, TensorRef),
438
439 Reduce(ReduceOp, Axis, TensorRef),
442 ReduceAll(ReduceOp, TensorRef),
444 Scan(ReduceOp, Axis, TensorRef),
446 Fold(FoldFn, TensorRef, TensorRef),
448
449 Reshape(Shape, TensorRef),
452 Slice(SliceSpec, TensorRef),
454 Transpose(Permutation, TensorRef),
456 Broadcast(Shape, TensorRef),
458 Concat(Axis, Vec<TensorRef>),
460 Split(Axis, Vec<usize>, TensorRef),
462
463 MatMul(TensorRef, TensorRef),
466 BatchMatMul(TensorRef, TensorRef),
468 Dot(TensorRef, TensorRef),
470 Outer(TensorRef, TensorRef),
472
473 Conv(ConvSpec, TensorRef, TensorRef),
476
477 Gather(Axis, TensorRef, TensorRef),
480 Scatter(Axis, TensorRef, TensorRef, TensorRef),
482}
483
484#[derive(Clone, Debug, Serialize, Deserialize)]
486pub enum ConstantOp {
487 Zeros(TensorMeta),
489 Ones(TensorMeta),
491 Full(TensorMeta, ScalarValue),
493 Range(DType, i64, i64, i64),
495 Eye(DType, usize),
497}
498
499#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
501pub enum UnaryOp {
502 Neg,
504 Abs,
506 Sqrt,
508 Rsqrt,
510 Exp,
512 Log,
514 Sin,
516 Cos,
518 Tan,
520 Tanh,
522 Sigmoid,
524 Relu,
526 Ceil,
528 Floor,
530 Round,
532 Not,
534}
535
536#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
538pub enum BinaryOp {
539 Add,
541 Sub,
543 Mul,
545 Div,
547 Mod,
549 Pow,
551 Max,
553 Min,
555 Eq,
557 Ne,
559 Lt,
561 Le,
563 Gt,
565 Ge,
567 And,
569 Or,
571 Xor,
573 Shl,
575 Shr,
577}
578
579#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
581pub enum ReduceOp {
582 Sum,
584 Prod,
586 Max,
588 Min,
590 All,
592 Any,
594 Mean,
596}
597
598#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
600pub struct Axis(pub i32);
601
602impl Axis {
603 #[must_use]
605 pub const fn new(axis: i32) -> Self {
606 Self(axis)
607 }
608
609 #[must_use]
611 pub const fn normalize(self, rank: usize) -> Option<usize> {
612 let axis = if self.0 < 0 {
613 (rank as i32) + self.0
614 } else {
615 self.0
616 };
617 if axis >= 0 && (axis as usize) < rank {
618 Some(axis as usize)
619 } else {
620 None
621 }
622 }
623}
624
625#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
627pub enum ScalarValue {
628 Bool(bool),
630 Int(i64),
632 Float(f64),
634}
635
636#[derive(Clone, Debug, Serialize, Deserialize)]
638pub struct MapFn {
639 pub name: Symbol,
641 pub span: Span,
643}
644
645#[derive(Clone, Debug, Serialize, Deserialize)]
647pub struct ZipFn {
648 pub name: Symbol,
650 pub span: Span,
652}
653
654#[derive(Clone, Debug, Serialize, Deserialize)]
656pub struct FoldFn {
657 pub name: Symbol,
659 pub span: Span,
661}
662
663#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
665pub struct SliceSpec {
666 pub ranges: SmallVec<[SliceRange; 4]>,
668}
669
670#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
672pub struct SliceRange {
673 pub start: Option<i64>,
675 pub stop: Option<i64>,
677 pub step: i64,
679}
680
681#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
683pub struct Permutation(SmallVec<[usize; 4]>);
684
685impl Permutation {
686 #[must_use]
688 pub fn new(perm: impl IntoIterator<Item = usize>) -> Self {
689 Self(perm.into_iter().collect())
690 }
691
692 #[must_use]
694 pub fn as_slice(&self) -> &[usize] {
695 &self.0
696 }
697
698 #[must_use]
700 pub fn is_identity(&self) -> bool {
701 self.0.iter().enumerate().all(|(i, &p)| i == p)
702 }
703}
704
705#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
707pub struct ConvSpec {
708 pub padding: SmallVec<[(usize, usize); 4]>,
710 pub strides: SmallVec<[usize; 4]>,
712 pub dilation: SmallVec<[usize; 4]>,
714 pub groups: usize,
716}
717
718#[derive(Clone, Debug, Serialize, Deserialize)]
724pub struct Kernel {
725 pub id: KernelId,
727 pub name: Symbol,
729 pub inputs: Vec<TensorRef>,
731 pub outputs: Vec<TensorRef>,
733 pub body: KernelBody,
735 pub allocs: Vec<AllocInfo>,
737 pub fusion_info: FusionInfo,
739}
740
741#[derive(Clone, Debug, Serialize, Deserialize)]
743pub enum KernelBody {
744 Fused(Vec<TensorOp>),
746 LoopNest(LoopNest),
748}
749
750#[derive(Clone, Debug, Serialize, Deserialize)]
752pub struct LoopNest {
753 pub loops: Vec<LoopInfo>,
755 pub body: Vec<TensorOp>,
757}
758
759#[derive(Clone, Debug, Serialize, Deserialize)]
761pub struct LoopInfo {
762 pub var: Symbol,
764 pub lower: i64,
766 pub upper: Dim,
768 pub step: i64,
770 pub parallel: bool,
772 pub vectorize: Option<usize>,
774}
775
776#[derive(Clone, Debug, Serialize, Deserialize)]
778pub struct AllocInfo {
779 pub buffer: BufferId,
781 pub size: usize,
783 pub alignment: usize,
785 pub region: AllocRegion,
787}
788
789#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
791pub enum AllocRegion {
792 HotArena,
794 Pinned,
796 General,
798 DeviceMemory(DeviceTarget),
800}
801
802#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
804pub enum DeviceTarget {
805 Cuda(u32),
807 Rocm(u32),
809 Any,
811}
812
813#[derive(Clone, Debug, Serialize, Deserialize)]
815pub struct FusionInfo {
816 pub original_ops: Vec<Symbol>,
818 pub decisions: Vec<FusionDecision>,
820 pub complete: bool,
822}
823
824#[derive(Clone, Debug, Serialize, Deserialize)]
826pub enum FusionDecision {
827 Fused(Vec<Symbol>),
829 Materialized(Symbol, MaterializeReason),
831 Blocked(Symbol, FusionBlockReason),
833}
834
835#[derive(Clone, Debug, Serialize, Deserialize)]
837pub enum MaterializeReason {
838 MultipleUses,
840 Explicit,
842 ControlFlow,
844}
845
846#[derive(Clone, Debug, Serialize, Deserialize)]
848pub enum FusionBlockReason {
849 ShapeMismatch,
851 DTypeMismatch,
853 DataDependency,
855 SideEffects,
857}
858
859#[derive(Clone, Debug, thiserror::Error, Serialize, Deserialize)]
861pub enum TensorIrError {
862 #[error("shape mismatch: expected {expected:?}, got {got:?}")]
864 ShapeMismatch {
865 expected: Shape,
867 got: Shape,
869 },
870
871 #[error("invalid axis {axis} for tensor of rank {rank}")]
873 InvalidAxis {
874 axis: i32,
876 rank: usize,
878 },
879
880 #[error("dtype mismatch: expected {expected:?}, got {got:?}")]
882 DTypeMismatch {
883 expected: DType,
885 got: DType,
887 },
888
889 #[error("fusion failed for guaranteed pattern: {pattern}")]
891 FusionFailed {
892 pattern: String,
894 },
895}
896
897#[cfg(test)]
898mod tests {
899 use super::*;
900
901 #[test]
902 fn test_dtype_sizes() {
903 assert_eq!(DType::Float32.size_bytes(), 4);
904 assert_eq!(DType::Float64.size_bytes(), 8);
905 assert_eq!(DType::Int32.size_bytes(), 4);
906 }
907
908 #[test]
909 fn test_shape_num_elements() {
910 let shape = Shape::from_static([2, 3, 4]);
911 assert_eq!(shape.num_elements(), Some(24));
912 assert_eq!(shape.rank(), 3);
913 }
914
915 #[test]
916 fn test_strides_contiguous() {
917 let shape = Shape::from_static([2, 3, 4]);
918 let strides = Strides::contiguous(&shape, 4).unwrap();
919 assert_eq!(strides.values(), &[48, 16, 4]);
920 }
921
922 #[test]
923 fn test_axis_normalize() {
924 let axis = Axis::new(-1);
925 assert_eq!(axis.normalize(3), Some(2));
926
927 let axis = Axis::new(1);
928 assert_eq!(axis.normalize(3), Some(1));
929
930 let axis = Axis::new(5);
931 assert_eq!(axis.normalize(3), None);
932 }
933
934 #[test]
935 fn test_permutation_identity() {
936 let perm = Permutation::new([0, 1, 2]);
937 assert!(perm.is_identity());
938
939 let perm = Permutation::new([2, 0, 1]);
940 assert!(!perm.is_identity());
941 }
942}