use crate::error::{Error, Result};
use crate::ir::model::expr::Expr;
use crate::ir::model::node::Node;
use crate::ir::model::program::{BufferDecl, Program};
use crate::ir::model::types::{BufferAccess, DataType};
use std::collections::HashMap;
pub type OpResolver = fn(&str) -> Option<Program>;
#[inline]
pub fn inline_calls(program: &Program) -> Result<Program> {
inline_calls_with_resolver(program, default_resolver)
}
#[inline]
pub fn inline_calls_with_resolver(program: &Program, resolver: OpResolver) -> Result<Program> {
let mut ctx = InlineCtx::new(resolver);
let entry = ctx.inline_nodes(program.entry())?;
Ok(Program::new(
program.buffers().to_vec(),
program.workgroup_size(),
entry,
))
}
#[inline]
pub fn default_resolver(op_id: &str) -> Option<Program> {
crate::ops::registry::lookup_program(op_id)
}
pub struct InlineCtx {
resolver: OpResolver,
stack: Vec<String>,
next_call_id: usize,
}
mod expand;
mod impl_inlinectx;
#[inline]
pub fn input_arg_map(callee: &Program, args: Vec<Expr>) -> HashMap<String, Expr> {
let mut inputs = input_buffers(callee);
inputs.sort_by_key(|buf| buf.binding());
inputs
.into_iter()
.zip(args)
.map(|(buf, arg)| (buf.name().to_string(), arg))
.collect()
}
#[inline]
pub fn input_buffers(callee: &Program) -> Vec<&BufferDecl> {
callee
.buffers()
.iter()
.filter(|buf| matches!(buf.access(), BufferAccess::ReadOnly | BufferAccess::Uniform))
.collect()
}
#[inline]
pub fn output_buffer<'a>(op_id: &str, program: &'a Program) -> Result<&'a BufferDecl> {
let outputs: Vec<&BufferDecl> = program
.buffers()
.iter()
.filter(|buf| buf.is_output())
.collect();
match outputs.as_slice() {
[output] => Ok(output),
[] => Err(Error::InlineNoOutput {
op_id: op_id.to_string(),
}),
outputs => Err(Error::InlineOutputCountMismatch {
op_id: op_id.to_string(),
got: outputs.len(),
}),
}
}
#[inline]
pub fn zero_value(ty: DataType) -> Expr {
match ty {
DataType::I32 => Expr::i32(0),
DataType::Bool => Expr::LitBool(false),
DataType::F32 | DataType::F16 | DataType::BF16 | DataType::F64 => Expr::f32(0.0),
DataType::U32
| DataType::U64
| DataType::Vec2U32
| DataType::Vec4U32
| DataType::Bytes
| DataType::Array { .. }
| DataType::Tensor => Expr::u32(0),
}
}