1#![allow(clippy::redundant_closure_call)]
4use crate::Tensor;
5use float8::F8E4M3 as f8e4m3;
6use half::{bf16, f16};
7use num_traits::float::Float;
8
9#[derive(Clone, Copy, PartialEq, Eq)]
10pub enum CmpOp {
11 Eq,
12 Ne,
13 Le,
14 Ge,
15 Lt,
16 Gt,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ReduceOp {
21 Sum,
22 Min,
23 Max,
24 ArgMin,
25 ArgMax,
26}
27
28impl ReduceOp {
29 pub(crate) fn name(&self) -> &'static str {
30 match self {
31 Self::ArgMax => "argmax",
32 Self::ArgMin => "argmin",
33 Self::Min => "min",
34 Self::Max => "max",
35 Self::Sum => "sum",
36 }
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum BinaryOp {
43 Add,
44 Mul,
45 Sub,
46 Div,
47 Maximum,
48 Minimum,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum UnaryOp {
54 Exp,
55 Log,
56 Sin,
57 Cos,
58 Abs,
59 Neg,
60 Recip,
61 Sqr,
62 Sqrt,
63 Gelu,
64 GeluErf,
65 Erf,
66 Relu,
67 Silu,
68 Tanh,
69 Floor,
70 Ceil,
71 Round,
72 Sign,
73}
74
75#[derive(Clone)]
76pub enum Op {
77 Binary(Tensor, Tensor, BinaryOp),
78 Unary(Tensor, UnaryOp),
79 Cmp(Tensor, CmpOp),
80 Reduce(Tensor, ReduceOp, Vec<usize>),
82 Matmul(Tensor, Tensor),
83 Gather(Tensor, Tensor, usize),
84 Scatter(Tensor, Tensor, Tensor, usize),
85 ScatterAdd(Tensor, Tensor, Tensor, usize),
86 IndexSelect(Tensor, Tensor, usize),
87 IndexAdd(Tensor, Tensor, Tensor, usize),
88 WhereCond(Tensor, Tensor, Tensor),
89
90 #[allow(dead_code)]
91 Conv1D {
92 arg: Tensor,
93 kernel: Tensor,
94 padding: usize,
95 stride: usize,
96 dilation: usize,
97 },
98
99 #[allow(dead_code)]
100 ConvTranspose1D {
101 arg: Tensor,
102 kernel: Tensor,
103 padding: usize,
104 output_padding: usize,
105 stride: usize,
106 dilation: usize,
107 },
108
109 #[allow(dead_code)]
110 Conv2D {
111 arg: Tensor,
112 kernel: Tensor,
113 padding: usize,
114 stride: usize,
115 dilation: usize,
116 },
117
118 #[allow(dead_code)]
119 ConvTranspose2D {
120 arg: Tensor,
121 kernel: Tensor,
122 padding: usize,
123 output_padding: usize,
124 stride: usize,
125 dilation: usize,
126 },
127
128 AvgPool2D {
129 arg: Tensor,
130 kernel_size: (usize, usize),
131 stride: (usize, usize),
132 },
133
134 MaxPool2D {
135 arg: Tensor,
136 kernel_size: (usize, usize),
137 stride: (usize, usize),
138 },
139
140 UpsampleNearest1D {
141 arg: Tensor,
142 target_size: usize,
143 },
144 UpsampleNearest2D {
145 arg: Tensor,
146 target_h: usize,
147 target_w: usize,
148 },
149 UpsampleBilinear2D {
150 arg: Tensor,
151 target_h: usize,
152 target_w: usize,
153 align_corners: bool,
154 },
155
156 Cat(Vec<Tensor>, usize),
157
158 #[allow(dead_code)] Affine {
160 arg: Tensor,
161 mul: f64,
162 add: f64,
163 },
164 ToDType(Tensor),
165 Copy(Tensor),
166 Broadcast(Tensor),
167 Narrow(Tensor, usize, usize, usize),
168 SliceScatter0(Tensor, Tensor, usize),
169 Reshape(Tensor),
170 ToDevice(Tensor),
171 Transpose(Tensor, usize, usize),
172 Permute(Tensor, Vec<usize>),
173 Elu(Tensor, f64),
174 Powf(Tensor, f64),
175 CustomOp1(
176 Tensor,
177 std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
178 ),
179 CustomOp2(
180 Tensor,
181 Tensor,
182 std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
183 ),
184 CustomOp3(
185 Tensor,
186 Tensor,
187 Tensor,
188 std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
189 ),
190}
191
192pub trait UnaryOpT {
193 const NAME: &'static str;
194 const KERNEL: &'static str;
195 const V: Self;
196 fn bf16(v1: bf16) -> bf16;
197 fn f16(v1: f16) -> f16;
198 fn f32(v1: f32) -> f32;
199 fn f64(v1: f64) -> f64;
200 fn u8(v1: u8) -> u8;
201 fn u32(v1: u32) -> u32;
202 fn i16(v1: i16) -> i16;
203 fn i32(v1: i32) -> i32;
204 fn i64(v1: i64) -> i64;
205 fn f8e4m3(v1: f8e4m3) -> f8e4m3;
206
207 const BF16_VEC: bool = false;
210 fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
211 const F16_VEC: bool = false;
212 fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
213 const F32_VEC: bool = false;
214 fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
215 const F64_VEC: bool = false;
216 fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
217}
218
219pub trait BinaryOpT {
220 const NAME: &'static str;
221 const KERNEL: &'static str;
222 const V: Self;
223 fn bf16(v1: bf16, v2: bf16) -> bf16;
224 fn f16(v1: f16, v2: f16) -> f16;
225 fn f32(v1: f32, v2: f32) -> f32;
226 fn f64(v1: f64, v2: f64) -> f64;
227 fn u8(v1: u8, v2: u8) -> u8;
228 fn u32(v1: u32, v2: u32) -> u32;
229 fn i16(v1: i16, v2: i16) -> i16;
230 fn i32(v1: i32, v2: i32) -> i32;
231 fn i64(v1: i64, v2: i64) -> i64;
232 fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3;
233
234 const BF16_VEC: bool = false;
235 fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
236 const F16_VEC: bool = false;
237 fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {}
238 const F32_VEC: bool = false;
239 fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {}
240 const F64_VEC: bool = false;
241 fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {}
242 const U8_VEC: bool = false;
243 fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
244 const U32_VEC: bool = false;
245 fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
246 const I64_VEC: bool = false;
247 fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
248}
249
250pub struct Add;
251pub struct Div;
252pub struct Mul;
253pub struct Sub;
254pub struct Maximum;
255pub struct Minimum;
256pub struct Exp;
257pub struct Log;
258pub struct Sin;
259pub struct Cos;
260pub struct Abs;
261pub struct Neg;
262pub struct Recip;
263pub struct Sqr;
264pub struct Sqrt;
265pub struct Gelu;
266pub struct GeluErf;
267pub struct Erf;
268pub struct Relu;
269pub struct Silu;
270pub struct Tanh;
271pub struct Floor;
272pub struct Ceil;
273pub struct Round;
274pub struct Sign;
275
276macro_rules! bin_op {
277 ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
278 impl BinaryOpT for $op {
279 const NAME: &'static str = $name;
280 const KERNEL: &'static str = concat!("b", $name);
281 const V: Self = $op;
282 #[inline(always)]
283 fn bf16(v1: bf16, v2: bf16) -> bf16 {
284 $e(v1, v2)
285 }
286 #[inline(always)]
287 fn f16(v1: f16, v2: f16) -> f16 {
288 $e(v1, v2)
289 }
290 #[inline(always)]
291 fn f32(v1: f32, v2: f32) -> f32 {
292 $e(v1, v2)
293 }
294 #[inline(always)]
295 fn f64(v1: f64, v2: f64) -> f64 {
296 $e(v1, v2)
297 }
298 #[inline(always)]
299 fn u8(v1: u8, v2: u8) -> u8 {
300 $e(v1, v2)
301 }
302 #[inline(always)]
303 fn u32(v1: u32, v2: u32) -> u32 {
304 $e(v1, v2)
305 }
306 #[inline(always)]
307 fn i16(v1: i16, v2: i16) -> i16 {
308 $e(v1, v2)
309 }
310 #[inline(always)]
311 fn i32(v1: i32, v2: i32) -> i32 {
312 $e(v1, v2)
313 }
314 #[inline(always)]
315 fn i64(v1: i64, v2: i64) -> i64 {
316 $e(v1, v2)
317 }
318 #[inline(always)]
319 fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3 {
320 $e(v1, v2)
321 }
322
323 #[cfg(feature = "mkl")]
324 const F32_VEC: bool = true;
325 #[cfg(feature = "mkl")]
326 const F64_VEC: bool = true;
327 #[cfg(feature = "mkl")]
328 #[inline(always)]
329 fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
330 crate::mkl::$f32_vec(xs1, xs2, ys)
331 }
332 #[cfg(feature = "mkl")]
333 #[inline(always)]
334 fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
335 crate::mkl::$f64_vec(xs1, xs2, ys)
336 }
337
338 #[cfg(feature = "accelerate")]
339 const F32_VEC: bool = true;
340 #[cfg(feature = "accelerate")]
341 const F64_VEC: bool = true;
342 #[cfg(feature = "accelerate")]
343 #[inline(always)]
344 fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
345 crate::accelerate::$f32_vec(xs1, xs2, ys)
346 }
347 #[cfg(feature = "accelerate")]
348 #[inline(always)]
349 fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
350 crate::accelerate::$f64_vec(xs1, xs2, ys)
351 }
352 }
353 };
354}
355
356bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
357bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
358bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
359bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
360bin_op!(
361 Minimum,
362 "minimum",
363 |v1, v2| if v1 > v2 { v2 } else { v1 },
364 vs_min,
365 vd_min
366);
367bin_op!(
368 Maximum,
369 "maximum",
370 |v1, v2| if v1 < v2 { v2 } else { v1 },
371 vs_max,
372 vd_max
373);
374
375#[allow(clippy::redundant_closure_call)]
376macro_rules! unary_op {
377 ($op: ident, $name: literal, $a: ident, $e: expr) => {
378 impl UnaryOpT for $op {
379 const NAME: &'static str = $name;
380 const KERNEL: &'static str = concat!("u", $name);
381 const V: Self = $op;
382 #[inline(always)]
383 fn bf16($a: bf16) -> bf16 {
384 $e
385 }
386 #[inline(always)]
387 fn f16($a: f16) -> f16 {
388 $e
389 }
390 #[inline(always)]
391 fn f32($a: f32) -> f32 {
392 $e
393 }
394 #[inline(always)]
395 fn f64($a: f64) -> f64 {
396 $e
397 }
398 #[inline(always)]
399 fn u8(_: u8) -> u8 {
400 todo!("no unary function for u8")
401 }
402 #[inline(always)]
403 fn u32(_: u32) -> u32 {
404 todo!("no unary function for u32")
405 }
406 #[inline(always)]
407 fn i16(_: i16) -> i16 {
408 todo!("no unary function for i16")
409 }
410 #[inline(always)]
411 fn i32(_: i32) -> i32 {
412 todo!("no unary function for i32")
413 }
414 #[inline(always)]
415 fn i64(_: i64) -> i64 {
416 todo!("no unary function for i64")
417 }
418 #[inline(always)]
419 fn f8e4m3($a: f8e4m3) -> f8e4m3 {
420 $e
421 }
422 }
423 };
424
425 ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
426 impl UnaryOpT for $op {
427 const NAME: &'static str = $name;
428 const KERNEL: &'static str = concat!("u", $name);
429 const V: Self = $op;
430 #[inline(always)]
431 fn bf16($a: bf16) -> bf16 {
432 $e
433 }
434 #[inline(always)]
435 fn f16($a: f16) -> f16 {
436 $e
437 }
438 #[inline(always)]
439 fn f32($a: f32) -> f32 {
440 $e
441 }
442 #[inline(always)]
443 fn f64($a: f64) -> f64 {
444 $e
445 }
446 #[inline(always)]
447 fn u8(_: u8) -> u8 {
448 todo!("no unary function for u8")
449 }
450 #[inline(always)]
451 fn u32(_: u32) -> u32 {
452 todo!("no unary function for u32")
453 }
454 #[inline(always)]
455 fn i16(_: i16) -> i16 {
456 todo!("no unary function for i16")
457 }
458 #[inline(always)]
459 fn i32(_: i32) -> i32 {
460 todo!("no unary function for i32")
461 }
462 #[inline(always)]
463 fn i64(_: i64) -> i64 {
464 todo!("no unary function for i64")
465 }
466 #[inline(always)]
467 fn f8e4m3($a: f8e4m3) -> f8e4m3 {
468 $e
469 }
470
471 #[cfg(feature = "mkl")]
472 const F32_VEC: bool = true;
473 #[cfg(feature = "mkl")]
474 const F64_VEC: bool = true;
475 #[cfg(feature = "mkl")]
476 #[inline(always)]
477 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
478 crate::mkl::$f32_vec(xs, ys)
479 }
480 #[cfg(feature = "mkl")]
481 #[inline(always)]
482 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
483 crate::mkl::$f64_vec(xs, ys)
484 }
485
486 #[cfg(feature = "accelerate")]
487 const F32_VEC: bool = true;
488 #[cfg(feature = "accelerate")]
489 const F64_VEC: bool = true;
490 #[cfg(feature = "accelerate")]
491 #[inline(always)]
492 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
493 crate::accelerate::$f32_vec(xs, ys)
494 }
495 #[cfg(feature = "accelerate")]
496 #[inline(always)]
497 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
498 crate::accelerate::$f64_vec(xs, ys)
499 }
500 }
501 };
502}
503
504unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
505unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
506unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
507unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
508unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
509unary_op!(Neg, "neg", v, -v);
510unary_op!(Recip, "recip", v, v.recip());
511unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
512unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
513
514#[allow(clippy::excessive_precision)]
517const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
518#[allow(clippy::excessive_precision)]
519const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
520
521impl UnaryOpT for Gelu {
525 const NAME: &'static str = "gelu";
526 const V: Self = Gelu;
527 #[inline(always)]
528 fn bf16(v: bf16) -> bf16 {
529 bf16::from_f32_const(0.5)
530 * v
531 * (bf16::ONE
532 + bf16::tanh(
533 bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
534 * v
535 * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
536 ))
537 }
538 #[inline(always)]
539 fn f16(v: f16) -> f16 {
540 f16::from_f32_const(0.5)
541 * v
542 * (f16::ONE
543 + f16::tanh(
544 f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
545 * v
546 * (f16::ONE + f16::from_f32_const(0.044715) * v * v),
547 ))
548 }
549 #[inline(always)]
550 fn f32(v: f32) -> f32 {
551 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
552 }
553 #[inline(always)]
554 fn f64(v: f64) -> f64 {
555 0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
556 }
557 #[inline(always)]
558 fn u8(_: u8) -> u8 {
559 0
560 }
561 #[inline(always)]
562 fn u32(_: u32) -> u32 {
563 0
564 }
565 #[inline(always)]
566 fn i16(_: i16) -> i16 {
567 0
568 }
569 #[inline(always)]
570 fn i32(_: i32) -> i32 {
571 0
572 }
573 #[inline(always)]
574 fn i64(_: i64) -> i64 {
575 0
576 }
577 #[inline(always)]
578 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
579 f8e4m3::from_f32(0.5)
580 * v
581 * (f8e4m3::ONE
582 + f8e4m3::tanh(
583 f8e4m3::from_f32(SQRT_TWO_OVER_PI_F32)
584 * v
585 * (f8e4m3::ONE + f8e4m3::from_f32(0.044715) * v * v),
586 ))
587 }
588 const KERNEL: &'static str = "ugelu";
589
590 #[cfg(feature = "mkl")]
591 const F32_VEC: bool = true;
592
593 #[cfg(feature = "mkl")]
594 #[inline(always)]
595 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
596 crate::mkl::vs_gelu(xs, ys)
597 }
598
599 #[cfg(feature = "mkl")]
600 const F64_VEC: bool = true;
601
602 #[cfg(feature = "mkl")]
603 #[inline(always)]
604 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
605 crate::mkl::vd_gelu(xs, ys)
606 }
607
608 #[cfg(feature = "accelerate")]
609 const F32_VEC: bool = true;
610
611 #[cfg(feature = "accelerate")]
612 #[inline(always)]
613 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
614 crate::accelerate::vs_gelu(xs, ys)
615 }
616
617 #[cfg(feature = "accelerate")]
618 const F64_VEC: bool = true;
619
620 #[cfg(feature = "accelerate")]
621 #[inline(always)]
622 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
623 crate::accelerate::vd_gelu(xs, ys)
624 }
625}
626
627impl UnaryOpT for Erf {
630 const NAME: &'static str = "erf";
631 const KERNEL: &'static str = "uerf";
632 const V: Self = Erf;
633 #[inline(always)]
634 fn bf16(v: bf16) -> bf16 {
635 bf16::from_f64(Self::f64(v.to_f64()))
636 }
637 #[inline(always)]
638 fn f16(v: f16) -> f16 {
639 f16::from_f64(Self::f64(v.to_f64()))
640 }
641 #[inline(always)]
642 fn f32(v: f32) -> f32 {
643 crate::cpu::erf::erf_f32(v)
644 }
645 #[inline(always)]
646 fn f64(v: f64) -> f64 {
647 crate::cpu::erf::erf_f64(v)
648 }
649 #[inline(always)]
650 fn u8(_: u8) -> u8 {
651 0
652 }
653 #[inline(always)]
654 fn u32(_: u32) -> u32 {
655 0
656 }
657 #[inline(always)]
658 fn i16(_: i16) -> i16 {
659 0
660 }
661 #[inline(always)]
662 fn i32(_: i32) -> i32 {
663 0
664 }
665 #[inline(always)]
666 fn i64(_: i64) -> i64 {
667 0
668 }
669 #[inline(always)]
670 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
671 f8e4m3::from_f64(Self::f64(v.to_f64()))
672 }
673}
674
675impl UnaryOpT for Silu {
677 const NAME: &'static str = "silu";
678 const V: Self = Silu;
679 #[inline(always)]
680 fn bf16(v: bf16) -> bf16 {
681 v / (bf16::ONE + (-v).exp())
682 }
683 #[inline(always)]
684 fn f16(v: f16) -> f16 {
685 v / (f16::ONE + (-v).exp())
686 }
687 #[inline(always)]
688 fn f32(v: f32) -> f32 {
689 v / (1.0 + (-v).exp())
690 }
691 #[inline(always)]
692 fn f64(v: f64) -> f64 {
693 v / (1.0 + (-v).exp())
694 }
695 #[inline(always)]
696 fn u8(_: u8) -> u8 {
697 0
698 }
699 #[inline(always)]
700 fn u32(_: u32) -> u32 {
701 0
702 }
703 #[inline(always)]
704 fn i16(_: i16) -> i16 {
705 0
706 }
707 #[inline(always)]
708 fn i32(_: i32) -> i32 {
709 0
710 }
711 #[inline(always)]
712 fn i64(_: i64) -> i64 {
713 0
714 }
715 #[inline(always)]
716 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
717 v / (f8e4m3::ONE + (-v).exp())
718 }
719 const KERNEL: &'static str = "usilu";
720
721 #[cfg(feature = "mkl")]
722 const F32_VEC: bool = true;
723
724 #[cfg(feature = "mkl")]
725 #[inline(always)]
726 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
727 crate::mkl::vs_silu(xs, ys)
728 }
729
730 #[cfg(feature = "mkl")]
731 const F64_VEC: bool = true;
732
733 #[cfg(feature = "mkl")]
734 #[inline(always)]
735 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
736 crate::mkl::vd_silu(xs, ys)
737 }
738
739 #[cfg(feature = "accelerate")]
740 const F32_VEC: bool = true;
741
742 #[cfg(feature = "accelerate")]
743 #[inline(always)]
744 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
745 crate::accelerate::vs_silu(xs, ys)
746 }
747
748 #[cfg(feature = "accelerate")]
749 const F64_VEC: bool = true;
750
751 #[cfg(feature = "accelerate")]
752 #[inline(always)]
753 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
754 crate::accelerate::vd_silu(xs, ys)
755 }
756}
757
758impl UnaryOpT for Abs {
759 const NAME: &'static str = "abs";
760 const KERNEL: &'static str = "uabs";
761 const V: Self = Abs;
762 #[inline(always)]
763 fn bf16(v: bf16) -> bf16 {
764 v.abs()
765 }
766 #[inline(always)]
767 fn f16(v: f16) -> f16 {
768 v.abs()
769 }
770 #[inline(always)]
771 fn f32(v: f32) -> f32 {
772 v.abs()
773 }
774 #[inline(always)]
775 fn f64(v: f64) -> f64 {
776 v.abs()
777 }
778 #[inline(always)]
779 fn u8(v: u8) -> u8 {
780 v
781 }
782 #[inline(always)]
783 fn u32(v: u32) -> u32 {
784 v
785 }
786 #[inline(always)]
787 fn i16(v: i16) -> i16 {
788 v.abs()
789 }
790 #[inline(always)]
791 fn i32(v: i32) -> i32 {
792 v.abs()
793 }
794 #[inline(always)]
795 fn i64(v: i64) -> i64 {
796 v.abs()
797 }
798 #[inline(always)]
799 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
800 v.abs()
801 }
802}
803
804impl UnaryOpT for Ceil {
805 const NAME: &'static str = "ceil";
806 const KERNEL: &'static str = "uceil";
807 const V: Self = Ceil;
808 #[inline(always)]
809 fn bf16(v: bf16) -> bf16 {
810 v.ceil()
811 }
812 #[inline(always)]
813 fn f16(v: f16) -> f16 {
814 v.ceil()
815 }
816 #[inline(always)]
817 fn f32(v: f32) -> f32 {
818 v.ceil()
819 }
820 #[inline(always)]
821 fn f64(v: f64) -> f64 {
822 v.ceil()
823 }
824 #[inline(always)]
825 fn u8(v: u8) -> u8 {
826 v
827 }
828 #[inline(always)]
829 fn u32(v: u32) -> u32 {
830 v
831 }
832 #[inline(always)]
833 fn i16(v: i16) -> i16 {
834 v
835 }
836 #[inline(always)]
837 fn i32(v: i32) -> i32 {
838 v
839 }
840 #[inline(always)]
841 fn i64(v: i64) -> i64 {
842 v
843 }
844 #[inline(always)]
845 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
846 v.ceil()
847 }
848}
849
850impl UnaryOpT for Floor {
851 const NAME: &'static str = "floor";
852 const KERNEL: &'static str = "ufloor";
853 const V: Self = Floor;
854 #[inline(always)]
855 fn bf16(v: bf16) -> bf16 {
856 v.floor()
857 }
858 #[inline(always)]
859 fn f16(v: f16) -> f16 {
860 v.floor()
861 }
862 #[inline(always)]
863 fn f32(v: f32) -> f32 {
864 v.floor()
865 }
866 #[inline(always)]
867 fn f64(v: f64) -> f64 {
868 v.floor()
869 }
870 #[inline(always)]
871 fn u8(v: u8) -> u8 {
872 v
873 }
874 #[inline(always)]
875 fn u32(v: u32) -> u32 {
876 v
877 }
878 #[inline(always)]
879 fn i16(v: i16) -> i16 {
880 v
881 }
882 #[inline(always)]
883 fn i32(v: i32) -> i32 {
884 v
885 }
886 #[inline(always)]
887 fn i64(v: i64) -> i64 {
888 v
889 }
890 #[inline(always)]
891 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
892 v.floor()
893 }
894}
895
896impl UnaryOpT for Round {
897 const NAME: &'static str = "round";
898 const KERNEL: &'static str = "uround";
899 const V: Self = Round;
900 #[inline(always)]
901 fn bf16(v: bf16) -> bf16 {
902 v.round()
903 }
904 #[inline(always)]
905 fn f16(v: f16) -> f16 {
906 v.round()
907 }
908 #[inline(always)]
909 fn f32(v: f32) -> f32 {
910 v.round()
911 }
912 #[inline(always)]
913 fn f64(v: f64) -> f64 {
914 v.round()
915 }
916 #[inline(always)]
917 fn u8(v: u8) -> u8 {
918 v
919 }
920 #[inline(always)]
921 fn u32(v: u32) -> u32 {
922 v
923 }
924 #[inline(always)]
925 fn i16(v: i16) -> i16 {
926 v
927 }
928 #[inline(always)]
929 fn i32(v: i32) -> i32 {
930 v
931 }
932 #[inline(always)]
933 fn i64(v: i64) -> i64 {
934 v
935 }
936 #[inline(always)]
937 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
938 v.round()
939 }
940}
941
942impl UnaryOpT for GeluErf {
943 const NAME: &'static str = "gelu_erf";
944 const KERNEL: &'static str = "ugelu_erf";
945 const V: Self = GeluErf;
946 #[inline(always)]
947 fn bf16(v: bf16) -> bf16 {
948 bf16::from_f64(Self::f64(v.to_f64()))
949 }
950 #[inline(always)]
951 fn f16(v: f16) -> f16 {
952 f16::from_f64(Self::f64(v.to_f64()))
953 }
954 #[inline(always)]
955 fn f32(v: f32) -> f32 {
956 (crate::cpu::erf::erf_f32(v * std::f32::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v
957 }
958 #[inline(always)]
959 fn f64(v: f64) -> f64 {
960 (crate::cpu::erf::erf_f64(v * std::f64::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v
961 }
962 #[inline(always)]
963 fn u8(_: u8) -> u8 {
964 0
965 }
966 #[inline(always)]
967 fn u32(_: u32) -> u32 {
968 0
969 }
970 #[inline(always)]
971 fn i16(_: i16) -> i16 {
972 0
973 }
974 #[inline(always)]
975 fn i32(_: i32) -> i32 {
976 0
977 }
978 #[inline(always)]
979 fn i64(_: i64) -> i64 {
980 0
981 }
982 #[inline(always)]
983 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
984 f8e4m3::from_f32(Self::f32(v.to_f32()))
985 }
986}
987
988impl UnaryOpT for Relu {
989 const NAME: &'static str = "relu";
990 const KERNEL: &'static str = "urelu";
991 const V: Self = Relu;
992 #[inline(always)]
993 fn bf16(v: bf16) -> bf16 {
994 v.max(bf16::ZERO)
995 }
996 #[inline(always)]
997 fn f16(v: f16) -> f16 {
998 v.max(f16::ZERO)
999 }
1000 #[inline(always)]
1001 fn f32(v: f32) -> f32 {
1002 v.max(0f32)
1003 }
1004 #[inline(always)]
1005 fn f64(v: f64) -> f64 {
1006 v.max(0f64)
1007 }
1008 #[inline(always)]
1009 fn u8(v: u8) -> u8 {
1010 v
1011 }
1012 #[inline(always)]
1013 fn u32(v: u32) -> u32 {
1014 v
1015 }
1016 #[inline(always)]
1017 fn i16(v: i16) -> i16 {
1018 v.max(0)
1019 }
1020 #[inline(always)]
1021 fn i32(v: i32) -> i32 {
1022 v.max(0)
1023 }
1024 #[inline(always)]
1025 fn i64(v: i64) -> i64 {
1026 v.max(0)
1027 }
1028 #[inline(always)]
1029 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
1030 v.max(f8e4m3::ZERO)
1031 }
1032}
1033
1034#[derive(Clone)]
1037pub struct BackpropOp(Option<Op>);
1038
1039impl BackpropOp {
1040 pub fn none() -> Self {
1041 BackpropOp(None)
1042 }
1043
1044 pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
1045 let op = if arg.track_op() {
1046 Some(f(arg.clone()))
1047 } else {
1048 None
1049 };
1050 Self(op)
1051 }
1052
1053 pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
1054 let op = if arg1.track_op() || arg2.track_op() {
1055 Some(f(arg1.clone(), arg2.clone()))
1056 } else {
1057 None
1058 };
1059 Self(op)
1060 }
1061
1062 pub(crate) fn new3(
1063 arg1: &Tensor,
1064 arg2: &Tensor,
1065 arg3: &Tensor,
1066 f: impl Fn(Tensor, Tensor, Tensor) -> Op,
1067 ) -> Self {
1068 let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
1069 Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
1070 } else {
1071 None
1072 };
1073 Self(op)
1074 }
1075
1076 pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
1077 let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
1078 let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
1079 Some(f(args))
1080 } else {
1081 None
1082 };
1083 Self(op)
1084 }
1085
1086 pub(crate) fn is_none(&self) -> bool {
1087 self.0.is_none()
1088 }
1089}
1090
1091impl std::ops::Deref for BackpropOp {
1092 type Target = Option<Op>;
1093 fn deref(&self) -> &Self::Target {
1094 &self.0
1095 }
1096}
1097
1098impl UnaryOpT for Sign {
1099 const NAME: &'static str = "sign";
1100 const KERNEL: &'static str = "usign";
1101 const V: Self = Sign;
1102 #[inline(always)]
1103 fn bf16(v: bf16) -> bf16 {
1104 bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
1105 }
1106 #[inline(always)]
1107 fn f16(v: f16) -> f16 {
1108 f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
1109 }
1110 #[inline(always)]
1111 fn f32(v: f32) -> f32 {
1112 f32::from(v > 0.) - f32::from(v < 0.)
1113 }
1114 #[inline(always)]
1115 fn f64(v: f64) -> f64 {
1116 f64::from(v > 0.) - f64::from(v < 0.)
1117 }
1118 #[inline(always)]
1119 fn u8(v: u8) -> u8 {
1120 u8::min(1, v)
1121 }
1122 #[inline(always)]
1123 fn u32(v: u32) -> u32 {
1124 u32::min(1, v)
1125 }
1126 #[inline(always)]
1127 fn i16(v: i16) -> i16 {
1128 (v > 0) as i16 - (v < 0) as i16
1129 }
1130 #[inline(always)]
1131 fn i32(v: i32) -> i32 {
1132 (v > 0) as i32 - (v < 0) as i32
1133 }
1134 #[inline(always)]
1135 fn i64(v: i64) -> i64 {
1136 (v > 0) as i64 - (v < 0) as i64
1137 }
1138 #[inline(always)]
1139 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
1140 if v > f8e4m3::ZERO {
1141 f8e4m3::ONE
1142 } else if v < f8e4m3::ZERO {
1143 -f8e4m3::ONE
1144 } else {
1145 f8e4m3::ZERO
1146 }
1147 }
1148}