use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{AlgebraicLaw, OpSpec, U32_U32_INPUTS, U32_OUTPUTS};
const LAWS: &[AlgebraicLaw] = &[
AlgebraicLaw::Commutative,
AlgebraicLaw::Associative,
AlgebraicLaw::Identity { element: 0 },
AlgebraicLaw::Idempotent,
];
#[derive(Debug, Clone, Copy, Default)]
pub struct Gcd;
impl Gcd {
pub const SPEC: OpSpec = OpSpec::composition_inlinable(
"primitive.math.gcd",
U32_U32_INPUTS,
U32_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
let idx = Expr::var("idx");
Program::new(
vec![
BufferDecl::read("a", 0, DataType::U32),
BufferDecl::read("b", 1, DataType::U32),
BufferDecl::output("out", 2, DataType::U32),
],
crate::ops::primitive::WORKGROUP_SIZE,
vec![
Node::let_bind("idx", Expr::gid_x()),
Node::if_then(
Expr::lt(idx.clone(), Expr::buf_len("out")),
vec![
Node::let_bind("x", Expr::load("a", idx.clone())),
Node::let_bind("y", Expr::load("b", idx.clone())),
Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(64),
vec![Node::if_then(
Expr::ne(Expr::var("y"), Expr::u32(0)),
vec![
Node::let_bind("t", Expr::var("y")),
Node::assign(
"y",
Expr::rem(Expr::var("x"), Expr::var("y")),
),
Node::assign("x", Expr::var("t")),
],
)],
),
Node::store("out", idx, Expr::var("x")),
],
),
],
)
}
}