use crate::{
AccessPattern, Alloc, BinOp, Body, Loop, LoopAttrs, LoopIR, LoopId, LoopMetadata, LoopType,
MemRef, Op, Param, ScalarType, Stmt, TargetArch, TripCount, Value, ValueId,
};
use bhc_index::Idx;
use bhc_intern::Symbol;
use bhc_tensor_ir::{
BufferId, Kernel, KernelBody, LoopNest as TensorLoopNest, ReduceOp as TensorReduceOp, TensorOp,
TensorRef,
};
use rustc_hash::FxHashMap;
use thiserror::Error;
#[derive(Clone, Debug, Error)]
pub enum LowerError {
#[error("unsupported tensor operation: {op}")]
UnsupportedOp {
op: String,
},
#[error("shape mismatch: expected {expected:?}, got {got:?}")]
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
#[error("invalid kernel structure: {reason}")]
InvalidKernel {
reason: String,
},
}
#[derive(Clone, Debug)]
pub struct LowerConfig {
pub target: TargetArch,
pub enable_vectorization: bool,
pub enable_parallelization: bool,
pub vectorize_threshold: usize,
pub parallelize_threshold: usize,
}
impl Default for LowerConfig {
fn default() -> Self {
Self {
target: TargetArch::default(),
enable_vectorization: true,
enable_parallelization: true,
vectorize_threshold: 4,
parallelize_threshold: 1024,
}
}
}
struct LowerContext {
config: LowerConfig,
next_value: u32,
next_loop: u32,
tensor_values: FxHashMap<u64, ValueId>,
allocations: Vec<Alloc>,
loop_metadata: Vec<LoopMetadata>,
params: Vec<Param>,
}
impl LowerContext {
fn new(config: LowerConfig) -> Self {
Self {
config,
next_value: 0,
next_loop: 0,
tensor_values: FxHashMap::default(),
allocations: Vec::new(),
loop_metadata: Vec::new(),
params: Vec::new(),
}
}
fn fresh_value(&mut self) -> ValueId {
let id = ValueId::new(self.next_value as usize);
self.next_value += 1;
id
}
fn fresh_loop(&mut self) -> LoopId {
let id = LoopId::new(self.next_loop as usize);
self.next_loop += 1;
id
}
}
pub fn lower_kernels(kernels: &[Kernel], config: LowerConfig) -> Result<Vec<LoopIR>, LowerError> {
kernels
.iter()
.map(|k| lower_kernel(k, config.clone()))
.collect()
}
pub fn lower_kernel(kernel: &Kernel, config: LowerConfig) -> Result<LoopIR, LowerError> {
let mut ctx = LowerContext::new(config);
for (i, input) in kernel.inputs.iter().enumerate() {
let param = tensor_ref_to_param(input, i, &mut ctx);
ctx.params.push(param);
}
for (i, output) in kernel.outputs.iter().enumerate() {
let param = tensor_ref_to_param(output, kernel.inputs.len() + i, &mut ctx);
ctx.params.push(param);
}
let body = match &kernel.body {
KernelBody::Fused(ops) => lower_fused_ops(ops, kernel, &mut ctx)?,
KernelBody::LoopNest(nest) => lower_tensor_loop_nest(nest, &mut ctx)?,
};
Ok(LoopIR {
name: kernel.name,
params: ctx.params,
return_ty: LoopType::Void,
body,
allocs: ctx.allocations,
loop_info: ctx.loop_metadata,
})
}
fn tensor_ref_to_param(tensor: &TensorRef, index: usize, ctx: &mut LowerContext) -> Param {
let elem_ty = ScalarType::from_dtype(tensor.meta.dtype);
let value_id = ctx.fresh_value();
ctx.tensor_values.insert(tensor.id.index() as u64, value_id);
Param {
name: Symbol::intern(&format!("tensor_{}", index)),
ty: LoopType::Ptr(Box::new(LoopType::Scalar(elem_ty))),
is_ptr: true,
}
}
fn lower_fused_ops(
ops: &[TensorOp],
kernel: &Kernel,
ctx: &mut LowerContext,
) -> Result<Body, LowerError> {
let output_shape: Vec<usize> = if let Some(output) = kernel.outputs.first() {
output
.meta
.shape
.dims()
.iter()
.map(|d| d.static_value().unwrap_or(0))
.collect()
} else {
return Err(LowerError::InvalidKernel {
reason: "kernel has no outputs".to_string(),
});
};
let (body, loop_vars) = generate_loop_nest(&output_shape, ctx)?;
let inner_stmts = lower_fused_ops_body(ops, &loop_vars, kernel, ctx)?;
let mut result_body = body;
insert_inner_stmts(&mut result_body, inner_stmts);
Ok(result_body)
}
fn generate_loop_nest(
shape: &[usize],
ctx: &mut LowerContext,
) -> Result<(Body, Vec<ValueId>), LowerError> {
let mut loop_vars = Vec::with_capacity(shape.len());
let mut loops = Vec::with_capacity(shape.len());
for (dim_idx, &dim_size) in shape.iter().enumerate() {
let loop_id = ctx.fresh_loop();
let loop_var = ctx.fresh_value();
loop_vars.push(loop_var);
let mut attrs = LoopAttrs::INDEPENDENT;
if ctx.config.enable_parallelization
&& dim_idx == 0
&& dim_size >= ctx.config.parallelize_threshold
{
attrs |= LoopAttrs::PARALLEL;
}
if ctx.config.enable_vectorization
&& dim_idx == shape.len() - 1
&& dim_size >= ctx.config.vectorize_threshold
{
attrs |= LoopAttrs::VECTORIZE;
}
ctx.loop_metadata.push(LoopMetadata {
id: loop_id,
trip_count: TripCount::Static(dim_size),
vector_width: None, parallel_chunk: None, unroll_factor: None,
dependencies: Vec::new(),
});
loops.push(Loop {
id: loop_id,
var: loop_var,
lower: Value::i64(0),
upper: Value::i64(dim_size as i64),
step: Value::i64(1),
body: Body::new(),
attrs,
});
}
let mut body = Body::new();
if loops.is_empty() {
return Ok((body, loop_vars));
}
let mut current_loop = loops.pop().unwrap();
while let Some(mut outer) = loops.pop() {
outer.body.push(Stmt::Loop(current_loop));
current_loop = outer;
}
body.push(Stmt::Loop(current_loop));
Ok((body, loop_vars))
}
fn lower_fused_ops_body(
ops: &[TensorOp],
loop_vars: &[ValueId],
_kernel: &Kernel,
ctx: &mut LowerContext,
) -> Result<Vec<Stmt>, LowerError> {
let mut stmts = Vec::new();
for op in ops {
lower_tensor_op(op, loop_vars, &mut stmts, ctx)?;
}
Ok(stmts)
}
fn lower_tensor_op(
op: &TensorOp,
loop_vars: &[ValueId],
stmts: &mut Vec<Stmt>,
ctx: &mut LowerContext,
) -> Result<(), LowerError> {
match op {
TensorOp::Map(_func, input) => {
let input_val = load_tensor_element(input, loop_vars, stmts, ctx)?;
let result = ctx.fresh_value();
stmts.push(Stmt::Assign(result, Op::Unary(crate::UnOp::Neg, input_val)));
Ok(())
}
TensorOp::ZipWith(_func, a, b) => {
let a_val = load_tensor_element(a, loop_vars, stmts, ctx)?;
let b_val = load_tensor_element(b, loop_vars, stmts, ctx)?;
let result = ctx.fresh_value();
stmts.push(Stmt::Assign(result, Op::Binary(BinOp::Add, a_val, b_val)));
Ok(())
}
TensorOp::ReduceAll(reduce_op, input) => {
lower_reduction(reduce_op, input, loop_vars, stmts, ctx)
}
TensorOp::Broadcast(_shape, input) => {
let _ = load_tensor_element(input, loop_vars, stmts, ctx)?;
Ok(())
}
TensorOp::Reshape(_shape, input) => {
let _ = load_tensor_element(input, loop_vars, stmts, ctx)?;
Ok(())
}
TensorOp::Transpose(_perm, input) => {
let _ = load_tensor_element(input, loop_vars, stmts, ctx)?;
Ok(())
}
_ => Err(LowerError::UnsupportedOp {
op: format!("{:?}", std::mem::discriminant(op)),
}),
}
}
fn load_tensor_element(
tensor: &TensorRef,
loop_vars: &[ValueId],
stmts: &mut Vec<Stmt>,
ctx: &mut LowerContext,
) -> Result<Value, LowerError> {
let elem_ty = ScalarType::from_dtype(tensor.meta.dtype);
let index = compute_linear_index(tensor, loop_vars)?;
let buffer_id = tensor
.meta
.alias
.unwrap_or(BufferId::new(tensor.id.index()));
let mem_ref = MemRef {
buffer: buffer_id,
index,
elem_ty: LoopType::Scalar(elem_ty),
access: compute_access_pattern(tensor),
};
let result = ctx.fresh_value();
stmts.push(Stmt::Assign(result, Op::Load(mem_ref)));
Ok(Value::Var(result, LoopType::Scalar(elem_ty)))
}
fn compute_linear_index(_tensor: &TensorRef, loop_vars: &[ValueId]) -> Result<Value, LowerError> {
if loop_vars.is_empty() {
return Ok(Value::i64(0));
}
let first_var = loop_vars[0];
Ok(Value::Var(first_var, LoopType::Scalar(ScalarType::I64)))
}
fn compute_access_pattern(tensor: &TensorRef) -> AccessPattern {
let strides = tensor.meta.strides.values();
if strides.last() == Some(&1) {
AccessPattern::Sequential
} else if let Some(&stride) = strides.last() {
AccessPattern::Strided(stride)
} else {
AccessPattern::Random
}
}
fn lower_reduction(
reduce_op: &TensorReduceOp,
input: &TensorRef,
loop_vars: &[ValueId],
stmts: &mut Vec<Stmt>,
ctx: &mut LowerContext,
) -> Result<(), LowerError> {
let elem_ty = ScalarType::from_dtype(input.meta.dtype);
let bits = elem_ty.size_bytes() as u8 * 8;
let _init_val = match reduce_op {
TensorReduceOp::Sum => Value::float(0.0, bits),
TensorReduceOp::Prod => Value::float(1.0, bits),
TensorReduceOp::Min => Value::float(f64::INFINITY, bits),
TensorReduceOp::Max => Value::float(f64::NEG_INFINITY, bits),
_ => Value::float(0.0, bits),
};
stmts.push(Stmt::Comment(format!(
"reduction accumulator for {:?}",
reduce_op
)));
let acc = ctx.fresh_value();
let input_val = load_tensor_element(input, loop_vars, stmts, ctx)?;
let bin_op = match reduce_op {
TensorReduceOp::Sum => BinOp::Add,
TensorReduceOp::Prod => BinOp::Mul,
TensorReduceOp::Min => BinOp::FMin,
TensorReduceOp::Max => BinOp::FMax,
_ => BinOp::Add,
};
let new_acc = ctx.fresh_value();
stmts.push(Stmt::Assign(
new_acc,
Op::Binary(
bin_op,
Value::Var(acc, LoopType::Scalar(elem_ty)),
input_val,
),
));
Ok(())
}
fn lower_tensor_loop_nest(
nest: &TensorLoopNest,
ctx: &mut LowerContext,
) -> Result<Body, LowerError> {
let mut loops = Vec::new();
for loop_spec in &nest.loops {
let loop_id = ctx.fresh_loop();
let loop_var = ctx.fresh_value();
let mut attrs = LoopAttrs::empty();
if loop_spec.parallel {
attrs |= LoopAttrs::PARALLEL;
}
if loop_spec.vectorize.is_some() {
attrs |= LoopAttrs::VECTORIZE;
}
let trip_count = loop_spec
.upper
.static_value()
.map(TripCount::Static)
.unwrap_or(TripCount::Dynamic);
let upper_bound = loop_spec.upper.static_value().unwrap_or(0) as i64;
ctx.loop_metadata.push(LoopMetadata {
id: loop_id,
trip_count,
vector_width: None,
parallel_chunk: None,
unroll_factor: None,
dependencies: Vec::new(),
});
loops.push(Loop {
id: loop_id,
var: loop_var,
lower: Value::i64(loop_spec.lower),
upper: Value::i64(upper_bound),
step: Value::i64(loop_spec.step),
body: Body::new(),
attrs,
});
}
let mut body = Body::new();
if loops.is_empty() {
return Ok(body);
}
let mut current_loop = loops.pop().unwrap();
while let Some(mut outer) = loops.pop() {
outer.body.push(Stmt::Loop(current_loop));
current_loop = outer;
}
body.push(Stmt::Loop(current_loop));
Ok(body)
}
fn insert_inner_stmts(body: &mut Body, stmts: Vec<Stmt>) {
fn find_innermost_and_insert(body: &mut Body, stmts: Vec<Stmt>) {
if let Some(Stmt::Loop(ref mut lp)) = body.stmts.last_mut() {
if lp.body.stmts.is_empty() || !matches!(lp.body.stmts.last(), Some(Stmt::Loop(_))) {
lp.body.stmts.extend(stmts);
} else {
find_innermost_and_insert(&mut lp.body, stmts);
}
} else {
body.stmts.extend(stmts);
}
}
find_innermost_and_insert(body, stmts);
}
#[cfg(test)]
mod tests {
use super::*;
use bhc_span::Span;
use bhc_tensor_ir::{
DType, FusionInfo, KernelId, Layout, MapFn, Shape, Strides, TensorId, TensorMeta,
};
fn make_test_kernel() -> Kernel {
let meta = TensorMeta {
dtype: DType::Float32,
shape: Shape::from_static([1024]),
strides: Strides::new([1]),
layout: Layout::Contiguous,
alias: None,
};
let input = TensorRef {
id: TensorId::new(0),
meta: meta.clone(),
};
let output = TensorRef {
id: TensorId::new(1),
meta,
};
let map_fn = MapFn {
name: Symbol::intern("f"),
span: Span::DUMMY,
};
Kernel {
id: KernelId::new(0),
name: Symbol::intern("test_kernel"),
inputs: vec![input.clone()],
outputs: vec![output],
body: KernelBody::Fused(vec![TensorOp::Map(map_fn, input)]),
allocs: vec![],
fusion_info: FusionInfo {
original_ops: vec![],
decisions: vec![],
complete: true,
},
}
}
#[test]
fn test_lower_simple_kernel() {
let kernel = make_test_kernel();
let config = LowerConfig::default();
let result = lower_kernel(&kernel, config);
assert!(result.is_ok());
let loop_ir = result.unwrap();
assert_eq!(loop_ir.name.as_str(), "test_kernel");
assert_eq!(loop_ir.params.len(), 2); assert!(!loop_ir.body.stmts.is_empty());
}
#[test]
fn test_lower_generates_loop_nest() {
let kernel = make_test_kernel();
let config = LowerConfig::default();
let loop_ir = lower_kernel(&kernel, config).unwrap();
assert!(matches!(loop_ir.body.stmts.first(), Some(Stmt::Loop(_))));
}
#[test]
fn test_lower_marks_vectorizable() {
let kernel = make_test_kernel();
let config = LowerConfig {
enable_vectorization: true,
vectorize_threshold: 4,
..Default::default()
};
let loop_ir = lower_kernel(&kernel, config).unwrap();
if let Some(Stmt::Loop(lp)) = loop_ir.body.stmts.first() {
assert!(lp.attrs.contains(LoopAttrs::VECTORIZE));
}
}
#[test]
fn test_sequential_access_pattern() {
let meta = TensorMeta {
dtype: DType::Float32,
shape: Shape::from_static([1024]),
strides: Strides::new([1]), layout: Layout::Contiguous,
alias: None,
};
let tensor = TensorRef {
id: TensorId::new(0),
meta,
};
let pattern = compute_access_pattern(&tensor);
assert_eq!(pattern, AccessPattern::Sequential);
}
#[test]
fn test_strided_access_pattern() {
let meta = TensorMeta {
dtype: DType::Float32,
shape: Shape::from_static([1024]),
strides: Strides::new([4]), layout: Layout::Strided,
alias: None,
};
let tensor = TensorRef {
id: TensorId::new(0),
meta,
};
let pattern = compute_access_pattern(&tensor);
assert_eq!(pattern, AccessPattern::Strided(4));
}
}