1#![allow(clippy::redundant_closure_call)]
2use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
3use half::{bf16, f16};
4use num_traits::float::Float;
5
6#[derive(Clone, Copy, PartialEq, Eq)]
7pub enum CmpOp {
8 Eq,
9 Ne,
10 Le,
11 Ge,
12 Lt,
13 Gt,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ReduceOp {
18 Sum,
19 Min,
20 Max,
21 ArgMin,
22 ArgMax,
23}
24
25impl ReduceOp {
26 pub(crate) fn name(&self) -> &'static str {
27 match self {
28 Self::ArgMax => "argmax",
29 Self::ArgMin => "argmin",
30 Self::Min => "min",
31 Self::Max => "max",
32 Self::Sum => "sum",
33 }
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum BinaryOp {
40 Add,
41 Mul,
42 Sub,
43 Div,
44 Maximum,
45 Minimum,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum UnaryOp {
51 Exp,
52 Log,
53 Sin,
54 Cos,
55 Abs,
56 Neg,
57 Recip,
58 Sqr,
59 Sqrt,
60 Gelu,
61 GeluErf,
62 Erf,
63 Relu,
64 Tanh,
65 Floor,
66 Ceil,
67 Round,
68}
69
70#[derive(Clone)]
71pub enum Op {
72 Binary(Tensor, Tensor, BinaryOp),
73 Unary(Tensor, UnaryOp),
74 Cmp(Tensor, CmpOp),
75 Reduce(Tensor, ReduceOp, Vec<usize>),
77 Matmul(Tensor, Tensor),
78 Gather(Tensor, Tensor, usize),
79 ScatterAdd(Tensor, Tensor, Tensor, usize),
80 IndexSelect(Tensor, Tensor, usize),
81 IndexAdd(Tensor, Tensor, Tensor, usize),
82 WhereCond(Tensor, Tensor, Tensor),
83
84 #[allow(dead_code)]
85 Conv1D {
86 arg: Tensor,
87 kernel: Tensor,
88 padding: usize,
89 stride: usize,
90 dilation: usize,
91 },
92
93 #[allow(dead_code)]
94 Conv2D {
95 arg: Tensor,
96 kernel: Tensor,
97 padding: usize,
98 stride: usize,
99 dilation: usize,
100 },
101
102 #[allow(dead_code)]
103 ConvTranspose2D {
104 arg: Tensor,
105 kernel: Tensor,
106 padding: usize,
107 output_padding: usize,
108 stride: usize,
109 dilation: usize,
110 },
111
112 AvgPool2D {
113 arg: Tensor,
114 kernel_size: (usize, usize),
115 stride: (usize, usize),
116 },
117
118 MaxPool2D {
119 arg: Tensor,
120 kernel_size: (usize, usize),
121 stride: (usize, usize),
122 },
123
124 UpsampleNearest1D(Tensor),
125 UpsampleNearest2D(Tensor),
126
127 Cat(Vec<Tensor>, usize),
128
129 #[allow(dead_code)] Affine {
131 arg: Tensor,
132 mul: f64,
133 add: f64,
134 },
135 ToDType(Tensor),
136 Copy(Tensor),
137 Broadcast(Tensor),
138 Narrow(Tensor, usize, usize, usize),
139 SliceScatter0(Tensor, Tensor, usize),
140 Reshape(Tensor),
141 ToDevice(Tensor),
142 Transpose(Tensor, usize, usize),
143 Permute(Tensor, Vec<usize>),
144 Elu(Tensor, f64),
145 Powf(Tensor, f64),
146 CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
147 CustomOp2(
148 Tensor,
149 Tensor,
150 std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
151 ),
152 CustomOp3(
153 Tensor,
154 Tensor,
155 Tensor,
156 std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
157 ),
158}
159
160pub trait CustomOp1 {
162 fn name(&self) -> &'static str;
164
165 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
168
169 fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
172 Err(crate::Error::Cuda(
173 format!("no cuda implementation for {}", self.name()).into(),
174 ))
175 }
176
177 fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
181 Err(crate::Error::BackwardNotSupported { op: self.name() })
182 }
183}
184
185pub trait CustomOp2 {
186 fn name(&self) -> &'static str;
187
188 fn cpu_fwd(
191 &self,
192 s1: &CpuStorage,
193 l1: &Layout,
194 s2: &CpuStorage,
195 l2: &Layout,
196 ) -> Result<(CpuStorage, Shape)>;
197
198 fn cuda_fwd(
201 &self,
202 _: &CudaStorage,
203 _: &Layout,
204 _: &CudaStorage,
205 _: &Layout,
206 ) -> Result<(CudaStorage, Shape)> {
207 Err(crate::Error::Cuda(
208 format!("no cuda implementation for {}", self.name()).into(),
209 ))
210 }
211
212 fn bwd(
213 &self,
214 _arg1: &Tensor,
215 _arg2: &Tensor,
216 _res: &Tensor,
217 _grad_res: &Tensor,
218 ) -> Result<(Option<Tensor>, Option<Tensor>)> {
219 Err(crate::Error::BackwardNotSupported { op: self.name() })
220 }
221}
222
223pub trait CustomOp3 {
224 fn name(&self) -> &'static str;
225
226 fn cpu_fwd(
229 &self,
230 s1: &CpuStorage,
231 l1: &Layout,
232 s2: &CpuStorage,
233 l2: &Layout,
234 s3: &CpuStorage,
235 l3: &Layout,
236 ) -> Result<(CpuStorage, Shape)>;
237
238 fn cuda_fwd(
241 &self,
242 _: &CudaStorage,
243 _: &Layout,
244 _: &CudaStorage,
245 _: &Layout,
246 _: &CudaStorage,
247 _: &Layout,
248 ) -> Result<(CudaStorage, Shape)> {
249 Err(crate::Error::Cuda(
250 format!("no cuda implementation for {}", self.name()).into(),
251 ))
252 }
253
254 fn bwd(
255 &self,
256 _arg1: &Tensor,
257 _arg2: &Tensor,
258 _arg3: &Tensor,
259 _res: &Tensor,
260 _grad_res: &Tensor,
261 ) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
262 Err(crate::Error::BackwardNotSupported { op: self.name() })
263 }
264}
265
266pub trait UnaryOpT {
267 const NAME: &'static str;
268 const KERNEL: &'static str;
269 const V: Self;
270 fn bf16(v1: bf16) -> bf16;
271 fn f16(v1: f16) -> f16;
272 fn f32(v1: f32) -> f32;
273 fn f64(v1: f64) -> f64;
274 fn u8(v1: u8) -> u8;
275 fn u32(v1: u32) -> u32;
276 fn i64(v1: i64) -> i64;
277
278 const BF16_VEC: bool = false;
281 fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
282 const F16_VEC: bool = false;
283 fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
284 const F32_VEC: bool = false;
285 fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
286 const F64_VEC: bool = false;
287 fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
288}
289
290pub trait BinaryOpT {
291 const NAME: &'static str;
292 const KERNEL: &'static str;
293 const V: Self;
294 fn bf16(v1: bf16, v2: bf16) -> bf16;
295 fn f16(v1: f16, v2: f16) -> f16;
296 fn f32(v1: f32, v2: f32) -> f32;
297 fn f64(v1: f64, v2: f64) -> f64;
298 fn u8(v1: u8, v2: u8) -> u8;
299 fn u32(v1: u32, v2: u32) -> u32;
300 fn i64(v1: i64, v2: i64) -> i64;
301
302 const BF16_VEC: bool = false;
303 fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
304 const F16_VEC: bool = false;
305 fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {}
306 const F32_VEC: bool = false;
307 fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {}
308 const F64_VEC: bool = false;
309 fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {}
310 const U8_VEC: bool = false;
311 fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
312 const U32_VEC: bool = false;
313 fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
314 const I64_VEC: bool = false;
315 fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
316}
317
318pub(crate) struct Add;
319pub(crate) struct Div;
320pub(crate) struct Mul;
321pub(crate) struct Sub;
322pub(crate) struct Maximum;
323pub(crate) struct Minimum;
324pub(crate) struct Exp;
325pub(crate) struct Log;
326pub(crate) struct Sin;
327pub(crate) struct Cos;
328pub(crate) struct Abs;
329pub(crate) struct Neg;
330pub(crate) struct Recip;
331pub(crate) struct Sqr;
332pub(crate) struct Sqrt;
333pub(crate) struct Gelu;
334pub(crate) struct GeluErf;
335pub(crate) struct Erf;
336pub(crate) struct Relu;
337pub(crate) struct Tanh;
338pub(crate) struct Floor;
339pub(crate) struct Ceil;
340pub(crate) struct Round;
341
342macro_rules! bin_op {
343 ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
344 impl BinaryOpT for $op {
345 const NAME: &'static str = $name;
346 const KERNEL: &'static str = concat!("b", $name);
347 const V: Self = $op;
348 #[inline(always)]
349 fn bf16(v1: bf16, v2: bf16) -> bf16 {
350 $e(v1, v2)
351 }
352 #[inline(always)]
353 fn f16(v1: f16, v2: f16) -> f16 {
354 $e(v1, v2)
355 }
356 #[inline(always)]
357 fn f32(v1: f32, v2: f32) -> f32 {
358 $e(v1, v2)
359 }
360 #[inline(always)]
361 fn f64(v1: f64, v2: f64) -> f64 {
362 $e(v1, v2)
363 }
364 #[inline(always)]
365 fn u8(v1: u8, v2: u8) -> u8 {
366 $e(v1, v2)
367 }
368 #[inline(always)]
369 fn u32(v1: u32, v2: u32) -> u32 {
370 $e(v1, v2)
371 }
372 #[inline(always)]
373 fn i64(v1: i64, v2: i64) -> i64 {
374 $e(v1, v2)
375 }
376
377 #[cfg(feature = "mkl")]
378 const F32_VEC: bool = true;
379 #[cfg(feature = "mkl")]
380 const F64_VEC: bool = true;
381 #[cfg(feature = "mkl")]
382 #[inline(always)]
383 fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
384 crate::mkl::$f32_vec(xs1, xs2, ys)
385 }
386 #[cfg(feature = "mkl")]
387 #[inline(always)]
388 fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
389 crate::mkl::$f64_vec(xs1, xs2, ys)
390 }
391
392 #[cfg(feature = "accelerate")]
393 const F32_VEC: bool = true;
394 #[cfg(feature = "accelerate")]
395 const F64_VEC: bool = true;
396 #[cfg(feature = "accelerate")]
397 #[inline(always)]
398 fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
399 crate::accelerate::$f32_vec(xs1, xs2, ys)
400 }
401 #[cfg(feature = "accelerate")]
402 #[inline(always)]
403 fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
404 crate::accelerate::$f64_vec(xs1, xs2, ys)
405 }
406 }
407 };
408}
409
410bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
411bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
412bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
413bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
414bin_op!(
415 Minimum,
416 "minimum",
417 |v1, v2| if v1 > v2 { v2 } else { v1 },
418 vs_min,
419 vd_min
420);
421bin_op!(
422 Maximum,
423 "maximum",
424 |v1, v2| if v1 < v2 { v2 } else { v1 },
425 vs_max,
426 vd_max
427);
428
429#[allow(clippy::redundant_closure_call)]
430macro_rules! unary_op {
431 ($op: ident, $name: literal, $a: ident, $e: expr) => {
432 impl UnaryOpT for $op {
433 const NAME: &'static str = $name;
434 const KERNEL: &'static str = concat!("u", $name);
435 const V: Self = $op;
436 #[inline(always)]
437 fn bf16($a: bf16) -> bf16 {
438 $e
439 }
440 #[inline(always)]
441 fn f16($a: f16) -> f16 {
442 $e
443 }
444 #[inline(always)]
445 fn f32($a: f32) -> f32 {
446 $e
447 }
448 #[inline(always)]
449 fn f64($a: f64) -> f64 {
450 $e
451 }
452 #[inline(always)]
453 fn u8(_: u8) -> u8 {
454 todo!("no unary function for u8")
455 }
456 #[inline(always)]
457 fn u32(_: u32) -> u32 {
458 todo!("no unary function for u32")
459 }
460 #[inline(always)]
461 fn i64(_: i64) -> i64 {
462 todo!("no unary function for i64")
463 }
464 }
465 };
466
467 ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
468 impl UnaryOpT for $op {
469 const NAME: &'static str = $name;
470 const KERNEL: &'static str = concat!("u", $name);
471 const V: Self = $op;
472 #[inline(always)]
473 fn bf16($a: bf16) -> bf16 {
474 $e
475 }
476 #[inline(always)]
477 fn f16($a: f16) -> f16 {
478 $e
479 }
480 #[inline(always)]
481 fn f32($a: f32) -> f32 {
482 $e
483 }
484 #[inline(always)]
485 fn f64($a: f64) -> f64 {
486 $e
487 }
488 #[inline(always)]
489 fn u8(_: u8) -> u8 {
490 todo!("no unary function for u8")
491 }
492 #[inline(always)]
493 fn u32(_: u32) -> u32 {
494 todo!("no unary function for u32")
495 }
496 #[inline(always)]
497 fn i64(_: i64) -> i64 {
498 todo!("no unary function for i64")
499 }
500
501 #[cfg(feature = "mkl")]
502 const F32_VEC: bool = true;
503 #[cfg(feature = "mkl")]
504 const F64_VEC: bool = true;
505 #[cfg(feature = "mkl")]
506 #[inline(always)]
507 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
508 crate::mkl::$f32_vec(xs, ys)
509 }
510 #[cfg(feature = "mkl")]
511 #[inline(always)]
512 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
513 crate::mkl::$f64_vec(xs, ys)
514 }
515
516 #[cfg(feature = "accelerate")]
517 const F32_VEC: bool = true;
518 #[cfg(feature = "accelerate")]
519 const F64_VEC: bool = true;
520 #[cfg(feature = "accelerate")]
521 #[inline(always)]
522 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
523 crate::accelerate::$f32_vec(xs, ys)
524 }
525 #[cfg(feature = "accelerate")]
526 #[inline(always)]
527 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
528 crate::accelerate::$f64_vec(xs, ys)
529 }
530 }
531 };
532}
533
534unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
535unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
536unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
537unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
538unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
539unary_op!(Neg, "neg", v, -v);
540unary_op!(Recip, "recip", v, v.recip());
541unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
542unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
543
544impl UnaryOpT for Gelu {
547 const NAME: &'static str = "gelu";
548 const V: Self = Gelu;
549 #[inline(always)]
550 fn bf16(v: bf16) -> bf16 {
551 bf16::from_f32_const(0.5)
552 * v
553 * (bf16::ONE
554 + bf16::tanh(
555 (bf16::from_f32_const(2.0) / bf16::PI).sqrt()
556 * v
557 * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
558 ))
559 }
560 #[inline(always)]
561 fn f16(v: f16) -> f16 {
562 f16::from_f32_const(0.5)
563 * v
564 * (f16::ONE
565 + f16::tanh(
566 (f16::from_f32_const(2.0) / f16::PI).sqrt()
567 * v
568 * (f16::ONE + f16::from_f32_const(0.044715) * v * v),
569 ))
570 }
571 #[inline(always)]
572 fn f32(v: f32) -> f32 {
573 0.5 * v
574 * (1.0
575 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
576 }
577 #[inline(always)]
578 fn f64(v: f64) -> f64 {
579 0.5 * v
580 * (1.0
581 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
582 }
583 #[inline(always)]
584 fn u8(_: u8) -> u8 {
585 0
586 }
587 #[inline(always)]
588 fn u32(_: u32) -> u32 {
589 0
590 }
591 #[inline(always)]
592 fn i64(_: i64) -> i64 {
593 0
594 }
595 const KERNEL: &'static str = "ugelu";
596
597 #[cfg(feature = "mkl")]
598 const F32_VEC: bool = true;
599
600 #[cfg(feature = "mkl")]
601 #[inline(always)]
602 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
603 crate::mkl::vs_gelu(xs, ys)
604 }
605
606 #[cfg(feature = "mkl")]
607 const F64_VEC: bool = true;
608
609 #[cfg(feature = "mkl")]
610 #[inline(always)]
611 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
612 crate::mkl::vd_gelu(xs, ys)
613 }
614
615 #[cfg(feature = "accelerate")]
616 const F32_VEC: bool = true;
617
618 #[cfg(feature = "accelerate")]
619 #[inline(always)]
620 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
621 crate::accelerate::vs_gelu(xs, ys)
622 }
623
624 #[cfg(feature = "accelerate")]
625 const F64_VEC: bool = true;
626
627 #[cfg(feature = "accelerate")]
628 #[inline(always)]
629 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
630 crate::accelerate::vd_gelu(xs, ys)
631 }
632}
633
634impl UnaryOpT for Erf {
635 const NAME: &'static str = "erf";
636 const KERNEL: &'static str = "uerf";
637 const V: Self = Erf;
638 #[inline(always)]
639 fn bf16(v: bf16) -> bf16 {
640 bf16::from_f64(Self::f64(v.to_f64()))
641 }
642 #[inline(always)]
643 fn f16(v: f16) -> f16 {
644 f16::from_f64(Self::f64(v.to_f64()))
645 }
646 #[inline(always)]
647 fn f32(v: f32) -> f32 {
648 Self::f64(v as f64) as f32
649 }
650 #[inline(always)]
651 fn f64(v: f64) -> f64 {
652 crate::cpu::erf::erf(v)
653 }
654 #[inline(always)]
655 fn u8(_: u8) -> u8 {
656 0
657 }
658 #[inline(always)]
659 fn u32(_: u32) -> u32 {
660 0
661 }
662 #[inline(always)]
663 fn i64(_: i64) -> i64 {
664 0
665 }
666}
667
668impl UnaryOpT for Abs {
669 const NAME: &'static str = "abs";
670 const KERNEL: &'static str = "uabs";
671 const V: Self = Abs;
672 #[inline(always)]
673 fn bf16(v: bf16) -> bf16 {
674 v.abs()
675 }
676 #[inline(always)]
677 fn f16(v: f16) -> f16 {
678 v.abs()
679 }
680 #[inline(always)]
681 fn f32(v: f32) -> f32 {
682 v.abs()
683 }
684 #[inline(always)]
685 fn f64(v: f64) -> f64 {
686 v.abs()
687 }
688 #[inline(always)]
689 fn u8(v: u8) -> u8 {
690 v
691 }
692 #[inline(always)]
693 fn u32(v: u32) -> u32 {
694 v
695 }
696 #[inline(always)]
697 fn i64(v: i64) -> i64 {
698 v.abs()
699 }
700}
701
702impl UnaryOpT for Ceil {
703 const NAME: &'static str = "ceil";
704 const KERNEL: &'static str = "uceil";
705 const V: Self = Ceil;
706 #[inline(always)]
707 fn bf16(v: bf16) -> bf16 {
708 v.ceil()
709 }
710 #[inline(always)]
711 fn f16(v: f16) -> f16 {
712 v.ceil()
713 }
714 #[inline(always)]
715 fn f32(v: f32) -> f32 {
716 v.ceil()
717 }
718 #[inline(always)]
719 fn f64(v: f64) -> f64 {
720 v.ceil()
721 }
722 #[inline(always)]
723 fn u8(v: u8) -> u8 {
724 v
725 }
726 #[inline(always)]
727 fn u32(v: u32) -> u32 {
728 v
729 }
730 #[inline(always)]
731 fn i64(v: i64) -> i64 {
732 v
733 }
734}
735
736impl UnaryOpT for Floor {
737 const NAME: &'static str = "floor";
738 const KERNEL: &'static str = "ufloor";
739 const V: Self = Floor;
740 #[inline(always)]
741 fn bf16(v: bf16) -> bf16 {
742 v.floor()
743 }
744 #[inline(always)]
745 fn f16(v: f16) -> f16 {
746 v.floor()
747 }
748 #[inline(always)]
749 fn f32(v: f32) -> f32 {
750 v.floor()
751 }
752 #[inline(always)]
753 fn f64(v: f64) -> f64 {
754 v.floor()
755 }
756 #[inline(always)]
757 fn u8(v: u8) -> u8 {
758 v
759 }
760 #[inline(always)]
761 fn u32(v: u32) -> u32 {
762 v
763 }
764 #[inline(always)]
765 fn i64(v: i64) -> i64 {
766 v
767 }
768}
769
770impl UnaryOpT for Round {
771 const NAME: &'static str = "round";
772 const KERNEL: &'static str = "uround";
773 const V: Self = Round;
774 #[inline(always)]
775 fn bf16(v: bf16) -> bf16 {
776 v.round()
777 }
778 #[inline(always)]
779 fn f16(v: f16) -> f16 {
780 v.round()
781 }
782 #[inline(always)]
783 fn f32(v: f32) -> f32 {
784 v.round()
785 }
786 #[inline(always)]
787 fn f64(v: f64) -> f64 {
788 v.round()
789 }
790 #[inline(always)]
791 fn u8(v: u8) -> u8 {
792 v
793 }
794 #[inline(always)]
795 fn u32(v: u32) -> u32 {
796 v
797 }
798 #[inline(always)]
799 fn i64(v: i64) -> i64 {
800 v
801 }
802}
803
804impl UnaryOpT for GeluErf {
805 const NAME: &'static str = "gelu_erf";
806 const KERNEL: &'static str = "ugelu_erf";
807 const V: Self = GeluErf;
808 #[inline(always)]
809 fn bf16(v: bf16) -> bf16 {
810 bf16::from_f64(Self::f64(v.to_f64()))
811 }
812 #[inline(always)]
813 fn f16(v: f16) -> f16 {
814 f16::from_f64(Self::f64(v.to_f64()))
815 }
816 #[inline(always)]
817 fn f32(v: f32) -> f32 {
818 Self::f64(v as f64) as f32
819 }
820 #[inline(always)]
821 fn f64(v: f64) -> f64 {
822 (crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
823 }
824 #[inline(always)]
825 fn u8(_: u8) -> u8 {
826 0
827 }
828 #[inline(always)]
829 fn u32(_: u32) -> u32 {
830 0
831 }
832 #[inline(always)]
833 fn i64(_: i64) -> i64 {
834 0
835 }
836}
837
838impl UnaryOpT for Relu {
839 const NAME: &'static str = "relu";
840 const KERNEL: &'static str = "urelu";
841 const V: Self = Relu;
842 #[inline(always)]
843 fn bf16(v: bf16) -> bf16 {
844 v.max(bf16::ZERO)
845 }
846 #[inline(always)]
847 fn f16(v: f16) -> f16 {
848 v.max(f16::ZERO)
849 }
850 #[inline(always)]
851 fn f32(v: f32) -> f32 {
852 v.max(0f32)
853 }
854 #[inline(always)]
855 fn f64(v: f64) -> f64 {
856 v.max(0f64)
857 }
858 #[inline(always)]
859 fn u8(v: u8) -> u8 {
860 v
861 }
862 #[inline(always)]
863 fn u32(v: u32) -> u32 {
864 v
865 }
866 #[inline(always)]
867 fn i64(v: i64) -> i64 {
868 v
869 }
870}
871
872#[derive(Clone)]
875pub struct BackpropOp(Option<Op>);
876
877impl BackpropOp {
878 pub(crate) fn none() -> Self {
879 BackpropOp(None)
880 }
881
882 pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
883 let op = if arg.track_op() {
884 Some(f(arg.clone()))
885 } else {
886 None
887 };
888 Self(op)
889 }
890
891 pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
892 let op = if arg1.track_op() || arg2.track_op() {
893 Some(f(arg1.clone(), arg2.clone()))
894 } else {
895 None
896 };
897 Self(op)
898 }
899
900 pub(crate) fn new3(
901 arg1: &Tensor,
902 arg2: &Tensor,
903 arg3: &Tensor,
904 f: impl Fn(Tensor, Tensor, Tensor) -> Op,
905 ) -> Self {
906 let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
907 Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
908 } else {
909 None
910 };
911 Self(op)
912 }
913
914 pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
915 let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
916 let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
917 Some(f(args))
918 } else {
919 None
920 };
921 Self(op)
922 }
923
924 pub(crate) fn is_none(&self) -> bool {
925 self.0.is_none()
926 }
927}
928
929impl std::ops::Deref for BackpropOp {
930 type Target = Option<Op>;
931 fn deref(&self) -> &Self::Target {
932 &self.0
933 }
934}