use crate::builder::BuildOptions;
use crate::region::wrap_anonymous;
use crate::tensor_ref::{TensorRef, TensorRefError};
use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum F32MulRhs<'a> {
SameInput,
Buffer(&'a str),
}
#[must_use]
pub(crate) fn f32_elementwise_mul(
op_id: &'static str,
input: &str,
rhs: F32MulRhs<'_>,
output: &str,
n: u32,
) -> Program {
let i = Expr::var("i");
let lhs_value = Expr::load(input, i.clone());
let rhs_value = match rhs {
F32MulRhs::SameInput => lhs_value.clone(),
F32MulRhs::Buffer(buffer) => Expr::load(buffer, i.clone()),
};
let body = vec![
Node::let_bind("i", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(i.clone(), Expr::u32(n)),
vec![Node::Store {
buffer: output.into(),
index: i,
value: Expr::mul(lhs_value, rhs_value),
}],
),
];
let mut buffers =
vec![BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::F32).with_count(n)];
if let F32MulRhs::Buffer(buffer) = rhs {
buffers.push(
BufferDecl::storage(buffer, 1, BufferAccess::ReadOnly, DataType::F32).with_count(n),
);
buffers.push(BufferDecl::output(output, 2, DataType::F32).with_count(n));
} else {
buffers.push(BufferDecl::output(output, 1, DataType::F32).with_count(n));
}
Program::wrapped(buffers, [64, 1, 1], vec![wrap_anonymous(op_id, body)])
}
pub(crate) fn try_u32_elementwise_unary<F>(
op_id: &'static str,
input: &str,
out: &str,
size: u32,
op: F,
) -> Result<Program, TensorRefError>
where
F: Fn(Expr) -> Expr,
{
crate::builder::build_elementwise_unary(
op_id,
TensorRef::u32_1d(input, size),
TensorRef::u32_1d(out, size),
BuildOptions::default(),
op,
)
}
#[must_use]
pub(crate) fn u32_elementwise_unary<F>(
op_id: &'static str,
input: &str,
out: &str,
size: u32,
op: F,
) -> Program
where
F: Fn(Expr) -> Expr,
{
try_u32_elementwise_unary(op_id, input, out, size, op).unwrap_or_else(|err| {
crate::builder::invalid_output_program(op_id, out, DataType::U32, format!("Fix: {err}"))
})
}
pub(crate) fn try_u32_elementwise_binary<F>(
op_id: &'static str,
a: &str,
b: &str,
out: &str,
size: u32,
op: F,
) -> Result<Program, TensorRefError>
where
F: Fn(Expr, Expr) -> Expr,
{
crate::builder::build_elementwise_binary(
op_id,
TensorRef::u32_1d(a, size),
TensorRef::u32_1d(b, size),
TensorRef::u32_1d(out, size),
BuildOptions::default(),
op,
)
}
#[must_use]
pub(crate) fn u32_elementwise_binary<F>(
op_id: &'static str,
a: &str,
b: &str,
out: &str,
size: u32,
op: F,
) -> Program
where
F: Fn(Expr, Expr) -> Expr,
{
try_u32_elementwise_binary(op_id, a, b, out, size, op).unwrap_or_else(|err| {
crate::builder::invalid_output_program(op_id, out, DataType::U32, format!("Fix: {err}"))
})
}