1#![allow(clippy::redundant_closure_call)]
4use crate::Tensor;
5use half::{bf16, f16};
6use num_traits::float::Float;
7
8#[derive(Clone, Copy, PartialEq, Eq)]
9pub enum CmpOp {
10 Eq,
11 Ne,
12 Le,
13 Ge,
14 Lt,
15 Gt,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ReduceOp {
20 Sum,
21 Min,
22 Max,
23 ArgMin,
24 ArgMax,
25}
26
27impl ReduceOp {
28 pub(crate) fn name(&self) -> &'static str {
29 match self {
30 Self::ArgMax => "argmax",
31 Self::ArgMin => "argmin",
32 Self::Min => "min",
33 Self::Max => "max",
34 Self::Sum => "sum",
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum BinaryOp {
42 Add,
43 Mul,
44 Sub,
45 Div,
46 Maximum,
47 Minimum,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum UnaryOp {
53 Exp,
54 Log,
55 Sin,
56 Cos,
57 Abs,
58 Neg,
59 Recip,
60 Sqr,
61 Sqrt,
62 Gelu,
63 GeluErf,
64 Erf,
65 Relu,
66 Silu,
67 Tanh,
68 Floor,
69 Ceil,
70 Round,
71 Sign,
72}
73
74#[derive(Clone)]
75pub enum Op {
76 Binary(Tensor, Tensor, BinaryOp),
77 Unary(Tensor, UnaryOp),
78 Cmp(Tensor, CmpOp),
79 Reduce(Tensor, ReduceOp, Vec<usize>),
81 Matmul(Tensor, Tensor),
82 Gather(Tensor, Tensor, usize),
83 Scatter(Tensor, Tensor, Tensor, usize),
84 ScatterAdd(Tensor, Tensor, Tensor, usize),
85 IndexSelect(Tensor, Tensor, usize),
86 IndexAdd(Tensor, Tensor, Tensor, usize),
87 WhereCond(Tensor, Tensor, Tensor),
88
89 #[allow(dead_code)]
90 Conv1D {
91 arg: Tensor,
92 kernel: Tensor,
93 padding: usize,
94 stride: usize,
95 dilation: usize,
96 },
97
98 #[allow(dead_code)]
99 ConvTranspose1D {
100 arg: Tensor,
101 kernel: Tensor,
102 padding: usize,
103 output_padding: usize,
104 stride: usize,
105 dilation: usize,
106 },
107
108 #[allow(dead_code)]
109 Conv2D {
110 arg: Tensor,
111 kernel: Tensor,
112 padding: usize,
113 stride: usize,
114 dilation: usize,
115 },
116
117 #[allow(dead_code)]
118 ConvTranspose2D {
119 arg: Tensor,
120 kernel: Tensor,
121 padding: usize,
122 output_padding: usize,
123 stride: usize,
124 dilation: usize,
125 },
126
127 AvgPool2D {
128 arg: Tensor,
129 kernel_size: (usize, usize),
130 stride: (usize, usize),
131 },
132
133 MaxPool2D {
134 arg: Tensor,
135 kernel_size: (usize, usize),
136 stride: (usize, usize),
137 },
138
139 UpsampleNearest1D {
140 arg: Tensor,
141 target_size: usize,
142 },
143 UpsampleNearest2D {
144 arg: Tensor,
145 target_h: usize,
146 target_w: usize,
147 },
148
149 Cat(Vec<Tensor>, usize),
150
151 #[allow(dead_code)] Affine {
153 arg: Tensor,
154 mul: f64,
155 add: f64,
156 },
157 ToDType(Tensor),
158 Copy(Tensor),
159 Broadcast(Tensor),
160 Narrow(Tensor, usize, usize, usize),
161 SliceScatter0(Tensor, Tensor, usize),
162 Reshape(Tensor),
163 ToDevice(Tensor),
164 Transpose(Tensor, usize, usize),
165 Permute(Tensor, Vec<usize>),
166 Elu(Tensor, f64),
167 Powf(Tensor, f64),
168 CustomOp1(
169 Tensor,
170 std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
171 ),
172 CustomOp2(
173 Tensor,
174 Tensor,
175 std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
176 ),
177 CustomOp3(
178 Tensor,
179 Tensor,
180 Tensor,
181 std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
182 ),
183}
184
185pub trait UnaryOpT {
186 const NAME: &'static str;
187 const KERNEL: &'static str;
188 const V: Self;
189 fn bf16(v1: bf16) -> bf16;
190 fn f16(v1: f16) -> f16;
191 fn f32(v1: f32) -> f32;
192 fn f64(v1: f64) -> f64;
193 fn u8(v1: u8) -> u8;
194 fn u32(v1: u32) -> u32;
195 fn i64(v1: i64) -> i64;
196
197 const BF16_VEC: bool = false;
200 fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
201 const F16_VEC: bool = false;
202 fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
203 const F32_VEC: bool = false;
204 fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
205 const F64_VEC: bool = false;
206 fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
207}
208
209pub trait BinaryOpT {
210 const NAME: &'static str;
211 const KERNEL: &'static str;
212 const V: Self;
213 fn bf16(v1: bf16, v2: bf16) -> bf16;
214 fn f16(v1: f16, v2: f16) -> f16;
215 fn f32(v1: f32, v2: f32) -> f32;
216 fn f64(v1: f64, v2: f64) -> f64;
217 fn u8(v1: u8, v2: u8) -> u8;
218 fn u32(v1: u32, v2: u32) -> u32;
219 fn i64(v1: i64, v2: i64) -> i64;
220
221 const BF16_VEC: bool = false;
222 fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
223 const F16_VEC: bool = false;
224 fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {}
225 const F32_VEC: bool = false;
226 fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {}
227 const F64_VEC: bool = false;
228 fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {}
229 const U8_VEC: bool = false;
230 fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
231 const U32_VEC: bool = false;
232 fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
233 const I64_VEC: bool = false;
234 fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
235}
236
237pub(crate) struct Add;
238pub(crate) struct Div;
239pub(crate) struct Mul;
240pub(crate) struct Sub;
241pub(crate) struct Maximum;
242pub(crate) struct Minimum;
243pub(crate) struct Exp;
244pub(crate) struct Log;
245pub(crate) struct Sin;
246pub(crate) struct Cos;
247pub(crate) struct Abs;
248pub(crate) struct Neg;
249pub(crate) struct Recip;
250pub(crate) struct Sqr;
251pub(crate) struct Sqrt;
252pub(crate) struct Gelu;
253pub(crate) struct GeluErf;
254pub(crate) struct Erf;
255pub(crate) struct Relu;
256pub(crate) struct Silu;
257pub(crate) struct Tanh;
258pub(crate) struct Floor;
259pub(crate) struct Ceil;
260pub(crate) struct Round;
261pub(crate) struct Sign;
262
263macro_rules! bin_op {
264 ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
265 impl BinaryOpT for $op {
266 const NAME: &'static str = $name;
267 const KERNEL: &'static str = concat!("b", $name);
268 const V: Self = $op;
269 #[inline(always)]
270 fn bf16(v1: bf16, v2: bf16) -> bf16 {
271 $e(v1, v2)
272 }
273 #[inline(always)]
274 fn f16(v1: f16, v2: f16) -> f16 {
275 $e(v1, v2)
276 }
277 #[inline(always)]
278 fn f32(v1: f32, v2: f32) -> f32 {
279 $e(v1, v2)
280 }
281 #[inline(always)]
282 fn f64(v1: f64, v2: f64) -> f64 {
283 $e(v1, v2)
284 }
285 #[inline(always)]
286 fn u8(v1: u8, v2: u8) -> u8 {
287 $e(v1, v2)
288 }
289 #[inline(always)]
290 fn u32(v1: u32, v2: u32) -> u32 {
291 $e(v1, v2)
292 }
293 #[inline(always)]
294 fn i64(v1: i64, v2: i64) -> i64 {
295 $e(v1, v2)
296 }
297
298 #[cfg(feature = "mkl")]
299 const F32_VEC: bool = true;
300 #[cfg(feature = "mkl")]
301 const F64_VEC: bool = true;
302 #[cfg(feature = "mkl")]
303 #[inline(always)]
304 fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
305 crate::mkl::$f32_vec(xs1, xs2, ys)
306 }
307 #[cfg(feature = "mkl")]
308 #[inline(always)]
309 fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
310 crate::mkl::$f64_vec(xs1, xs2, ys)
311 }
312
313 #[cfg(feature = "accelerate")]
314 const F32_VEC: bool = true;
315 #[cfg(feature = "accelerate")]
316 const F64_VEC: bool = true;
317 #[cfg(feature = "accelerate")]
318 #[inline(always)]
319 fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
320 crate::accelerate::$f32_vec(xs1, xs2, ys)
321 }
322 #[cfg(feature = "accelerate")]
323 #[inline(always)]
324 fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
325 crate::accelerate::$f64_vec(xs1, xs2, ys)
326 }
327 }
328 };
329}
330
331bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
332bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
333bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
334bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
335bin_op!(
336 Minimum,
337 "minimum",
338 |v1, v2| if v1 > v2 { v2 } else { v1 },
339 vs_min,
340 vd_min
341);
342bin_op!(
343 Maximum,
344 "maximum",
345 |v1, v2| if v1 < v2 { v2 } else { v1 },
346 vs_max,
347 vd_max
348);
349
350#[allow(clippy::redundant_closure_call)]
351macro_rules! unary_op {
352 ($op: ident, $name: literal, $a: ident, $e: expr) => {
353 impl UnaryOpT for $op {
354 const NAME: &'static str = $name;
355 const KERNEL: &'static str = concat!("u", $name);
356 const V: Self = $op;
357 #[inline(always)]
358 fn bf16($a: bf16) -> bf16 {
359 $e
360 }
361 #[inline(always)]
362 fn f16($a: f16) -> f16 {
363 $e
364 }
365 #[inline(always)]
366 fn f32($a: f32) -> f32 {
367 $e
368 }
369 #[inline(always)]
370 fn f64($a: f64) -> f64 {
371 $e
372 }
373 #[inline(always)]
374 fn u8(_: u8) -> u8 {
375 todo!("no unary function for u8")
376 }
377 #[inline(always)]
378 fn u32(_: u32) -> u32 {
379 todo!("no unary function for u32")
380 }
381 #[inline(always)]
382 fn i64(_: i64) -> i64 {
383 todo!("no unary function for i64")
384 }
385 }
386 };
387
388 ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
389 impl UnaryOpT for $op {
390 const NAME: &'static str = $name;
391 const KERNEL: &'static str = concat!("u", $name);
392 const V: Self = $op;
393 #[inline(always)]
394 fn bf16($a: bf16) -> bf16 {
395 $e
396 }
397 #[inline(always)]
398 fn f16($a: f16) -> f16 {
399 $e
400 }
401 #[inline(always)]
402 fn f32($a: f32) -> f32 {
403 $e
404 }
405 #[inline(always)]
406 fn f64($a: f64) -> f64 {
407 $e
408 }
409 #[inline(always)]
410 fn u8(_: u8) -> u8 {
411 todo!("no unary function for u8")
412 }
413 #[inline(always)]
414 fn u32(_: u32) -> u32 {
415 todo!("no unary function for u32")
416 }
417 #[inline(always)]
418 fn i64(_: i64) -> i64 {
419 todo!("no unary function for i64")
420 }
421
422 #[cfg(feature = "mkl")]
423 const F32_VEC: bool = true;
424 #[cfg(feature = "mkl")]
425 const F64_VEC: bool = true;
426 #[cfg(feature = "mkl")]
427 #[inline(always)]
428 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
429 crate::mkl::$f32_vec(xs, ys)
430 }
431 #[cfg(feature = "mkl")]
432 #[inline(always)]
433 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
434 crate::mkl::$f64_vec(xs, ys)
435 }
436
437 #[cfg(feature = "accelerate")]
438 const F32_VEC: bool = true;
439 #[cfg(feature = "accelerate")]
440 const F64_VEC: bool = true;
441 #[cfg(feature = "accelerate")]
442 #[inline(always)]
443 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
444 crate::accelerate::$f32_vec(xs, ys)
445 }
446 #[cfg(feature = "accelerate")]
447 #[inline(always)]
448 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
449 crate::accelerate::$f64_vec(xs, ys)
450 }
451 }
452 };
453}
454
455unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
456unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
457unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
458unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
459unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
460unary_op!(Neg, "neg", v, -v);
461unary_op!(Recip, "recip", v, v.recip());
462unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
463unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
464
465#[allow(clippy::excessive_precision)]
468const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
469#[allow(clippy::excessive_precision)]
470const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
471
472impl UnaryOpT for Gelu {
476 const NAME: &'static str = "gelu";
477 const V: Self = Gelu;
478 #[inline(always)]
479 fn bf16(v: bf16) -> bf16 {
480 bf16::from_f32_const(0.5)
481 * v
482 * (bf16::ONE
483 + bf16::tanh(
484 bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
485 * v
486 * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
487 ))
488 }
489 #[inline(always)]
490 fn f16(v: f16) -> f16 {
491 f16::from_f32_const(0.5)
492 * v
493 * (f16::ONE
494 + f16::tanh(
495 f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
496 * v
497 * (f16::ONE + f16::from_f32_const(0.044715) * v * v),
498 ))
499 }
500 #[inline(always)]
501 fn f32(v: f32) -> f32 {
502 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
503 }
504 #[inline(always)]
505 fn f64(v: f64) -> f64 {
506 0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
507 }
508 #[inline(always)]
509 fn u8(_: u8) -> u8 {
510 0
511 }
512 #[inline(always)]
513 fn u32(_: u32) -> u32 {
514 0
515 }
516 #[inline(always)]
517 fn i64(_: i64) -> i64 {
518 0
519 }
520 const KERNEL: &'static str = "ugelu";
521
522 #[cfg(feature = "mkl")]
523 const F32_VEC: bool = true;
524
525 #[cfg(feature = "mkl")]
526 #[inline(always)]
527 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
528 crate::mkl::vs_gelu(xs, ys)
529 }
530
531 #[cfg(feature = "mkl")]
532 const F64_VEC: bool = true;
533
534 #[cfg(feature = "mkl")]
535 #[inline(always)]
536 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
537 crate::mkl::vd_gelu(xs, ys)
538 }
539
540 #[cfg(feature = "accelerate")]
541 const F32_VEC: bool = true;
542
543 #[cfg(feature = "accelerate")]
544 #[inline(always)]
545 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
546 crate::accelerate::vs_gelu(xs, ys)
547 }
548
549 #[cfg(feature = "accelerate")]
550 const F64_VEC: bool = true;
551
552 #[cfg(feature = "accelerate")]
553 #[inline(always)]
554 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
555 crate::accelerate::vd_gelu(xs, ys)
556 }
557}
558
559impl UnaryOpT for Erf {
562 const NAME: &'static str = "erf";
563 const KERNEL: &'static str = "uerf";
564 const V: Self = Erf;
565 #[inline(always)]
566 fn bf16(v: bf16) -> bf16 {
567 bf16::from_f64(Self::f64(v.to_f64()))
568 }
569 #[inline(always)]
570 fn f16(v: f16) -> f16 {
571 f16::from_f64(Self::f64(v.to_f64()))
572 }
573 #[inline(always)]
574 fn f32(v: f32) -> f32 {
575 Self::f64(v as f64) as f32
576 }
577 #[inline(always)]
578 fn f64(v: f64) -> f64 {
579 crate::cpu::erf::erf(v)
580 }
581 #[inline(always)]
582 fn u8(_: u8) -> u8 {
583 0
584 }
585 #[inline(always)]
586 fn u32(_: u32) -> u32 {
587 0
588 }
589 #[inline(always)]
590 fn i64(_: i64) -> i64 {
591 0
592 }
593}
594
595impl UnaryOpT for Silu {
597 const NAME: &'static str = "silu";
598 const V: Self = Silu;
599 #[inline(always)]
600 fn bf16(v: bf16) -> bf16 {
601 v / (bf16::ONE + (-v).exp())
602 }
603 #[inline(always)]
604 fn f16(v: f16) -> f16 {
605 v / (f16::ONE + (-v).exp())
606 }
607 #[inline(always)]
608 fn f32(v: f32) -> f32 {
609 v / (1.0 + (-v).exp())
610 }
611 #[inline(always)]
612 fn f64(v: f64) -> f64 {
613 v / (1.0 + (-v).exp())
614 }
615 #[inline(always)]
616 fn u8(_: u8) -> u8 {
617 0
618 }
619 #[inline(always)]
620 fn u32(_: u32) -> u32 {
621 0
622 }
623 #[inline(always)]
624 fn i64(_: i64) -> i64 {
625 0
626 }
627 const KERNEL: &'static str = "usilu";
628
629 #[cfg(feature = "mkl")]
630 const F32_VEC: bool = true;
631
632 #[cfg(feature = "mkl")]
633 #[inline(always)]
634 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
635 crate::mkl::vs_silu(xs, ys)
636 }
637
638 #[cfg(feature = "mkl")]
639 const F64_VEC: bool = true;
640
641 #[cfg(feature = "mkl")]
642 #[inline(always)]
643 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
644 crate::mkl::vd_silu(xs, ys)
645 }
646
647 #[cfg(feature = "accelerate")]
648 const F32_VEC: bool = true;
649
650 #[cfg(feature = "accelerate")]
651 #[inline(always)]
652 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
653 crate::accelerate::vs_silu(xs, ys)
654 }
655
656 #[cfg(feature = "accelerate")]
657 const F64_VEC: bool = true;
658
659 #[cfg(feature = "accelerate")]
660 #[inline(always)]
661 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
662 crate::accelerate::vd_silu(xs, ys)
663 }
664}
665
666impl UnaryOpT for Abs {
667 const NAME: &'static str = "abs";
668 const KERNEL: &'static str = "uabs";
669 const V: Self = Abs;
670 #[inline(always)]
671 fn bf16(v: bf16) -> bf16 {
672 v.abs()
673 }
674 #[inline(always)]
675 fn f16(v: f16) -> f16 {
676 v.abs()
677 }
678 #[inline(always)]
679 fn f32(v: f32) -> f32 {
680 v.abs()
681 }
682 #[inline(always)]
683 fn f64(v: f64) -> f64 {
684 v.abs()
685 }
686 #[inline(always)]
687 fn u8(v: u8) -> u8 {
688 v
689 }
690 #[inline(always)]
691 fn u32(v: u32) -> u32 {
692 v
693 }
694 #[inline(always)]
695 fn i64(v: i64) -> i64 {
696 v.abs()
697 }
698}
699
700impl UnaryOpT for Ceil {
701 const NAME: &'static str = "ceil";
702 const KERNEL: &'static str = "uceil";
703 const V: Self = Ceil;
704 #[inline(always)]
705 fn bf16(v: bf16) -> bf16 {
706 v.ceil()
707 }
708 #[inline(always)]
709 fn f16(v: f16) -> f16 {
710 v.ceil()
711 }
712 #[inline(always)]
713 fn f32(v: f32) -> f32 {
714 v.ceil()
715 }
716 #[inline(always)]
717 fn f64(v: f64) -> f64 {
718 v.ceil()
719 }
720 #[inline(always)]
721 fn u8(v: u8) -> u8 {
722 v
723 }
724 #[inline(always)]
725 fn u32(v: u32) -> u32 {
726 v
727 }
728 #[inline(always)]
729 fn i64(v: i64) -> i64 {
730 v
731 }
732}
733
734impl UnaryOpT for Floor {
735 const NAME: &'static str = "floor";
736 const KERNEL: &'static str = "ufloor";
737 const V: Self = Floor;
738 #[inline(always)]
739 fn bf16(v: bf16) -> bf16 {
740 v.floor()
741 }
742 #[inline(always)]
743 fn f16(v: f16) -> f16 {
744 v.floor()
745 }
746 #[inline(always)]
747 fn f32(v: f32) -> f32 {
748 v.floor()
749 }
750 #[inline(always)]
751 fn f64(v: f64) -> f64 {
752 v.floor()
753 }
754 #[inline(always)]
755 fn u8(v: u8) -> u8 {
756 v
757 }
758 #[inline(always)]
759 fn u32(v: u32) -> u32 {
760 v
761 }
762 #[inline(always)]
763 fn i64(v: i64) -> i64 {
764 v
765 }
766}
767
768impl UnaryOpT for Round {
769 const NAME: &'static str = "round";
770 const KERNEL: &'static str = "uround";
771 const V: Self = Round;
772 #[inline(always)]
773 fn bf16(v: bf16) -> bf16 {
774 v.round()
775 }
776 #[inline(always)]
777 fn f16(v: f16) -> f16 {
778 v.round()
779 }
780 #[inline(always)]
781 fn f32(v: f32) -> f32 {
782 v.round()
783 }
784 #[inline(always)]
785 fn f64(v: f64) -> f64 {
786 v.round()
787 }
788 #[inline(always)]
789 fn u8(v: u8) -> u8 {
790 v
791 }
792 #[inline(always)]
793 fn u32(v: u32) -> u32 {
794 v
795 }
796 #[inline(always)]
797 fn i64(v: i64) -> i64 {
798 v
799 }
800}
801
802impl UnaryOpT for GeluErf {
803 const NAME: &'static str = "gelu_erf";
804 const KERNEL: &'static str = "ugelu_erf";
805 const V: Self = GeluErf;
806 #[inline(always)]
807 fn bf16(v: bf16) -> bf16 {
808 bf16::from_f64(Self::f64(v.to_f64()))
809 }
810 #[inline(always)]
811 fn f16(v: f16) -> f16 {
812 f16::from_f64(Self::f64(v.to_f64()))
813 }
814 #[inline(always)]
815 fn f32(v: f32) -> f32 {
816 Self::f64(v as f64) as f32
817 }
818 #[inline(always)]
819 fn f64(v: f64) -> f64 {
820 (crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
821 }
822 #[inline(always)]
823 fn u8(_: u8) -> u8 {
824 0
825 }
826 #[inline(always)]
827 fn u32(_: u32) -> u32 {
828 0
829 }
830 #[inline(always)]
831 fn i64(_: i64) -> i64 {
832 0
833 }
834}
835
836impl UnaryOpT for Relu {
837 const NAME: &'static str = "relu";
838 const KERNEL: &'static str = "urelu";
839 const V: Self = Relu;
840 #[inline(always)]
841 fn bf16(v: bf16) -> bf16 {
842 v.max(bf16::ZERO)
843 }
844 #[inline(always)]
845 fn f16(v: f16) -> f16 {
846 v.max(f16::ZERO)
847 }
848 #[inline(always)]
849 fn f32(v: f32) -> f32 {
850 v.max(0f32)
851 }
852 #[inline(always)]
853 fn f64(v: f64) -> f64 {
854 v.max(0f64)
855 }
856 #[inline(always)]
857 fn u8(v: u8) -> u8 {
858 v
859 }
860 #[inline(always)]
861 fn u32(v: u32) -> u32 {
862 v
863 }
864 #[inline(always)]
865 fn i64(v: i64) -> i64 {
866 v
867 }
868}
869
870#[derive(Clone)]
873pub struct BackpropOp(Option<Op>);
874
875impl BackpropOp {
876 pub(crate) fn none() -> Self {
877 BackpropOp(None)
878 }
879
880 pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
881 let op = if arg.track_op() {
882 Some(f(arg.clone()))
883 } else {
884 None
885 };
886 Self(op)
887 }
888
889 pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
890 let op = if arg1.track_op() || arg2.track_op() {
891 Some(f(arg1.clone(), arg2.clone()))
892 } else {
893 None
894 };
895 Self(op)
896 }
897
898 pub(crate) fn new3(
899 arg1: &Tensor,
900 arg2: &Tensor,
901 arg3: &Tensor,
902 f: impl Fn(Tensor, Tensor, Tensor) -> Op,
903 ) -> Self {
904 let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
905 Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
906 } else {
907 None
908 };
909 Self(op)
910 }
911
912 pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
913 let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
914 let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
915 Some(f(args))
916 } else {
917 None
918 };
919 Self(op)
920 }
921
922 pub(crate) fn is_none(&self) -> bool {
923 self.0.is_none()
924 }
925}
926
927impl std::ops::Deref for BackpropOp {
928 type Target = Option<Op>;
929 fn deref(&self) -> &Self::Target {
930 &self.0
931 }
932}
933
934impl UnaryOpT for Sign {
935 const NAME: &'static str = "sign";
936 const KERNEL: &'static str = "usign";
937 const V: Self = Sign;
938 #[inline(always)]
939 fn bf16(v: bf16) -> bf16 {
940 bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
941 }
942 #[inline(always)]
943 fn f16(v: f16) -> f16 {
944 f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
945 }
946 #[inline(always)]
947 fn f32(v: f32) -> f32 {
948 f32::from(v > 0.) - f32::from(v < 0.)
949 }
950 #[inline(always)]
951 fn f64(v: f64) -> f64 {
952 f64::from(v > 0.) - f64::from(v < 0.)
953 }
954 #[inline(always)]
955 fn u8(v: u8) -> u8 {
956 u8::min(1, v)
957 }
958 #[inline(always)]
959 fn u32(v: u32) -> u32 {
960 u32::min(1, v)
961 }
962 #[inline(always)]
963 fn i64(v: i64) -> i64 {
964 (v > 0) as i64 - (v < 0) as i64
965 }
966}