use std::sync::Arc;
use morok_dtype::DType;
use snafu::ensure;
use crate::error::{InvalidDTypeForUnaryOpSnafu, WhereConditionNotBoolSnafu};
use crate::op::Op;
use crate::types::{BinaryOp, TernaryOp, UnaryOp};
use crate::uop::UOp;
use crate::{IntoUOp, Result};
macro_rules! binary_arith_ops {
($($method:ident => $op:ident),+ $(,)?) => {
$(
#[track_caller]
pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), dtype))
}
)+
};
}
macro_rules! division_ops {
($($method:ident => $op:ident),+ $(,)?) => {
$(
#[track_caller]
pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
Self::check_division_by_zero(rhs)?;
let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), dtype))
}
)+
};
}
macro_rules! bitwise_binary_ops {
($($method:ident => $op:ident),+ $(,)?) => {
$(
pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
Self::check_bitwise_dtype(dtype.clone(), BinaryOp::$op)?;
Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), dtype))
}
)+
};
}
macro_rules! shift_ops {
($($method:ident => $op:ident),+ $(,)?) => {
$(
pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
let dtype = self.dtype();
Self::check_bitwise_dtype(dtype.clone(), BinaryOp::$op)?;
Self::validate_binary_shapes(self, rhs, BinaryOp::$op)?;
Ok(Self::new(Op::Binary(BinaryOp::$op, self.clone(), rhs.clone()), dtype))
}
)+
};
}
macro_rules! cmp_ops {
($($method:ident => $op:ident),+ $(,)?) => {
$(
#[track_caller]
pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
let vcount = dtype.vcount();
let result_dtype = if vcount > 1 { DType::Bool.vec(vcount) } else { DType::Bool };
Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), result_dtype))
}
)+
};
}
macro_rules! transcendental_ops {
($($method:ident => $op:ident),+ $(,)?) => {
$(
#[track_caller]
pub fn $method(self: &Arc<Self>) -> Result<Arc<Self>> {
let dtype = self.dtype();
ensure!(dtype.is_float(), InvalidDTypeForUnaryOpSnafu { operation: UnaryOp::$op, dtype });
Ok(Self::new(Op::Unary(UnaryOp::$op, self.clone()), dtype))
}
)+
};
}
macro_rules! scalar_ops {
($($method:ident => $op_method:ident),+ $(,)?) => {
$(
pub fn $method<T: IntoUOp>(lhs: Arc<Self>, rhs: T) -> Result<Arc<Self>> {
let rhs_uop = rhs.into_uop(lhs.dtype());
lhs.$op_method(&rhs_uop)
}
)+
};
}
macro_rules! panicking_binary_wrapper {
($($method:ident => $try_method:ident),+ $(,)?) => {
$(
#[doc = concat!("Panicking version of `", stringify!($try_method), "`.")]
#[doc = ""]
#[doc = "For use in pattern rewrites where types are validated."]
#[doc = "Panics on type mismatch."]
#[track_caller]
pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Arc<Self> {
self.$try_method(rhs).expect(concat!(stringify!($method), ": type mismatch"))
}
)+
};
}
impl UOp {
binary_arith_ops! {
try_add => Add,
try_sub => Sub,
try_mul => Mul,
}
division_ops! {
try_mod => Mod,
}
#[track_caller]
pub fn try_div(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
Self::check_division_by_zero(rhs)?;
let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
let op = if dtype.is_float() { BinaryOp::Fdiv } else { BinaryOp::Idiv };
Self::validate_binary_shapes(&lhs, &rhs, op)?;
Ok(Self::new(Op::Binary(op, lhs, rhs), dtype))
}
pub fn try_max(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::Max)?;
Ok(Self::new(Op::Binary(BinaryOp::Max, lhs, rhs), dtype))
}
pub fn try_pow(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::Pow)?;
Ok(Self::new(Op::Binary(BinaryOp::Pow, lhs, rhs), dtype))
}
#[track_caller]
pub fn neg(self: &Arc<Self>) -> Arc<Self> {
if self.dtype.is_bool() {
return self.not();
}
use crate::types::ConstValue;
let dtype = self.dtype.clone();
let neg_one = if dtype.is_float() { ConstValue::Float(-1.0) } else { ConstValue::Int(-1) };
let mut neg_one_uop = Self::const_(dtype.clone(), neg_one);
if let Ok(Some(shape)) = self.shape()
&& !shape.is_empty()
{
use crate::sint::SInt;
use smallvec::SmallVec;
let ones: SmallVec<[SInt; 4]> = shape.iter().map(|_| SInt::from(1)).collect();
neg_one_uop = neg_one_uop.try_reshape(&ones).expect("neg: reshape failed");
neg_one_uop = neg_one_uop.try_expand(shape).expect("neg: expand failed");
}
self.mul(&neg_one_uop)
}
#[track_caller]
pub fn abs(self: &Arc<Self>) -> Arc<Self> {
let dtype = self.dtype.clone();
Self::new(Op::Unary(UnaryOp::Abs, self.clone()), dtype)
}
#[track_caller]
pub fn square(self: &Arc<Self>) -> Arc<Self> {
let dtype = self.dtype();
Self::new(Op::Unary(UnaryOp::Square, self.clone()), dtype)
}
pub fn sign(self: &Arc<Self>) -> Arc<Self> {
let dtype = self.dtype();
Self::new(Op::Unary(UnaryOp::Sign, self.clone()), dtype)
}
scalar_ops! {
try_add_scalar => try_add,
try_sub_scalar => try_sub,
try_mul_scalar => try_mul,
try_mod_scalar => try_mod,
}
transcendental_ops! {
try_sqrt => Sqrt,
try_rsqrt => Rsqrt,
try_exp => Exp,
try_exp2 => Exp2,
try_log => Log,
try_log2 => Log2,
try_sin => Sin,
try_cos => Cos,
try_tan => Tan,
}
#[track_caller]
pub fn erf(self: &Arc<Self>) -> Result<Arc<Self>> {
let dtype = self.dtype();
ensure!(dtype.is_float(), InvalidDTypeForUnaryOpSnafu { operation: UnaryOp::Erf, dtype });
Ok(Self::new(Op::Unary(UnaryOp::Erf, self.clone()), dtype))
}
#[track_caller]
pub fn try_reciprocal(operand: &Arc<Self>) -> Result<Arc<Self>> {
let dtype = operand.dtype();
ensure!(dtype.is_float(), InvalidDTypeForUnaryOpSnafu { operation: UnaryOp::Reciprocal, dtype });
Ok(Self::new(Op::Unary(UnaryOp::Reciprocal, operand.clone()), dtype))
}
#[track_caller]
pub fn trunc(operand: Arc<Self>) -> Arc<Self> {
let dtype = operand.dtype();
Self::new(Op::Unary(UnaryOp::Trunc, operand), dtype)
}
#[track_caller]
pub fn floor(operand: Arc<Self>) -> Arc<Self> {
let dtype = operand.dtype();
Self::new(Op::Unary(UnaryOp::Floor, operand), dtype)
}
#[track_caller]
pub fn ceil(operand: Arc<Self>) -> Arc<Self> {
let dtype = operand.dtype();
Self::new(Op::Unary(UnaryOp::Ceil, operand), dtype)
}
pub fn round(operand: Arc<Self>) -> Arc<Self> {
let dtype = operand.dtype();
Self::new(Op::Unary(UnaryOp::Round, operand), dtype)
}
bitwise_binary_ops! {
try_and_op => And,
try_or_op => Or,
try_xor_op => Xor,
}
shift_ops! {
try_shl_op => Shl,
try_shr_op => Shr,
}
#[track_caller]
pub fn not(self: &Arc<Self>) -> Arc<Self> {
let dtype = self.dtype.clone();
Self::new(Op::Unary(UnaryOp::Not, self.clone()), dtype)
}
cmp_ops! {
try_cmplt => Lt,
try_cmple => Le,
try_cmpeq => Eq,
try_cmpne => Ne,
try_cmpgt => Gt,
try_cmpge => Ge,
}
#[track_caller]
pub fn try_where(condition: Arc<Self>, true_val: Arc<Self>, false_val: Arc<Self>) -> Result<Arc<Self>> {
let cond_dtype = condition.dtype();
ensure!(cond_dtype.is_bool(), WhereConditionNotBoolSnafu { actual: cond_dtype });
let dtype = if matches!(true_val.op, Op::Invalid) { false_val.dtype() } else { true_val.dtype() };
let true_val = if matches!(true_val.op, Op::Invalid) && true_val.dtype() != dtype {
Self::new(Op::Invalid, dtype.clone())
} else {
true_val
};
let false_val = if matches!(false_val.op, Op::Invalid) && false_val.dtype() != dtype {
Self::new(Op::Invalid, dtype.clone())
} else {
false_val
};
Self::validate_ternary_shapes(&true_val, &false_val)?;
Ok(Self::new(Op::Ternary(TernaryOp::Where, condition, true_val, false_val), dtype))
}
pub fn try_mulacc(a: Arc<Self>, b: Arc<Self>, c: Arc<Self>) -> Result<Arc<Self>> {
if a.dtype() != b.dtype() || a.dtype() != c.dtype() {
return crate::error::MulAccDtypeMismatchSnafu {
a_dtype: a.dtype(),
b_dtype: b.dtype(),
c_dtype: c.dtype(),
}
.fail();
}
let dtype = a.dtype();
Self::validate_ternary_shapes(&a, &b)?;
Self::validate_ternary_shapes(&a, &c)?;
Ok(Self::new(Op::Ternary(TernaryOp::MulAcc, a, b, c), dtype))
}
panicking_binary_wrapper! {
add => try_add,
sub => try_sub,
mul => try_mul,
idiv => try_div,
mod_ => try_mod,
max => try_max,
and_ => try_and_op,
or_ => try_or_op,
xor => try_xor_op,
shl => try_shl_op,
shr => try_shr_op,
lt => try_cmplt,
le => try_cmple,
gt => try_cmpgt,
ge => try_cmpge,
eq => try_cmpeq,
ne => try_cmpne,
}
pub fn alu(op: BinaryOp, lhs: Arc<Self>, rhs: Arc<Self>) -> Arc<Self> {
let dtype = if op.is_comparison() { DType::Bool } else { lhs.dtype() };
Self::new(Op::Binary(op, lhs, rhs), dtype)
}
pub fn threefry(lhs: Arc<Self>, rhs: Arc<Self>) -> Result<Arc<Self>> {
let dtype = DType::UInt64; Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::Threefry)?;
Ok(Self::new(Op::Binary(BinaryOp::Threefry, lhs, rhs), dtype))
}
}