use crate::error::{Error, Result};
use crate::ir_inner::model::expr::Expr;
use crate::ir_inner::model::expr::Ident;
use crate::ir_inner::model::node::Node;
use crate::ir_inner::model::program::{BufferDecl, Program};
use crate::ir_inner::model::types::{BufferAccess, DataType};
use rustc_hash::FxHashMap as HashMap;
pub type OpResolver = fn(&str) -> Option<Program>;
#[inline]
#[must_use]
pub fn inline_calls(program: &Program) -> Result<Program> {
inline_calls_with_resolver(program, default_resolver)
}
#[inline]
#[must_use]
pub fn inline_calls_with_resolver(program: &Program, resolver: OpResolver) -> Result<Program> {
let mut ctx = InlineCtx::new(resolver);
let entry = ctx.inline_nodes(program.entry())?;
Ok(Program::wrapped(
program.buffers().to_vec(),
program.workgroup_size(),
entry,
))
}
#[inline]
#[must_use]
pub fn default_resolver(_op_id: &str) -> Option<Program> {
None
}
pub struct InlineCtx {
resolver: OpResolver,
stack: Vec<String>,
next_call_id: usize,
}
mod expand;
mod impl_inlinectx;
#[inline]
pub(crate) fn input_arg_map(callee: &Program, args: Vec<Expr>) -> HashMap<Ident, Expr> {
let mut inputs = input_buffers(callee);
inputs.sort_by_key(|buf| buf.binding());
inputs
.into_iter()
.zip(args)
.map(|(buf, arg)| (Ident::from(buf.name()), arg))
.collect()
}
#[must_use]
#[inline]
pub(crate) fn input_buffers(callee: &Program) -> Vec<&BufferDecl> {
callee
.buffers()
.iter()
.filter(|buf| matches!(buf.access(), BufferAccess::ReadOnly | BufferAccess::Uniform))
.collect()
}
#[inline]
#[must_use]
pub fn output_buffer<'a>(op_id: &str, program: &'a Program) -> Result<&'a BufferDecl> {
let outputs: Vec<&BufferDecl> = program
.buffers()
.iter()
.filter(|buf| buf.is_output())
.collect();
match outputs.as_slice() {
[output] => Ok(output),
[] => Err(Error::InlineNoOutput {
op_id: op_id.to_string(),
}),
outputs => Err(Error::InlineOutputCountMismatch {
op_id: op_id.to_string(),
got: outputs.len(),
}),
}
}
#[inline]
pub fn zero_value(ty: DataType) -> Expr {
match ty {
DataType::I32 => Expr::i32(0),
DataType::Bool => Expr::LitBool(false),
DataType::F32 | DataType::F16 | DataType::BF16 | DataType::F64 => Expr::f32(0.0),
DataType::U32
| DataType::U64
| DataType::Vec2U32
| DataType::Vec4U32
| DataType::Bytes
| DataType::Array { .. }
| DataType::Tensor => Expr::u32(0),
_ => Expr::u32(0),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir_inner::model::expr::Expr;
use crate::ir_inner::model::node::Node;
use crate::ir_inner::model::program::BufferDecl;
fn make_caller() -> Program {
Program::wrapped(
vec![
BufferDecl::storage("A", 0, BufferAccess::ReadOnly, DataType::U32).with_count(1),
BufferDecl::output("out", 1, DataType::U32).with_count(1),
],
[1, 1, 1],
vec![Node::store(
"out",
Expr::u32(0),
Expr::Call {
op_id: "add_one".into(),
args: vec![Expr::Load {
buffer: "A".into(),
index: Box::new(Expr::u32(0)),
}],
},
)],
)
}
fn make_callee() -> Program {
Program::wrapped(
vec![
BufferDecl::storage("x", 0, BufferAccess::ReadOnly, DataType::U32).with_count(1),
BufferDecl::output("result", 1, DataType::U32).with_count(1),
],
[1, 1, 1],
vec![Node::store(
"result",
Expr::u32(0),
Expr::add(
Expr::Load {
buffer: "x".into(),
index: Box::new(Expr::u32(0)),
},
Expr::u32(1),
),
)],
)
}
fn test_resolver(op_id: &str) -> Option<Program> {
if op_id == "add_one" {
Some(make_callee())
} else {
None
}
}
#[test]
fn test_inline_call_success() {
let caller = make_caller();
let inlined = inline_calls_with_resolver(&caller, test_resolver).unwrap();
let nodes = inlined.entry();
let mut has_call = false;
let dump = format!("{nodes:?}");
if dump.contains("Call {") {
has_call = true;
}
assert!(!has_call, "Expr::Call should be inlined out: {dump}");
}
#[test]
fn test_inline_unknown_op() {
let caller = make_caller();
let err = inline_calls(&caller).unwrap_err();
assert!(matches!(err, Error::InlineUnknownOp { .. }));
}
#[test]
fn test_zero_value() {
assert_eq!(zero_value(DataType::I32), Expr::i32(0));
assert_eq!(zero_value(DataType::F32), Expr::f32(0.0));
assert_eq!(zero_value(DataType::Bool), Expr::LitBool(false));
assert_eq!(zero_value(DataType::U32), Expr::u32(0));
}
}