use bhc_core::{Expr, Literal, Var, VarId};
use bhc_index::Idx;
use bhc_intern::Symbol;
use bhc_span::Span;
use bhc_types::Ty;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use crate::{
DType, Dim, FoldFn, Layout, MapFn, Permutation, ReduceOp, Shape, SliceRange, SliceSpec,
Strides, TensorId, TensorMeta, TensorOp, TensorRef, ZipFn,
};
#[derive(Debug)]
pub struct LowerContext {
next_tensor_id: u32,
#[allow(dead_code)]
next_buffer_id: u32,
var_tensors: FxHashMap<VarId, TensorRef>,
ops: Vec<TensorOp>,
builtins: BuiltinTable,
}
#[derive(Debug)]
struct BuiltinTable {
map: Option<Symbol>,
zip_with: Option<Symbol>,
sum: Option<Symbol>,
product: Option<Symbol>,
foldl: Option<Symbol>,
#[allow(dead_code)]
foldr: Option<Symbol>,
reshape: Option<Symbol>,
transpose: Option<Symbol>,
slice: Option<Symbol>,
broadcast: Option<Symbol>,
matmul: Option<Symbol>,
dot: Option<Symbol>,
}
impl BuiltinTable {
fn new() -> Self {
Self {
map: None,
zip_with: None,
sum: None,
product: None,
foldl: None,
foldr: None,
reshape: None,
transpose: None,
slice: None,
broadcast: None,
matmul: None,
dot: None,
}
}
#[allow(clippy::too_many_arguments)]
fn with_symbols(
map: Symbol,
zip_with: Symbol,
sum: Symbol,
product: Symbol,
foldl: Symbol,
foldr: Symbol,
reshape: Symbol,
transpose: Symbol,
slice: Symbol,
broadcast: Symbol,
matmul: Symbol,
dot: Symbol,
) -> Self {
Self {
map: Some(map),
zip_with: Some(zip_with),
sum: Some(sum),
product: Some(product),
foldl: Some(foldl),
foldr: Some(foldr),
reshape: Some(reshape),
transpose: Some(transpose),
slice: Some(slice),
broadcast: Some(broadcast),
matmul: Some(matmul),
dot: Some(dot),
}
}
fn is_map(&self, sym: Symbol) -> bool {
self.map.is_some_and(|s| s == sym)
}
fn is_zip_with(&self, sym: Symbol) -> bool {
self.zip_with.is_some_and(|s| s == sym)
}
fn is_sum(&self, sym: Symbol) -> bool {
self.sum.is_some_and(|s| s == sym)
}
fn is_product(&self, sym: Symbol) -> bool {
self.product.is_some_and(|s| s == sym)
}
fn is_foldl(&self, sym: Symbol) -> bool {
self.foldl.is_some_and(|s| s == sym)
}
fn is_reshape(&self, sym: Symbol) -> bool {
self.reshape.is_some_and(|s| s == sym)
}
fn is_transpose(&self, sym: Symbol) -> bool {
self.transpose.is_some_and(|s| s == sym)
}
fn is_slice(&self, sym: Symbol) -> bool {
self.slice.is_some_and(|s| s == sym)
}
fn is_broadcast(&self, sym: Symbol) -> bool {
self.broadcast.is_some_and(|s| s == sym)
}
fn is_matmul(&self, sym: Symbol) -> bool {
self.matmul.is_some_and(|s| s == sym)
}
fn is_dot(&self, sym: Symbol) -> bool {
self.dot.is_some_and(|s| s == sym)
}
}
#[derive(Debug, Clone)]
pub enum LowerResult {
Tensor(TensorRef),
Scalar(ScalarValue),
Function,
NotTensor,
}
#[derive(Debug, Clone)]
pub enum ScalarValue {
Int(i64),
Float(f64),
Var(VarId),
}
impl LowerContext {
#[must_use]
pub fn new() -> Self {
Self {
next_tensor_id: 0,
next_buffer_id: 0,
var_tensors: FxHashMap::default(),
ops: Vec::new(),
builtins: BuiltinTable::new(),
}
}
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn with_builtins(
map: Symbol,
zip_with: Symbol,
sum: Symbol,
product: Symbol,
foldl: Symbol,
foldr: Symbol,
reshape: Symbol,
transpose: Symbol,
slice: Symbol,
broadcast: Symbol,
matmul: Symbol,
dot: Symbol,
) -> Self {
Self {
next_tensor_id: 0,
next_buffer_id: 0,
var_tensors: FxHashMap::default(),
ops: Vec::new(),
builtins: BuiltinTable::with_symbols(
map, zip_with, sum, product, foldl, foldr, reshape, transpose, slice, broadcast,
matmul, dot,
),
}
}
fn fresh_tensor_id(&mut self) -> TensorId {
let id = TensorId::new(self.next_tensor_id as usize);
self.next_tensor_id += 1;
id
}
pub fn register_tensor(&mut self, var_id: VarId, tensor: TensorRef) {
self.var_tensors.insert(var_id, tensor);
}
pub fn lookup_tensor(&self, var_id: VarId) -> Option<&TensorRef> {
self.var_tensors.get(&var_id)
}
#[must_use]
pub fn into_ops(self) -> Vec<TensorOp> {
self.ops
}
#[must_use]
pub fn ops(&self) -> &[TensorOp] {
&self.ops
}
pub fn lower_expr(&mut self, expr: &Expr) -> LowerResult {
match expr {
Expr::Var(var, _) => self.lower_var(var),
Expr::Lit(lit, ty, _) => self.lower_lit(lit, ty),
Expr::App(f, arg, span) => self.lower_app(f, arg, *span),
Expr::Let(bind, body, _) => self.lower_let(bind, body),
Expr::Lam(_, _, _) => LowerResult::Function,
Expr::TyLam(_, body, _) => self.lower_expr(body),
Expr::TyApp(f, _, _) => self.lower_expr(f),
Expr::Case(_, _, _, _) => LowerResult::NotTensor,
Expr::Cast(e, _, _) => self.lower_expr(e),
Expr::Tick(_, e, _) => self.lower_expr(e),
Expr::Lazy(e, _) => self.lower_expr(e),
Expr::Type(_, _) | Expr::Coercion(_, _) => LowerResult::NotTensor,
}
}
fn lower_var(&self, var: &Var) -> LowerResult {
if let Some(tensor) = self.var_tensors.get(&var.id) {
LowerResult::Tensor(tensor.clone())
} else {
LowerResult::Scalar(ScalarValue::Var(var.id))
}
}
fn lower_lit(&self, lit: &Literal, _ty: &Ty) -> LowerResult {
match lit {
Literal::Int(n) => LowerResult::Scalar(ScalarValue::Int(*n)),
Literal::Integer(n) => LowerResult::Scalar(ScalarValue::Int(*n as i64)),
Literal::Float(f) => LowerResult::Scalar(ScalarValue::Float(f64::from(*f))),
Literal::Double(d) => LowerResult::Scalar(ScalarValue::Float(*d)),
Literal::Char(_) | Literal::String(_) => LowerResult::NotTensor,
}
}
fn lower_app(&mut self, f: &Expr, arg: &Expr, span: Span) -> LowerResult {
let (func, args) = collect_app_args(f, arg);
if let Some(result) = self.try_lower_builtin(&func, &args, span) {
return result;
}
LowerResult::NotTensor
}
fn try_lower_builtin(
&mut self,
func: &Expr,
args: &[&Expr],
span: Span,
) -> Option<LowerResult> {
let func_name = match func {
Expr::Var(var, _) => var.name,
_ => return None,
};
if self.builtins.is_map(func_name) && args.len() == 2 {
return Some(self.lower_map(args[0], args[1], span));
}
if self.builtins.is_zip_with(func_name) && args.len() == 3 {
return Some(self.lower_zip_with(args[0], args[1], args[2], span));
}
if self.builtins.is_sum(func_name) && args.len() == 1 {
return Some(self.lower_reduce(ReduceOp::Sum, args[0], span));
}
if self.builtins.is_product(func_name) && args.len() == 1 {
return Some(self.lower_reduce(ReduceOp::Prod, args[0], span));
}
if self.builtins.is_foldl(func_name) && args.len() == 3 {
return Some(self.lower_fold(args[0], args[1], args[2], span));
}
if self.builtins.is_reshape(func_name) && args.len() == 2 {
return Some(self.lower_reshape(args[0], args[1], span));
}
if self.builtins.is_transpose(func_name) && args.len() == 1 {
return Some(self.lower_transpose(args[0], span));
}
if self.builtins.is_slice(func_name) && args.len() == 2 {
return Some(self.lower_slice(args[0], args[1], span));
}
if self.builtins.is_broadcast(func_name) && args.len() == 2 {
return Some(self.lower_broadcast(args[0], args[1], span));
}
if self.builtins.is_matmul(func_name) && args.len() == 2 {
return Some(self.lower_matmul(args[0], args[1], span));
}
if self.builtins.is_dot(func_name) && args.len() == 2 {
return Some(self.lower_dot(args[0], args[1], span));
}
None
}
fn lower_map(&mut self, f: &Expr, xs: &Expr, span: Span) -> LowerResult {
let xs_result = self.lower_expr(xs);
let xs_tensor = match xs_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let map_fn = MapFn {
name: extract_fn_name(f),
span,
};
let output = self.make_output_tensor(&xs_tensor.meta);
let op = TensorOp::Map(map_fn, xs_tensor);
self.ops.push(op);
LowerResult::Tensor(output)
}
fn lower_zip_with(&mut self, f: &Expr, xs: &Expr, ys: &Expr, span: Span) -> LowerResult {
let xs_result = self.lower_expr(xs);
let ys_result = self.lower_expr(ys);
let xs_tensor = match xs_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let ys_tensor = match ys_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let zip_fn = ZipFn {
name: extract_fn_name(f),
span,
};
let output = self.make_output_tensor(&xs_tensor.meta);
let op = TensorOp::ZipWith(zip_fn, xs_tensor, ys_tensor);
self.ops.push(op);
LowerResult::Tensor(output)
}
fn lower_reduce(&mut self, reduce_op: ReduceOp, xs: &Expr, _span: Span) -> LowerResult {
let xs_result = self.lower_expr(xs);
let xs_tensor = match xs_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let output_meta = TensorMeta {
dtype: xs_tensor.meta.dtype,
shape: Shape::scalar(),
strides: Strides::new([]),
layout: Layout::Contiguous,
alias: None,
};
let output = TensorRef {
id: self.fresh_tensor_id(),
meta: output_meta,
};
let op = TensorOp::ReduceAll(reduce_op, xs_tensor);
self.ops.push(op);
LowerResult::Tensor(output)
}
fn lower_fold(&mut self, f: &Expr, z: &Expr, xs: &Expr, span: Span) -> LowerResult {
let xs_result = self.lower_expr(xs);
let z_result = self.lower_expr(z);
let xs_tensor = match xs_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let z_tensor = match z_result {
LowerResult::Tensor(t) => t,
LowerResult::Scalar(_) => {
self.make_scalar_tensor()
}
_ => return LowerResult::NotTensor,
};
let fold_fn = FoldFn {
name: extract_fn_name(f),
span,
};
let output = TensorRef {
id: self.fresh_tensor_id(),
meta: z_tensor.meta.clone(),
};
let op = TensorOp::Fold(fold_fn, z_tensor, xs_tensor);
self.ops.push(op);
LowerResult::Tensor(output)
}
fn lower_reshape(&mut self, shape_expr: &Expr, xs: &Expr, _span: Span) -> LowerResult {
let xs_result = self.lower_expr(xs);
let xs_tensor = match xs_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let new_shape = extract_shape(shape_expr).unwrap_or_else(|| xs_tensor.meta.shape.clone());
let strides = Strides::contiguous(&new_shape, xs_tensor.meta.dtype.size_bytes())
.unwrap_or_else(|| Strides::new([]));
let output_meta = TensorMeta {
dtype: xs_tensor.meta.dtype,
shape: new_shape.clone(),
strides,
layout: Layout::Contiguous,
alias: xs_tensor.meta.alias,
};
let output = TensorRef {
id: self.fresh_tensor_id(),
meta: output_meta,
};
let op = TensorOp::Reshape(new_shape, xs_tensor);
self.ops.push(op);
LowerResult::Tensor(output)
}
fn lower_transpose(&mut self, xs: &Expr, _span: Span) -> LowerResult {
let xs_result = self.lower_expr(xs);
let xs_tensor = match xs_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let rank = xs_tensor.meta.shape.rank();
let perm: SmallVec<[usize; 4]> = (0..rank).rev().collect();
let permutation = Permutation::new(perm.clone());
let new_shape = apply_permutation_to_shape(&xs_tensor.meta.shape, &perm);
let new_strides = apply_permutation_to_strides(&xs_tensor.meta.strides, &perm);
let output_meta = TensorMeta {
dtype: xs_tensor.meta.dtype,
shape: new_shape,
strides: new_strides,
layout: Layout::Strided,
alias: xs_tensor.meta.alias,
};
let output = TensorRef {
id: self.fresh_tensor_id(),
meta: output_meta,
};
let op = TensorOp::Transpose(permutation, xs_tensor);
self.ops.push(op);
LowerResult::Tensor(output)
}
fn lower_slice(&mut self, spec_expr: &Expr, xs: &Expr, _span: Span) -> LowerResult {
let xs_result = self.lower_expr(xs);
let xs_tensor = match xs_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let slice_spec = extract_slice_spec(spec_expr)
.unwrap_or_else(|| make_identity_slice(xs_tensor.meta.shape.rank()));
let new_shape = compute_slice_output_shape(&slice_spec, &xs_tensor.meta.shape);
let output_meta = TensorMeta {
dtype: xs_tensor.meta.dtype,
shape: new_shape,
strides: xs_tensor.meta.strides.clone(), layout: Layout::Strided,
alias: xs_tensor.meta.alias,
};
let output = TensorRef {
id: self.fresh_tensor_id(),
meta: output_meta,
};
let op = TensorOp::Slice(slice_spec, xs_tensor);
self.ops.push(op);
LowerResult::Tensor(output)
}
fn lower_broadcast(&mut self, shape_expr: &Expr, xs: &Expr, _span: Span) -> LowerResult {
let xs_result = self.lower_expr(xs);
let xs_tensor = match xs_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let target_shape =
extract_shape(shape_expr).unwrap_or_else(|| xs_tensor.meta.shape.clone());
let broadcast_strides = compute_broadcast_strides(&xs_tensor.meta, &target_shape);
let output_meta = TensorMeta {
dtype: xs_tensor.meta.dtype,
shape: target_shape.clone(),
strides: broadcast_strides,
layout: Layout::Strided,
alias: xs_tensor.meta.alias,
};
let output = TensorRef {
id: self.fresh_tensor_id(),
meta: output_meta,
};
let op = TensorOp::Broadcast(target_shape, xs_tensor);
self.ops.push(op);
LowerResult::Tensor(output)
}
fn lower_matmul(&mut self, a: &Expr, b: &Expr, _span: Span) -> LowerResult {
let a_result = self.lower_expr(a);
let b_result = self.lower_expr(b);
let a_tensor = match a_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let b_tensor = match b_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let a_dims = a_tensor.meta.shape.dims();
let b_dims = b_tensor.meta.shape.dims();
let (m, n) = if a_dims.len() >= 2 && b_dims.len() >= 2 {
let m = a_dims[a_dims.len() - 2];
let n = b_dims[b_dims.len() - 1];
(m, n)
} else {
return LowerResult::NotTensor;
};
let output_shape = Shape::new([m, n]);
let output_strides = Strides::contiguous(&output_shape, a_tensor.meta.dtype.size_bytes())
.unwrap_or_else(|| Strides::new([]));
let output_meta = TensorMeta {
dtype: a_tensor.meta.dtype,
shape: output_shape,
strides: output_strides,
layout: Layout::Contiguous,
alias: None,
};
let output = TensorRef {
id: self.fresh_tensor_id(),
meta: output_meta,
};
let op = TensorOp::MatMul(a_tensor, b_tensor);
self.ops.push(op);
LowerResult::Tensor(output)
}
fn lower_dot(&mut self, a: &Expr, b: &Expr, _span: Span) -> LowerResult {
let a_result = self.lower_expr(a);
let b_result = self.lower_expr(b);
let a_tensor = match a_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let b_tensor = match b_result {
LowerResult::Tensor(t) => t,
_ => return LowerResult::NotTensor,
};
let output_meta = TensorMeta {
dtype: a_tensor.meta.dtype,
shape: Shape::scalar(),
strides: Strides::new([]),
layout: Layout::Contiguous,
alias: None,
};
let output = TensorRef {
id: self.fresh_tensor_id(),
meta: output_meta,
};
let op = TensorOp::Dot(a_tensor, b_tensor);
self.ops.push(op);
LowerResult::Tensor(output)
}
fn lower_let(&mut self, bind: &bhc_core::Bind, body: &Expr) -> LowerResult {
match bind {
bhc_core::Bind::NonRec(var, rhs) => {
let rhs_result = self.lower_expr(rhs);
if let LowerResult::Tensor(tensor) = rhs_result {
self.register_tensor(var.id, tensor);
}
}
bhc_core::Bind::Rec(bindings) => {
for (var, rhs) in bindings {
let rhs_result = self.lower_expr(rhs);
if let LowerResult::Tensor(tensor) = rhs_result {
self.register_tensor(var.id, tensor);
}
}
}
}
self.lower_expr(body)
}
fn make_output_tensor(&mut self, input_meta: &TensorMeta) -> TensorRef {
TensorRef {
id: self.fresh_tensor_id(),
meta: input_meta.clone(),
}
}
fn make_scalar_tensor(&mut self) -> TensorRef {
TensorRef {
id: self.fresh_tensor_id(),
meta: TensorMeta {
dtype: DType::Float64,
shape: Shape::scalar(),
strides: Strides::new([]),
layout: Layout::Contiguous,
alias: None,
},
}
}
}
fn apply_permutation_to_shape(shape: &Shape, perm: &[usize]) -> Shape {
let dims = shape.dims();
let new_dims: SmallVec<[Dim; 4]> = perm.iter().map(|&i| dims[i]).collect();
Shape::new(new_dims)
}
fn apply_permutation_to_strides(strides: &Strides, perm: &[usize]) -> Strides {
let vals = strides.values();
let new_vals: SmallVec<[i64; 4]> = perm.iter().map(|&i| vals[i]).collect();
Strides::new(new_vals)
}
fn make_identity_slice(rank: usize) -> SliceSpec {
let ranges: SmallVec<[SliceRange; 4]> = (0..rank)
.map(|_| SliceRange {
start: None,
stop: None,
step: 1,
})
.collect();
SliceSpec { ranges }
}
fn compute_slice_output_shape(slice: &SliceSpec, input_shape: &Shape) -> Shape {
let dims = input_shape.dims();
let mut new_dims: SmallVec<[Dim; 4]> = SmallVec::new();
for (i, range) in slice.ranges.iter().enumerate() {
if i >= dims.len() {
break;
}
let dim = &dims[i];
let new_dim = match (dim, range.start, range.stop, range.step) {
(d, None, None, 1) => *d,
(Dim::Static(n), start, stop, step) => {
let s = start.unwrap_or(0) as usize;
let e = stop.map(|x| x as usize).unwrap_or(*n);
let st = step.unsigned_abs() as usize;
let st = if st == 0 { 1 } else { st };
let new_size = (e.saturating_sub(s) + st - 1) / st;
Dim::Static(new_size)
}
(Dim::Dynamic(sym), _, _, _) => Dim::Dynamic(*sym),
};
new_dims.push(new_dim);
}
for dim in dims.iter().skip(slice.ranges.len()) {
new_dims.push(*dim);
}
Shape::new(new_dims)
}
impl Default for LowerContext {
fn default() -> Self {
Self::new()
}
}
fn collect_app_args<'a>(f: &'a Expr, arg: &'a Expr) -> (&'a Expr, Vec<&'a Expr>) {
let mut args = vec![arg];
let mut current = f;
while let Expr::App(inner_f, inner_arg, _) = current {
args.push(inner_arg.as_ref());
current = inner_f.as_ref();
}
args.reverse();
(current, args)
}
fn extract_fn_name(expr: &Expr) -> Symbol {
match expr {
Expr::Var(var, _) => var.name,
Expr::Lam(_, _, _) => {
unsafe { Symbol::from_raw(0) }
}
_ => unsafe { Symbol::from_raw(0) },
}
}
fn extract_shape(_expr: &Expr) -> Option<Shape> {
None
}
fn extract_slice_spec(_expr: &Expr) -> Option<SliceSpec> {
None
}
fn compute_broadcast_strides(source: &TensorMeta, target_shape: &Shape) -> Strides {
let source_rank = source.shape.rank();
let target_rank = target_shape.rank();
let mut strides: SmallVec<[i64; 4]> = SmallVec::new();
for i in 0..target_rank {
let source_idx = source_rank as isize - (target_rank as isize - i as isize);
if source_idx < 0 {
strides.push(0);
} else {
let src_idx = source_idx as usize;
let src_dim = source.shape.dims()[src_idx];
let tgt_dim = target_shape.dims()[i];
if src_dim == tgt_dim {
strides.push(source.strides.values()[src_idx]);
} else if src_dim == Dim::Static(1) {
strides.push(0);
} else {
strides.push(source.strides.values()[src_idx]);
}
}
}
Strides::new(strides)
}
pub fn lower_module(module: &bhc_core::CoreModule) -> Vec<TensorOp> {
let mut ctx = LowerContext::new();
for bind in &module.bindings {
match bind {
bhc_core::Bind::NonRec(var, rhs) => {
let result = ctx.lower_expr(rhs);
if let LowerResult::Tensor(tensor) = result {
ctx.register_tensor(var.id, tensor);
}
}
bhc_core::Bind::Rec(bindings) => {
for (var, rhs) in bindings {
let result = ctx.lower_expr(rhs);
if let LowerResult::Tensor(tensor) = result {
ctx.register_tensor(var.id, tensor);
}
}
}
}
}
ctx.into_ops()
}
#[cfg(test)]
mod tests {
use super::*;
use bhc_core::VarId;
fn make_tensor_ref(id: u32, shape: &[usize], dtype: DType) -> TensorRef {
let dims: SmallVec<[Dim; 4]> = shape.iter().map(|&d| Dim::Static(d)).collect();
let shape = Shape::new(dims);
let strides =
Strides::contiguous(&shape, dtype.size_bytes()).unwrap_or_else(|| Strides::new([]));
TensorRef {
id: TensorId::new(id as usize),
meta: TensorMeta {
dtype,
shape,
strides,
layout: Layout::Contiguous,
alias: None,
},
}
}
#[test]
fn test_lower_context_creation() {
let ctx = LowerContext::new();
assert!(ctx.ops().is_empty());
}
#[test]
fn test_register_and_lookup_tensor() {
let mut ctx = LowerContext::new();
let tensor = make_tensor_ref(0, &[100], DType::Float32);
let var_id = VarId::new(42);
ctx.register_tensor(var_id, tensor.clone());
let found = ctx.lookup_tensor(var_id);
assert!(found.is_some());
assert_eq!(found.unwrap().id, tensor.id);
}
#[test]
fn test_fresh_ids_increment() {
let mut ctx = LowerContext::new();
let id1 = ctx.fresh_tensor_id();
let id2 = ctx.fresh_tensor_id();
assert_ne!(id1, id2);
}
#[test]
fn test_collect_app_args() {
let f_var = bhc_core::Var::new(unsafe { Symbol::from_raw(1) }, VarId::new(0), Ty::Error);
let f_expr = Expr::Var(f_var, Span::default());
let x = Expr::Lit(Literal::Int(1), Ty::Error, Span::default());
let y = Expr::Lit(Literal::Int(2), Ty::Error, Span::default());
let z = Expr::Lit(Literal::Int(3), Ty::Error, Span::default());
let app1 = Expr::App(Box::new(f_expr), Box::new(x), Span::default());
let app2 = Expr::App(Box::new(app1), Box::new(y), Span::default());
let app3 = Expr::App(Box::new(app2), Box::new(z), Span::default());
if let Expr::App(f, arg, _) = &app3 {
let (func, args) = collect_app_args(f, arg);
assert_eq!(args.len(), 3);
assert!(matches!(func, Expr::Var(_, _)));
}
}
#[test]
fn test_lower_literal_int() {
let ctx = LowerContext::new();
let lit = Literal::Int(42);
let result = ctx.lower_lit(&lit, &Ty::Error);
match result {
LowerResult::Scalar(ScalarValue::Int(n)) => assert_eq!(n, 42),
_ => panic!("Expected scalar int"),
}
}
#[test]
fn test_lower_literal_float() {
let ctx = LowerContext::new();
let lit = Literal::Double(3.14);
let result = ctx.lower_lit(&lit, &Ty::Error);
match result {
LowerResult::Scalar(ScalarValue::Float(f)) => {
assert!((f - 3.14).abs() < f64::EPSILON);
}
_ => panic!("Expected scalar float"),
}
}
#[test]
fn test_broadcast_strides_expansion() {
let source_meta = TensorMeta {
dtype: DType::Float32,
shape: Shape::new([Dim::Static(1), Dim::Static(10)]),
strides: Strides::new([10, 1]),
layout: Layout::Contiguous,
alias: None,
};
let target_shape = Shape::new([Dim::Static(5), Dim::Static(1), Dim::Static(10)]);
let broadcast = compute_broadcast_strides(&source_meta, &target_shape);
assert_eq!(broadcast.values()[0], 0);
assert_eq!(broadcast.values()[1], 10);
assert_eq!(broadcast.values()[2], 1);
}
#[test]
fn test_broadcast_strides_broadcasting() {
let source_meta = TensorMeta {
dtype: DType::Float32,
shape: Shape::new([Dim::Static(1), Dim::Static(5)]),
strides: Strides::new([5, 1]),
layout: Layout::Contiguous,
alias: None,
};
let target_shape = Shape::new([Dim::Static(3), Dim::Static(5)]);
let broadcast = compute_broadcast_strides(&source_meta, &target_shape);
assert_eq!(broadcast.values()[0], 0);
assert_eq!(broadcast.values()[1], 1);
}
#[test]
fn test_permutation_to_shape() {
let shape = Shape::new([Dim::Static(3), Dim::Static(4)]);
let perm = [1, 0];
let new_shape = apply_permutation_to_shape(&shape, &perm);
assert_eq!(new_shape.dims().len(), 2);
assert_eq!(new_shape.dims()[0], Dim::Static(4));
assert_eq!(new_shape.dims()[1], Dim::Static(3));
}
#[test]
fn test_permutation_to_strides() {
let strides = Strides::new([4, 1]);
let perm = [1, 0];
let new_strides = apply_permutation_to_strides(&strides, &perm);
assert_eq!(new_strides.values().len(), 2);
assert_eq!(new_strides.values()[0], 1);
assert_eq!(new_strides.values()[1], 4);
}
#[test]
fn test_transpose_3d() {
let shape = Shape::new([Dim::Static(2), Dim::Static(3), Dim::Static(4)]);
let strides = Strides::new([12, 4, 1]);
let perm: SmallVec<[usize; 4]> = [2, 1, 0].into_iter().collect();
let new_shape = apply_permutation_to_shape(&shape, &perm);
let new_strides = apply_permutation_to_strides(&strides, &perm);
assert_eq!(new_shape.dims()[0], Dim::Static(4));
assert_eq!(new_shape.dims()[1], Dim::Static(3));
assert_eq!(new_shape.dims()[2], Dim::Static(2));
assert_eq!(new_strides.values()[0], 1);
assert_eq!(new_strides.values()[1], 4);
assert_eq!(new_strides.values()[2], 12);
}
#[test]
fn test_identity_slice() {
let slice = make_identity_slice(3);
assert_eq!(slice.ranges.len(), 3);
for r in slice.ranges.iter() {
assert_eq!(r.start, None);
assert_eq!(r.stop, None);
assert_eq!(r.step, 1);
}
}
#[test]
fn test_slice_output_shape_identity() {
let shape = Shape::new([Dim::Static(10), Dim::Static(20)]);
let slice = make_identity_slice(2);
let output = compute_slice_output_shape(&slice, &shape);
assert_eq!(output.dims()[0], Dim::Static(10));
assert_eq!(output.dims()[1], Dim::Static(20));
}
#[test]
fn test_slice_output_shape_with_bounds() {
let shape = Shape::new([Dim::Static(10), Dim::Static(20)]);
let slice = SliceSpec {
ranges: smallvec::smallvec![
SliceRange {
start: Some(2),
stop: Some(8),
step: 1
},
SliceRange {
start: Some(5),
stop: Some(15),
step: 1
},
],
};
let output = compute_slice_output_shape(&slice, &shape);
assert_eq!(output.dims()[0], Dim::Static(6));
assert_eq!(output.dims()[1], Dim::Static(10));
}
#[test]
fn test_slice_output_shape_with_step() {
let shape = Shape::new([Dim::Static(10)]);
let slice = SliceSpec {
ranges: smallvec::smallvec![SliceRange {
start: Some(0),
stop: Some(10),
step: 2
},],
};
let output = compute_slice_output_shape(&slice, &shape);
assert_eq!(output.dims()[0], Dim::Static(5));
}
#[test]
fn test_reshape_contiguous_is_metadata_only() {
use crate::fusion::is_reshape_metadata_only;
let tensor = make_tensor_ref(0, &[2, 3, 4], DType::Float32);
assert!(is_reshape_metadata_only(&tensor));
let strided_meta = TensorMeta {
dtype: DType::Float32,
shape: Shape::new([Dim::Static(3), Dim::Static(4)]),
strides: Strides::new([8, 1]), layout: Layout::Strided,
alias: None,
};
let strided = TensorRef {
id: TensorId::new(1),
meta: strided_meta,
};
assert!(!is_reshape_metadata_only(&strided));
}
}