use std::fmt::{self, Display, Formatter};
use crate::engine::vector::scalar::VeScalar;
use crate::scalar::Opt;
use furiosa_opt_macro::primitive;
#[derive(Debug, Clone, Copy)]
pub enum ArgMode {
Unary,
Binary(BinaryArgMode),
Ternary(TernaryArgMode),
}
#[primitive(op::BinaryArgMode)]
#[derive(Debug, Clone, Copy)]
pub enum BinaryArgMode {
Mode00,
Mode01,
Mode10,
Mode11,
}
impl Display for BinaryArgMode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Mode00 => write!(f, "BinaryArgMode::Mode00"),
Self::Mode01 => write!(f, "BinaryArgMode::Mode01"),
Self::Mode10 => write!(f, "BinaryArgMode::Mode10"),
Self::Mode11 => write!(f, "BinaryArgMode::Mode11"),
}
}
}
impl BinaryArgMode {
pub fn apply_opt<D: VeScalar>(&self, op: impl Fn(D, D) -> D + 'static) -> Box<dyn Fn(Opt<D>, Opt<D>) -> Opt<D>> {
match self {
BinaryArgMode::Mode00 => Box::new(move |a, _b| match a {
Opt::Init(a) => Opt::Init(op(a, a)),
Opt::Uninit => Opt::Uninit,
}),
BinaryArgMode::Mode01 => Box::new(move |a, b| match (a, b) {
(Opt::Init(a), Opt::Init(b)) => Opt::Init(op(a, b)),
_ => Opt::Uninit,
}),
BinaryArgMode::Mode10 => Box::new(move |a, b| match (a, b) {
(Opt::Init(a), Opt::Init(b)) => Opt::Init(op(b, a)),
_ => Opt::Uninit,
}),
BinaryArgMode::Mode11 => Box::new(move |_a, b| match b {
Opt::Init(b) => Opt::Init(op(b, b)),
Opt::Uninit => Opt::Uninit,
}),
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum TernaryArgMode {
Mode012,
Mode002,
Mode102,
Mode112,
Mode020,
Mode021,
Mode120,
}
impl Display for TernaryArgMode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Mode012 => write!(f, "TernaryArgMode::Mode012"),
Self::Mode002 => write!(f, "TernaryArgMode::Mode002"),
Self::Mode102 => write!(f, "TernaryArgMode::Mode102"),
Self::Mode112 => write!(f, "TernaryArgMode::Mode112"),
Self::Mode020 => write!(f, "TernaryArgMode::Mode020"),
Self::Mode021 => write!(f, "TernaryArgMode::Mode021"),
Self::Mode120 => write!(f, "TernaryArgMode::Mode120"),
}
}
}
impl TernaryArgMode {
pub fn apply_opt<D: VeScalar>(
&self,
op: impl Fn(D, D, D) -> D + 'static,
) -> Box<dyn Fn(Opt<D>, Opt<D>, Opt<D>) -> Opt<D>> {
match self {
TernaryArgMode::Mode012 => Box::new(move |m, o0, o1| match (m, o0, o1) {
(Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(m, o0, o1)),
_ => Opt::Uninit,
}),
TernaryArgMode::Mode002 => Box::new(move |m, _o0, o1| match (m, o1) {
(Opt::Init(m), Opt::Init(o1)) => Opt::Init(op(m, m, o1)),
_ => Opt::Uninit,
}),
TernaryArgMode::Mode102 => Box::new(move |m, o0, o1| match (m, o0, o1) {
(Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(o0, m, o1)),
_ => Opt::Uninit,
}),
TernaryArgMode::Mode112 => Box::new(move |_m, o0, o1| match (o0, o1) {
(Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(o0, o0, o1)),
_ => Opt::Uninit,
}),
TernaryArgMode::Mode020 => Box::new(move |m, _o0, o1| match (m, o1) {
(Opt::Init(m), Opt::Init(o1)) => Opt::Init(op(m, o1, m)),
_ => Opt::Uninit,
}),
TernaryArgMode::Mode021 => Box::new(move |m, o0, o1| match (m, o0, o1) {
(Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(m, o1, o0)),
_ => Opt::Uninit,
}),
TernaryArgMode::Mode120 => Box::new(move |m, o0, o1| match (m, o0, o1) {
(Opt::Init(m), Opt::Init(o0), Opt::Init(o1)) => Opt::Init(op(o0, o1, m)),
_ => Opt::Uninit,
}),
}
}
}