use super::*;
use crate::engine::vector::layer::{FpToFxp, FxpToFp};
use crate::prelude::VeScalar;
use crate::scalar::Opt;
impl LogicBinaryOpI32 {
pub fn op_fn(&self) -> fn(i32, i32) -> i32 {
match self {
LogicBinaryOpI32::BitAnd => |a, b| a & b,
LogicBinaryOpI32::BitOr => |a, b| a | b,
LogicBinaryOpI32::BitXor => |a, b| a ^ b,
LogicBinaryOpI32::LeftShift => |a, b| a << (b as u32),
LogicBinaryOpI32::LogicRightShift => |a, b| ((a as u32) >> (b as u32)) as i32,
LogicBinaryOpI32::ArithRightShift => |a, b| a >> (b as u32),
}
}
}
impl LogicBinaryOpF32 {
pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
match self {
LogicBinaryOpF32::BitAnd => |a, b| f32::from_bits(a.to_bits() & b.to_bits()),
LogicBinaryOpF32::BitOr => |a, b| f32::from_bits(a.to_bits() | b.to_bits()),
LogicBinaryOpF32::BitXor => |a, b| f32::from_bits(a.to_bits() ^ b.to_bits()),
}
}
}
impl LogicOpI {
pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
let op = self.op.op_fn();
self.arg_mode.apply_opt(op)
}
}
impl LogicOpF {
pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
let op = self.op.op_fn();
self.arg_mode.apply_opt(op)
}
}
impl FxpBinaryOp {
pub fn op_fn(&self) -> fn(i32, i32) -> i32 {
match self {
FxpBinaryOp::AddFxp => |a, b| a.wrapping_add(b),
FxpBinaryOp::AddFxpSat => |a, b| a.saturating_add(b),
FxpBinaryOp::SubFxp => |a, b| a.wrapping_sub(b),
FxpBinaryOp::SubFxpSat => |a, b| a.saturating_sub(b),
FxpBinaryOp::LeftShift => |a, b| a << (b as u32),
FxpBinaryOp::LeftShiftSat => |a, b| a.saturating_mul(1 << (b as u32)),
FxpBinaryOp::MulFxp => |a, b| {
if a == i32::MIN && b == i32::MIN {
i32::MAX
} else {
let product = i64::from(a) * i64::from(b);
(((product >> 30) + 1) >> 1) as i32
}
},
FxpBinaryOp::MulInt => |a, b| a.wrapping_mul(b),
FxpBinaryOp::LogicRightShift => |a, b| ((a as u32) >> (b as u32)) as i32,
FxpBinaryOp::ArithRightShift => |a, b| a >> (b as u32),
FxpBinaryOp::ArithRightShiftRound => todo!("ArithRightShiftRound not implemented"),
}
}
}
impl FxpOp {
pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
let op = self.op.op_fn();
self.arg_mode.apply_opt(op)
}
}
impl FpUnaryOp {
pub fn op_fn(&self) -> fn(f32) -> f32 {
match self {
FpUnaryOp::Exp => |x| x.exp(),
FpUnaryOp::NegExp => |x| (-x).exp(),
FpUnaryOp::Sqrt => |x| x.sqrt(),
FpUnaryOp::Tanh => |x| x.tanh(),
FpUnaryOp::Sigmoid => |x| 1.0 / (1.0 + (-x).exp()),
FpUnaryOp::Erf => |_x| todo!("Erf not implemented"),
FpUnaryOp::Log => |x| x.ln(),
FpUnaryOp::Sin => |x| x.sin(),
FpUnaryOp::Cos => |x| x.cos(),
}
}
}
impl FpBinaryOp {
pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
match self {
FpBinaryOp::AddF => |a, b| a + b,
FpBinaryOp::SubF => |a, b| a - b,
FpBinaryOp::MulF(_) => |a, b| a * b,
FpBinaryOp::MaskMulF(_) => |_a, _b| todo!("MaskMulF not implemented"),
FpBinaryOp::DivF => |a, b| a / b,
}
}
}
impl FpTernaryOp {
pub fn op_fn(&self) -> fn(f32, f32, f32) -> f32 {
match self {
FpTernaryOp::FmaF => |a, b, c| a.mul_add(b, c),
FpTernaryOp::MaskFmaF => |_a, _b, _c| todo!("MaskFmaF not implemented"),
}
}
}
impl FpOp {
pub fn unary_op_opt(&self) -> Box<dyn Fn(Opt<f32>) -> Opt<f32>> {
match self {
FpOp::UnaryOp { op } => Box::new(op.unary_op_fn()),
_ => panic!("unary_op_opt called on non-unary FpOp"),
}
}
pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
match self {
FpOp::BinaryOp { op, mode } => mode.apply_opt(op.op_fn()),
_ => panic!("binary_op_opt called on non-binary FpOp"),
}
}
pub fn ternary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>, Opt<f32>) -> Opt<f32>> {
match self {
FpOp::TernaryOp { op, mode } => mode.apply_opt(op.op_fn()),
_ => panic!("ternary_op_opt called on non-ternary FpOp"),
}
}
}
impl ClipBinaryOpI32 {
pub fn op_fn(&self) -> fn(i32, i32) -> i32 {
match self {
ClipBinaryOpI32::AddFxp => |a, b| a.wrapping_add(b),
ClipBinaryOpI32::AddFxpSat => |a, b| a.saturating_add(b),
ClipBinaryOpI32::Min => |a, b| a.min(b),
ClipBinaryOpI32::Max => |a, b| a.max(b),
ClipBinaryOpI32::AbsMin => |a, b| if a.abs() < b.abs() { a } else { b },
ClipBinaryOpI32::AbsMax => |a, b| if a.abs() > b.abs() { a } else { b },
}
}
}
impl ClipBinaryOpF32 {
pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
match self {
ClipBinaryOpF32::Add => |a, b| a + b,
ClipBinaryOpF32::Min => |a, b| a.min(b),
ClipBinaryOpF32::Max => |a, b| a.max(b),
ClipBinaryOpF32::AbsMin => |a, b| if a.abs() < b.abs() { a } else { b },
ClipBinaryOpF32::AbsMax => |a, b| if a.abs() > b.abs() { a } else { b },
}
}
}
impl ClipOpI {
pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
let op = self.op.op_fn();
self.mode.apply_opt(op)
}
}
impl ClipOpF {
pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
let op = self.op.op_fn();
self.mode.apply_opt(op)
}
}
impl FxpToFp {
pub fn op_fn(&self) -> impl Fn(i32) -> f32 {
let int_width = self.int_width();
move |x| crate::float::fixedpoint_to_float(x, int_width)
}
}
impl FpToFxp {
pub fn op_fn(&self) -> impl Fn(f32) -> i32 {
let int_width = self.int_width();
move |x| crate::float::float_to_fixedpoint(x, int_width)
}
}
pub trait HasConversionOp<D: VeScalar, D2: VeScalar>: Clone + Copy {
fn conversion_op_fn(&self) -> impl Fn(D) -> D2;
}
impl HasConversionOp<i32, f32> for FxpToFp {
fn conversion_op_fn(&self) -> impl Fn(i32) -> f32 {
self.op_fn()
}
}
impl HasConversionOp<f32, i32> for FpToFxp {
fn conversion_op_fn(&self) -> impl Fn(f32) -> i32 {
self.op_fn()
}
}
fn lift_reduce_fn<D: Copy>(reduce_fn: impl Fn(D, D) -> D + 'static) -> impl Fn(Opt<D>, Opt<D>) -> Opt<D> {
move |a: Opt<D>, b: Opt<D>| match (a, b) {
(Opt::Uninit, _) => b,
(_, Opt::Uninit) => a,
(Opt::Init(x), Opt::Init(y)) => Opt::Init(reduce_fn(x, y)),
}
}
impl IntraSliceReduceOpI32 {
pub fn reduce_fn(&self) -> fn(i32, i32) -> i32 {
match self {
IntraSliceReduceOpI32::AddSat => |a, b| a.saturating_add(b),
IntraSliceReduceOpI32::Max => |a, b| a.max(b),
IntraSliceReduceOpI32::Min => |a, b| a.min(b),
}
}
pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
Box::new(lift_reduce_fn(self.reduce_fn()))
}
pub fn identity(&self) -> i32 {
match self {
IntraSliceReduceOpI32::AddSat => 0,
IntraSliceReduceOpI32::Max => i32::MIN,
IntraSliceReduceOpI32::Min => i32::MAX,
}
}
}
impl IntraSliceReduceOpF32 {
pub fn reduce_fn(&self) -> fn(f32, f32) -> f32 {
match self {
IntraSliceReduceOpF32::Add => |a, b| a + b,
IntraSliceReduceOpF32::Max => |a, b| a.max(b),
IntraSliceReduceOpF32::Min => |a, b| a.min(b),
}
}
pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
Box::new(lift_reduce_fn(self.reduce_fn()))
}
pub fn identity(&self) -> f32 {
match self {
IntraSliceReduceOpF32::Add => 0.0,
IntraSliceReduceOpF32::Max => f32::NEG_INFINITY,
IntraSliceReduceOpF32::Min => f32::INFINITY,
}
}
}
impl InterSliceReduceOpI32 {
pub fn reduce_fn(&self) -> fn(i32, i32) -> i32 {
match self {
InterSliceReduceOpI32::Add => |a, b| a.wrapping_add(b),
InterSliceReduceOpI32::AddSat => |a, b| a.saturating_add(b),
InterSliceReduceOpI32::Max => |a, b| a.max(b),
InterSliceReduceOpI32::Min => |a, b| a.min(b),
}
}
pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<i32>, Opt<i32>) -> Opt<i32>> {
Box::new(lift_reduce_fn(self.reduce_fn()))
}
}
impl InterSliceReduceOpF32 {
pub fn reduce_fn(&self) -> fn(f32, f32) -> f32 {
match self {
InterSliceReduceOpF32::Add => |a, b| a + b,
InterSliceReduceOpF32::Max => |a, b| a.max(b),
InterSliceReduceOpF32::Min => |a, b| a.min(b),
InterSliceReduceOpF32::Mul => |a, b| a * b,
}
}
pub fn lifted_reduce_fn(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
Box::new(lift_reduce_fn(self.reduce_fn()))
}
}
impl FpDivBinaryOp {
pub fn op_fn(&self) -> fn(f32, f32) -> f32 {
match self {
FpDivBinaryOp::DivF => |a, b| a / b,
}
}
}
impl FpDivOp {
pub fn binary_op_opt(&self) -> Box<dyn Fn(Opt<f32>, Opt<f32>) -> Opt<f32>> {
let op = self.op.op_fn();
self.mode.apply_opt(op)
}
}
pub trait HasUnaryOp<D>: Clone + Copy {
fn unary_op_fn(self) -> impl Fn(Opt<D>) -> Opt<D>;
}
pub trait HasBinaryOp<D>: Clone + Copy {
fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<D>, Opt<D>) -> Opt<D>;
}
pub trait HasTernaryOp<D>: Clone + Copy {
fn ternary_op_fn(self, mode: Option<TernaryArgMode>) -> impl Fn(Opt<D>, Opt<D>, Opt<D>) -> Opt<D>;
}
impl HasBinaryOp<i32> for LogicBinaryOpI32 {
fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<i32>, Opt<i32>) -> Opt<i32> {
mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
}
}
impl HasBinaryOp<f32> for LogicBinaryOpF32 {
fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
}
}
impl HasBinaryOp<i32> for FxpBinaryOp {
fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<i32>, Opt<i32>) -> Opt<i32> {
mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
}
}
impl HasUnaryOp<f32> for FpUnaryOp {
fn unary_op_fn(self) -> impl Fn(Opt<f32>) -> Opt<f32> {
let op = self.op_fn();
move |x| match x {
Opt::Init(x) => Opt::Init(op(x)),
Opt::Uninit => Opt::Uninit,
}
}
}
impl HasBinaryOp<f32> for FpBinaryOp {
fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
}
}
impl HasTernaryOp<f32> for FpTernaryOp {
fn ternary_op_fn(self, mode: Option<TernaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>, Opt<f32>) -> Opt<f32> {
mode.unwrap_or(TernaryArgMode::Mode012).apply_opt(self.op_fn())
}
}
impl HasBinaryOp<f32> for FpDivBinaryOp {
fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
}
}
impl HasBinaryOp<f32> for FpDivOp {
fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
match mode {
Some(mode) => mode.apply_opt(self.op.op_fn()),
None => self.binary_op_opt(),
}
}
}
impl HasBinaryOp<i32> for ClipBinaryOpI32 {
fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<i32>, Opt<i32>) -> Opt<i32> {
mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
}
}
impl HasBinaryOp<f32> for ClipBinaryOpF32 {
fn binary_op_fn(self, mode: Option<BinaryArgMode>) -> impl Fn(Opt<f32>, Opt<f32>) -> Opt<f32> {
mode.unwrap_or(BinaryArgMode::Mode01).apply_opt(self.op_fn())
}
}