1use std::sync::Arc;
14
15use morok_dtype::DType;
16use snafu::ensure;
17
18use crate::error::{InvalidDTypeForUnaryOpSnafu, WhereConditionNotBoolSnafu};
19use crate::op::Op;
20use crate::types::{BinaryOp, TernaryOp, UnaryOp};
21use crate::uop::UOp;
22use crate::{IntoUOp, Result};
23
24macro_rules! binary_arith_ops {
30 ($($method:ident => $op:ident),+ $(,)?) => {
31 $(
32 #[track_caller]
33 pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
34 let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
35 Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
36 Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), dtype))
37 }
38 )+
39 };
40}
41
42macro_rules! division_ops {
44 ($($method:ident => $op:ident),+ $(,)?) => {
45 $(
46 #[track_caller]
47 pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
48 Self::check_division_by_zero(rhs)?;
49 let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
50 Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
51 Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), dtype))
52 }
53 )+
54 };
55}
56
57macro_rules! bitwise_binary_ops {
59 ($($method:ident => $op:ident),+ $(,)?) => {
60 $(
61 pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
62 let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
63 Self::check_bitwise_dtype(dtype.clone(), BinaryOp::$op)?;
64 Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
65 Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), dtype))
66 }
67 )+
68 };
69}
70
71macro_rules! shift_ops {
73 ($($method:ident => $op:ident),+ $(,)?) => {
74 $(
75 pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
76 let dtype = self.dtype();
77 Self::check_bitwise_dtype(dtype.clone(), BinaryOp::$op)?;
78 Self::validate_binary_shapes(self, rhs, BinaryOp::$op)?;
79 Ok(Self::new(Op::Binary(BinaryOp::$op, self.clone(), rhs.clone()), dtype))
80 }
81 )+
82 };
83}
84
85macro_rules! cmp_ops {
88 ($($method:ident => $op:ident),+ $(,)?) => {
89 $(
90 #[track_caller]
91 pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
92 let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
94 Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::$op)?;
95 let vcount = dtype.vcount();
97 let result_dtype = if vcount > 1 { DType::Bool.vec(vcount) } else { DType::Bool };
98 Ok(Self::new(Op::Binary(BinaryOp::$op, lhs, rhs), result_dtype))
99 }
100 )+
101 };
102}
103
104macro_rules! transcendental_ops {
106 ($($method:ident => $op:ident),+ $(,)?) => {
107 $(
108 #[track_caller]
109 pub fn $method(self: &Arc<Self>) -> Result<Arc<Self>> {
110 let dtype = self.dtype();
111 ensure!(dtype.is_float(), InvalidDTypeForUnaryOpSnafu { operation: UnaryOp::$op, dtype });
112 Ok(Self::new(Op::Unary(UnaryOp::$op, self.clone()), dtype))
113 }
114 )+
115 };
116}
117
118macro_rules! scalar_ops {
120 ($($method:ident => $op_method:ident),+ $(,)?) => {
121 $(
122 pub fn $method<T: IntoUOp>(lhs: Arc<Self>, rhs: T) -> Result<Arc<Self>> {
123 let rhs_uop = rhs.into_uop(lhs.dtype());
124 lhs.$op_method(&rhs_uop)
125 }
126 )+
127 };
128}
129
130macro_rules! panicking_binary_wrapper {
139 ($($method:ident => $try_method:ident),+ $(,)?) => {
140 $(
141 #[doc = concat!("Panicking version of `", stringify!($try_method), "`.")]
142 #[doc = ""]
143 #[doc = "For use in pattern rewrites where types are validated."]
144 #[doc = "Panics on type mismatch."]
145 #[track_caller]
146 pub fn $method(self: &Arc<Self>, rhs: &Arc<Self>) -> Arc<Self> {
147 self.$try_method(rhs).expect(concat!(stringify!($method), ": type mismatch"))
148 }
149 )+
150 };
151}
152
153impl UOp {
154 binary_arith_ops! {
159 try_add => Add,
160 try_sub => Sub,
161 try_mul => Mul,
162 }
163
164 division_ops! {
165 try_mod => Mod,
166 }
167
168 #[track_caller]
172 pub fn try_div(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
173 Self::check_division_by_zero(rhs)?;
174 let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
175
176 let op = if dtype.is_float() { BinaryOp::Fdiv } else { BinaryOp::Idiv };
178
179 Self::validate_binary_shapes(&lhs, &rhs, op)?;
180 Ok(Self::new(Op::Binary(op, lhs, rhs), dtype))
181 }
182
183 pub fn try_max(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
185 let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
186 Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::Max)?;
187 Ok(Self::new(Op::Binary(BinaryOp::Max, lhs, rhs), dtype))
188 }
189
190 pub fn try_pow(self: &Arc<Self>, rhs: &Arc<Self>) -> Result<Arc<Self>> {
192 let (lhs, rhs, dtype) = Self::promote_and_cast(self.clone(), rhs.clone())?;
193 Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::Pow)?;
194 Ok(Self::new(Op::Binary(BinaryOp::Pow, lhs, rhs), dtype))
195 }
196
197 #[track_caller]
209 pub fn neg(self: &Arc<Self>) -> Arc<Self> {
210 if self.dtype.is_bool() {
212 return self.not();
213 }
214 use crate::types::ConstValue;
215 let dtype = self.dtype.clone();
216 let neg_one = if dtype.is_float() { ConstValue::Float(-1.0) } else { ConstValue::Int(-1) };
219 let mut neg_one_uop = Self::const_(dtype.clone(), neg_one);
220
221 if let Ok(Some(shape)) = self.shape()
224 && !shape.is_empty()
225 {
226 use crate::sint::SInt;
227 use smallvec::SmallVec;
228 let ones: SmallVec<[SInt; 4]> = shape.iter().map(|_| SInt::from(1)).collect();
229 neg_one_uop = neg_one_uop.try_reshape(&ones).expect("neg: reshape failed");
230 neg_one_uop = neg_one_uop.try_expand(shape).expect("neg: expand failed");
231 }
232
233 self.mul(&neg_one_uop)
234 }
235
236 #[track_caller]
238 pub fn abs(self: &Arc<Self>) -> Arc<Self> {
239 let dtype = self.dtype.clone();
240 Self::new(Op::Unary(UnaryOp::Abs, self.clone()), dtype)
241 }
242
243 #[track_caller]
245 pub fn square(self: &Arc<Self>) -> Arc<Self> {
246 let dtype = self.dtype();
247 Self::new(Op::Unary(UnaryOp::Square, self.clone()), dtype)
248 }
249
250 pub fn sign(self: &Arc<Self>) -> Arc<Self> {
252 let dtype = self.dtype();
253 Self::new(Op::Unary(UnaryOp::Sign, self.clone()), dtype)
254 }
255
256 scalar_ops! {
261 try_add_scalar => try_add,
262 try_sub_scalar => try_sub,
263 try_mul_scalar => try_mul,
264 try_mod_scalar => try_mod,
265 }
266
267 transcendental_ops! {
272 try_sqrt => Sqrt,
273 try_rsqrt => Rsqrt,
274 try_exp => Exp,
275 try_exp2 => Exp2,
276 try_log => Log,
277 try_log2 => Log2,
278 try_sin => Sin,
279 try_cos => Cos,
280 try_tan => Tan,
281 }
282
283 #[track_caller]
285 pub fn erf(self: &Arc<Self>) -> Result<Arc<Self>> {
286 let dtype = self.dtype();
287 ensure!(dtype.is_float(), InvalidDTypeForUnaryOpSnafu { operation: UnaryOp::Erf, dtype });
288 Ok(Self::new(Op::Unary(UnaryOp::Erf, self.clone()), dtype))
289 }
290
291 #[track_caller]
293 pub fn try_reciprocal(operand: &Arc<Self>) -> Result<Arc<Self>> {
294 let dtype = operand.dtype();
295 ensure!(dtype.is_float(), InvalidDTypeForUnaryOpSnafu { operation: UnaryOp::Reciprocal, dtype });
296 Ok(Self::new(Op::Unary(UnaryOp::Reciprocal, operand.clone()), dtype))
297 }
298
299 #[track_caller]
305 pub fn trunc(operand: Arc<Self>) -> Arc<Self> {
306 let dtype = operand.dtype();
307 Self::new(Op::Unary(UnaryOp::Trunc, operand), dtype)
308 }
309
310 #[track_caller]
312 pub fn floor(operand: Arc<Self>) -> Arc<Self> {
313 let dtype = operand.dtype();
314 Self::new(Op::Unary(UnaryOp::Floor, operand), dtype)
315 }
316
317 #[track_caller]
319 pub fn ceil(operand: Arc<Self>) -> Arc<Self> {
320 let dtype = operand.dtype();
321 Self::new(Op::Unary(UnaryOp::Ceil, operand), dtype)
322 }
323
324 pub fn round(operand: Arc<Self>) -> Arc<Self> {
326 let dtype = operand.dtype();
327 Self::new(Op::Unary(UnaryOp::Round, operand), dtype)
328 }
329
330 bitwise_binary_ops! {
335 try_and_op => And,
336 try_or_op => Or,
337 try_xor_op => Xor,
338 }
339
340 shift_ops! {
341 try_shl_op => Shl,
342 try_shr_op => Shr,
343 }
344
345 #[track_caller]
347 pub fn not(self: &Arc<Self>) -> Arc<Self> {
348 let dtype = self.dtype.clone();
349 Self::new(Op::Unary(UnaryOp::Not, self.clone()), dtype)
350 }
351
352 cmp_ops! {
357 try_cmplt => Lt,
358 try_cmple => Le,
359 try_cmpeq => Eq,
360 try_cmpne => Ne,
361 try_cmpgt => Gt,
362 try_cmpge => Ge,
363 }
364
365 #[track_caller]
374 pub fn try_where(condition: Arc<Self>, true_val: Arc<Self>, false_val: Arc<Self>) -> Result<Arc<Self>> {
375 let cond_dtype = condition.dtype();
376 ensure!(cond_dtype.is_bool(), WhereConditionNotBoolSnafu { actual: cond_dtype });
377
378 let dtype = if matches!(true_val.op, Op::Invalid) { false_val.dtype() } else { true_val.dtype() };
382 let true_val = if matches!(true_val.op, Op::Invalid) && true_val.dtype() != dtype {
383 Self::new(Op::Invalid, dtype.clone())
384 } else {
385 true_val
386 };
387 let false_val = if matches!(false_val.op, Op::Invalid) && false_val.dtype() != dtype {
388 Self::new(Op::Invalid, dtype.clone())
389 } else {
390 false_val
391 };
392 Self::validate_ternary_shapes(&true_val, &false_val)?;
393 Ok(Self::new(Op::Ternary(TernaryOp::Where, condition, true_val, false_val), dtype))
394 }
395
396 pub fn try_mulacc(a: Arc<Self>, b: Arc<Self>, c: Arc<Self>) -> Result<Arc<Self>> {
401 if a.dtype() != b.dtype() || a.dtype() != c.dtype() {
403 return crate::error::MulAccDtypeMismatchSnafu {
404 a_dtype: a.dtype(),
405 b_dtype: b.dtype(),
406 c_dtype: c.dtype(),
407 }
408 .fail();
409 }
410 let dtype = a.dtype();
411 Self::validate_ternary_shapes(&a, &b)?;
413 Self::validate_ternary_shapes(&a, &c)?;
414 Ok(Self::new(Op::Ternary(TernaryOp::MulAcc, a, b, c), dtype))
415 }
416
417 panicking_binary_wrapper! {
429 add => try_add,
431 sub => try_sub,
432 mul => try_mul,
433 idiv => try_div,
434 mod_ => try_mod,
435 max => try_max,
436
437 and_ => try_and_op,
439 or_ => try_or_op,
440 xor => try_xor_op,
441 shl => try_shl_op,
442 shr => try_shr_op,
443
444 lt => try_cmplt,
446 le => try_cmple,
447 gt => try_cmpgt,
448 ge => try_cmpge,
449 eq => try_cmpeq,
450 ne => try_cmpne,
451 }
452
453 pub fn alu(op: BinaryOp, lhs: Arc<Self>, rhs: Arc<Self>) -> Arc<Self> {
459 let dtype = if op.is_comparison() { DType::Bool } else { lhs.dtype() };
460 Self::new(Op::Binary(op, lhs, rhs), dtype)
461 }
462
463 pub fn threefry(lhs: Arc<Self>, rhs: Arc<Self>) -> Result<Arc<Self>> {
469 let dtype = DType::UInt64; Self::validate_binary_shapes(&lhs, &rhs, BinaryOp::Threefry)?;
471 Ok(Self::new(Op::Binary(BinaryOp::Threefry, lhs, rhs), dtype))
472 }
473}