#![warn(missing_docs)]
pub mod fusion;
pub mod lower;
use bhc_index::Idx;
use bhc_intern::Symbol;
use bhc_span::Span;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TensorId(u32);
impl Idx for TensorId {
fn new(idx: usize) -> Self {
Self(idx as u32)
}
fn index(self) -> usize {
self.0 as usize
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct KernelId(u32);
impl Idx for KernelId {
fn new(idx: usize) -> Self {
Self(idx as u32)
}
fn index(self) -> usize {
self.0 as usize
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct BufferId(u32);
impl Idx for BufferId {
fn new(idx: usize) -> Self {
Self(idx as u32)
}
fn index(self) -> usize {
self.0 as usize
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DType {
Bool,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float16,
Float32,
Float64,
BFloat16,
Complex64,
Complex128,
}
impl DType {
#[must_use]
pub const fn size_bytes(self) -> usize {
match self {
Self::Bool | Self::Int8 | Self::UInt8 => 1,
Self::Int16 | Self::UInt16 | Self::Float16 | Self::BFloat16 => 2,
Self::Int32 | Self::UInt32 | Self::Float32 => 4,
Self::Int64 | Self::UInt64 | Self::Float64 | Self::Complex64 => 8,
Self::Complex128 => 16,
}
}
#[must_use]
pub const fn alignment(self) -> usize {
self.size_bytes()
}
#[must_use]
pub const fn is_float(self) -> bool {
matches!(
self,
Self::Float16 | Self::Float32 | Self::Float64 | Self::BFloat16
)
}
#[must_use]
pub const fn is_integer(self) -> bool {
matches!(
self,
Self::Int8
| Self::Int16
| Self::Int32
| Self::Int64
| Self::UInt8
| Self::UInt16
| Self::UInt32
| Self::UInt64
)
}
#[must_use]
pub const fn is_signed(self) -> bool {
matches!(
self,
Self::Int8
| Self::Int16
| Self::Int32
| Self::Int64
| Self::Float16
| Self::Float32
| Self::Float64
| Self::BFloat16
| Self::Complex64
| Self::Complex128
)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Dim {
Static(usize),
Dynamic(Symbol),
}
impl Dim {
#[must_use]
pub const fn static_value(&self) -> Option<usize> {
match self {
Self::Static(n) => Some(*n),
Self::Dynamic(_) => None,
}
}
#[must_use]
pub const fn is_static(&self) -> bool {
matches!(self, Self::Static(_))
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Shape(SmallVec<[Dim; 4]>);
impl Shape {
#[must_use]
pub fn new(dims: impl IntoIterator<Item = Dim>) -> Self {
Self(dims.into_iter().collect())
}
#[must_use]
pub fn from_static(dims: impl IntoIterator<Item = usize>) -> Self {
Self(dims.into_iter().map(Dim::Static).collect())
}
#[must_use]
pub fn scalar() -> Self {
Self(SmallVec::new())
}
#[must_use]
pub fn rank(&self) -> usize {
self.0.len()
}
#[must_use]
pub fn dims(&self) -> &[Dim] {
&self.0
}
#[must_use]
pub fn num_elements(&self) -> Option<usize> {
self.0
.iter()
.try_fold(1usize, |acc, dim| dim.static_value().map(|n| acc * n))
}
#[must_use]
pub fn is_scalar(&self) -> bool {
self.0.is_empty()
}
#[must_use]
pub fn is_static(&self) -> bool {
self.0.iter().all(Dim::is_static)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Strides(SmallVec<[i64; 4]>);
impl Strides {
#[must_use]
pub fn new(strides: impl IntoIterator<Item = i64>) -> Self {
Self(strides.into_iter().collect())
}
#[must_use]
pub fn contiguous(shape: &Shape, elem_size: usize) -> Option<Self> {
let mut strides = SmallVec::with_capacity(shape.rank());
let mut stride = elem_size as i64;
for dim in shape.dims().iter().rev() {
strides.push(stride);
stride *= dim.static_value()? as i64;
}
strides.reverse();
Some(Self(strides))
}
#[must_use]
pub fn values(&self) -> &[i64] {
&self.0
}
#[must_use]
pub fn is_contiguous(&self, shape: &Shape, elem_size: usize) -> bool {
if let Some(contiguous) = Self::contiguous(shape, elem_size) {
self.0 == contiguous.0
} else {
false
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Layout {
Contiguous,
Strided,
Tiled(TileInfo),
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TileInfo {
pub tile_sizes: SmallVec<[usize; 4]>,
pub inner_order: SmallVec<[usize; 4]>,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TensorMeta {
pub dtype: DType,
pub shape: Shape,
pub strides: Strides,
pub layout: Layout,
pub alias: Option<BufferId>,
}
impl TensorMeta {
#[must_use]
pub fn new_contiguous(dtype: DType, shape: Shape) -> Option<Self> {
let strides = Strides::contiguous(&shape, dtype.size_bytes())?;
Some(Self {
dtype,
shape,
strides,
layout: Layout::Contiguous,
alias: None,
})
}
#[must_use]
pub fn size_bytes(&self) -> Option<usize> {
self.shape
.num_elements()
.map(|n| n * self.dtype.size_bytes())
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TensorRef {
pub id: TensorId,
pub meta: TensorMeta,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum TensorOp {
Constant(ConstantOp),
Unary(UnaryOp, TensorRef),
Binary(BinaryOp, TensorRef, TensorRef),
Map(MapFn, TensorRef),
ZipWith(ZipFn, TensorRef, TensorRef),
Reduce(ReduceOp, Axis, TensorRef),
ReduceAll(ReduceOp, TensorRef),
Scan(ReduceOp, Axis, TensorRef),
Fold(FoldFn, TensorRef, TensorRef),
Reshape(Shape, TensorRef),
Slice(SliceSpec, TensorRef),
Transpose(Permutation, TensorRef),
Broadcast(Shape, TensorRef),
Concat(Axis, Vec<TensorRef>),
Split(Axis, Vec<usize>, TensorRef),
MatMul(TensorRef, TensorRef),
BatchMatMul(TensorRef, TensorRef),
Dot(TensorRef, TensorRef),
Outer(TensorRef, TensorRef),
Conv(ConvSpec, TensorRef, TensorRef),
Gather(Axis, TensorRef, TensorRef),
Scatter(Axis, TensorRef, TensorRef, TensorRef),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ConstantOp {
Zeros(TensorMeta),
Ones(TensorMeta),
Full(TensorMeta, ScalarValue),
Range(DType, i64, i64, i64),
Eye(DType, usize),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum UnaryOp {
Neg,
Abs,
Sqrt,
Rsqrt,
Exp,
Log,
Sin,
Cos,
Tan,
Tanh,
Sigmoid,
Relu,
Ceil,
Floor,
Round,
Not,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BinaryOp {
Add,
Sub,
Mul,
Div,
Mod,
Pow,
Max,
Min,
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
And,
Or,
Xor,
Shl,
Shr,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ReduceOp {
Sum,
Prod,
Max,
Min,
All,
Any,
Mean,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Axis(pub i32);
impl Axis {
#[must_use]
pub const fn new(axis: i32) -> Self {
Self(axis)
}
#[must_use]
pub const fn normalize(self, rank: usize) -> Option<usize> {
let axis = if self.0 < 0 {
(rank as i32) + self.0
} else {
self.0
};
if axis >= 0 && (axis as usize) < rank {
Some(axis as usize)
} else {
None
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum ScalarValue {
Bool(bool),
Int(i64),
Float(f64),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MapFn {
pub name: Symbol,
pub span: Span,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ZipFn {
pub name: Symbol,
pub span: Span,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FoldFn {
pub name: Symbol,
pub span: Span,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct SliceSpec {
pub ranges: SmallVec<[SliceRange; 4]>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct SliceRange {
pub start: Option<i64>,
pub stop: Option<i64>,
pub step: i64,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Permutation(SmallVec<[usize; 4]>);
impl Permutation {
#[must_use]
pub fn new(perm: impl IntoIterator<Item = usize>) -> Self {
Self(perm.into_iter().collect())
}
#[must_use]
pub fn as_slice(&self) -> &[usize] {
&self.0
}
#[must_use]
pub fn is_identity(&self) -> bool {
self.0.iter().enumerate().all(|(i, &p)| i == p)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ConvSpec {
pub padding: SmallVec<[(usize, usize); 4]>,
pub strides: SmallVec<[usize; 4]>,
pub dilation: SmallVec<[usize; 4]>,
pub groups: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Kernel {
pub id: KernelId,
pub name: Symbol,
pub inputs: Vec<TensorRef>,
pub outputs: Vec<TensorRef>,
pub body: KernelBody,
pub allocs: Vec<AllocInfo>,
pub fusion_info: FusionInfo,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum KernelBody {
Fused(Vec<TensorOp>),
LoopNest(LoopNest),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LoopNest {
pub loops: Vec<LoopInfo>,
pub body: Vec<TensorOp>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LoopInfo {
pub var: Symbol,
pub lower: i64,
pub upper: Dim,
pub step: i64,
pub parallel: bool,
pub vectorize: Option<usize>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AllocInfo {
pub buffer: BufferId,
pub size: usize,
pub alignment: usize,
pub region: AllocRegion,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum AllocRegion {
HotArena,
Pinned,
General,
DeviceMemory(DeviceTarget),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DeviceTarget {
Cuda(u32),
Rocm(u32),
Any,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FusionInfo {
pub original_ops: Vec<Symbol>,
pub decisions: Vec<FusionDecision>,
pub complete: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum FusionDecision {
Fused(Vec<Symbol>),
Materialized(Symbol, MaterializeReason),
Blocked(Symbol, FusionBlockReason),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum MaterializeReason {
MultipleUses,
Explicit,
ControlFlow,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum FusionBlockReason {
ShapeMismatch,
DTypeMismatch,
DataDependency,
SideEffects,
}
#[derive(Clone, Debug, thiserror::Error, Serialize, Deserialize)]
pub enum TensorIrError {
#[error("shape mismatch: expected {expected:?}, got {got:?}")]
ShapeMismatch {
expected: Shape,
got: Shape,
},
#[error("invalid axis {axis} for tensor of rank {rank}")]
InvalidAxis {
axis: i32,
rank: usize,
},
#[error("dtype mismatch: expected {expected:?}, got {got:?}")]
DTypeMismatch {
expected: DType,
got: DType,
},
#[error("fusion failed for guaranteed pattern: {pattern}")]
FusionFailed {
pattern: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dtype_sizes() {
assert_eq!(DType::Float32.size_bytes(), 4);
assert_eq!(DType::Float64.size_bytes(), 8);
assert_eq!(DType::Int32.size_bytes(), 4);
}
#[test]
fn test_shape_num_elements() {
let shape = Shape::from_static([2, 3, 4]);
assert_eq!(shape.num_elements(), Some(24));
assert_eq!(shape.rank(), 3);
}
#[test]
fn test_strides_contiguous() {
let shape = Shape::from_static([2, 3, 4]);
let strides = Strides::contiguous(&shape, 4).unwrap();
assert_eq!(strides.values(), &[48, 16, 4]);
}
#[test]
fn test_axis_normalize() {
let axis = Axis::new(-1);
assert_eq!(axis.normalize(3), Some(2));
let axis = Axis::new(1);
assert_eq!(axis.normalize(3), Some(1));
let axis = Axis::new(5);
assert_eq!(axis.normalize(3), None);
}
#[test]
fn test_permutation_identity() {
let perm = Permutation::new([0, 1, 2]);
assert!(perm.is_identity());
let perm = Permutation::new([2, 0, 1]);
assert!(!perm.is_identity());
}
}