use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{AlgebraicLaw, OpSpec};
pub const INPUTS: &[DataType] = &[DataType::I32];
pub const OUTPUTS: &[DataType] = &[DataType::I32];
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Idempotent];
#[derive(Debug, Clone, Copy, Default)]
pub struct Sign;
impl Sign {
pub const SPEC: OpSpec =
OpSpec::composition_inlinable("primitive.math.sign", INPUTS, OUTPUTS, LAWS, Self::program);
#[must_use]
pub fn program() -> Program {
let idx = Expr::var("idx");
Program::new(
vec![
BufferDecl::read("a", 0, DataType::I32),
BufferDecl::output("out", 1, DataType::I32),
],
[64, 1, 1],
vec![
Node::let_bind("idx", Expr::gid_x()),
Node::if_then(
Expr::lt(idx.clone(), Expr::buf_len("out")),
vec![Node::store(
"out",
idx.clone(),
sign_expr(Expr::load("a", idx)),
)],
),
],
)
}
}
pub fn sign_expr(value: Expr) -> Expr {
Expr::select(
Expr::lt(value.clone(), Expr::i32(0)),
Expr::i32(-1),
Expr::select(Expr::gt(value, Expr::i32(0)), Expr::i32(1), Expr::i32(0)),
)
}