1#![allow(clippy::redundant_closure_call)]
4use crate::Tensor;
5use float8::F8E4M3 as f8e4m3;
6use half::{bf16, f16};
7use num_traits::float::Float;
8
9#[derive(Clone, Copy, PartialEq, Eq)]
10pub enum CmpOp {
11 Eq,
12 Ne,
13 Le,
14 Ge,
15 Lt,
16 Gt,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ReduceOp {
21 Sum,
22 Min,
23 Max,
24 ArgMin,
25 ArgMax,
26}
27
28impl ReduceOp {
29 pub(crate) fn name(&self) -> &'static str {
30 match self {
31 Self::ArgMax => "argmax",
32 Self::ArgMin => "argmin",
33 Self::Min => "min",
34 Self::Max => "max",
35 Self::Sum => "sum",
36 }
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum BinaryOp {
43 Add,
44 Mul,
45 Sub,
46 Div,
47 Maximum,
48 Minimum,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum UnaryOp {
54 Exp,
55 Log,
56 Sin,
57 Cos,
58 Abs,
59 Neg,
60 Recip,
61 Sqr,
62 Sqrt,
63 Gelu,
64 GeluErf,
65 Erf,
66 Relu,
67 Silu,
68 Tanh,
69 Floor,
70 Ceil,
71 Round,
72 Sign,
73}
74
75#[derive(Clone)]
76pub enum Op {
77 Binary(Tensor, Tensor, BinaryOp),
78 Unary(Tensor, UnaryOp),
79 Cmp(Tensor, CmpOp),
80 Reduce(Tensor, ReduceOp, Vec<usize>),
82 Matmul(Tensor, Tensor),
83 Gather(Tensor, Tensor, usize),
84 Scatter(Tensor, Tensor, Tensor, usize),
85 ScatterAdd(Tensor, Tensor, Tensor, usize),
86 IndexSelect(Tensor, Tensor, usize),
87 IndexAdd(Tensor, Tensor, Tensor, usize),
88 WhereCond(Tensor, Tensor, Tensor),
89
90 #[allow(dead_code)]
91 Conv1D {
92 arg: Tensor,
93 kernel: Tensor,
94 padding: usize,
95 stride: usize,
96 dilation: usize,
97 },
98
99 #[allow(dead_code)]
100 ConvTranspose1D {
101 arg: Tensor,
102 kernel: Tensor,
103 padding: usize,
104 output_padding: usize,
105 stride: usize,
106 dilation: usize,
107 },
108
109 #[allow(dead_code)]
110 Conv2D {
111 arg: Tensor,
112 kernel: Tensor,
113 padding: usize,
114 stride: usize,
115 dilation: usize,
116 },
117
118 #[allow(dead_code)]
119 ConvTranspose2D {
120 arg: Tensor,
121 kernel: Tensor,
122 padding: usize,
123 output_padding: usize,
124 stride: usize,
125 dilation: usize,
126 },
127
128 AvgPool2D {
129 arg: Tensor,
130 kernel_size: (usize, usize),
131 stride: (usize, usize),
132 },
133
134 MaxPool2D {
135 arg: Tensor,
136 kernel_size: (usize, usize),
137 stride: (usize, usize),
138 },
139
140 UpsampleNearest1D {
141 arg: Tensor,
142 target_size: usize,
143 },
144 UpsampleNearest2D {
145 arg: Tensor,
146 target_h: usize,
147 target_w: usize,
148 },
149 UpsampleBilinear2D {
150 arg: Tensor,
151 target_h: usize,
152 target_w: usize,
153 align_corners: bool,
154 },
155
156 Cat(Vec<Tensor>, usize),
157
158 #[allow(dead_code)] Affine {
160 arg: Tensor,
161 mul: f64,
162 add: f64,
163 },
164 ToDType(Tensor),
165 Copy(Tensor),
166 Broadcast(Tensor),
167 Narrow(Tensor, usize, usize, usize),
168 SliceScatter0(Tensor, Tensor, usize),
169 Reshape(Tensor),
170 ToDevice(Tensor),
171 Transpose(Tensor, usize, usize),
172 Permute(Tensor, Vec<usize>),
173 Elu(Tensor, f64),
174 Powf(Tensor, f64),
175 CustomOp1(
176 Tensor,
177 std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
178 ),
179 CustomOp2(
180 Tensor,
181 Tensor,
182 std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
183 ),
184 CustomOp3(
185 Tensor,
186 Tensor,
187 Tensor,
188 std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
189 ),
190}
191
192pub trait UnaryOpT {
193 const NAME: &'static str;
194 const KERNEL: &'static str;
195 const V: Self;
196 fn bf16(v1: bf16) -> bf16;
197 fn f16(v1: f16) -> f16;
198 fn f32(v1: f32) -> f32;
199 fn f64(v1: f64) -> f64;
200 fn u8(v1: u8) -> u8;
201 fn u32(v1: u32) -> u32;
202 fn i16(v1: i16) -> i16;
203 fn i32(v1: i32) -> i32;
204 fn i64(v1: i64) -> i64;
205 fn f8e4m3(v1: f8e4m3) -> f8e4m3;
206
207 #[inline(always)]
208 fn bf16_vec(xs: &[bf16], ys: &mut [bf16]) {
209 xs.iter().zip(ys).for_each(|(&x, y)| *y = Self::bf16(x))
210 }
211 #[inline(always)]
212 fn f16_vec(xs: &[f16], ys: &mut [f16]) {
213 xs.iter().zip(ys).for_each(|(&x, y)| *y = Self::f16(x))
214 }
215 #[inline(always)]
216 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
217 xs.iter().zip(ys).for_each(|(&x, y)| *y = Self::f32(x))
218 }
219 #[inline(always)]
220 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
221 xs.iter().zip(ys).for_each(|(&x, y)| *y = Self::f64(x))
222 }
223}
224
225pub trait BinaryOpT {
226 const NAME: &'static str;
227 const KERNEL: &'static str;
228 const V: Self;
229 fn bf16(v1: bf16, v2: bf16) -> bf16;
230 fn f16(v1: f16, v2: f16) -> f16;
231 fn f32(v1: f32, v2: f32) -> f32;
232 fn f64(v1: f64, v2: f64) -> f64;
233 fn u8(v1: u8, v2: u8) -> u8;
234 fn u32(v1: u32, v2: u32) -> u32;
235 fn i16(v1: i16, v2: i16) -> i16;
236 fn i32(v1: i32, v2: i32) -> i32;
237 fn i64(v1: i64, v2: i64) -> i64;
238 fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3;
239
240 #[inline(always)]
241 fn bf16_vec(xs1: &[bf16], xs2: &[bf16], ys: &mut [bf16]) {
242 xs1.iter()
243 .zip(xs2)
244 .zip(ys)
245 .for_each(|((&a, &b), y)| *y = Self::bf16(a, b))
246 }
247 #[inline(always)]
248 fn f16_vec(xs1: &[f16], xs2: &[f16], ys: &mut [f16]) {
249 xs1.iter()
250 .zip(xs2)
251 .zip(ys)
252 .for_each(|((&a, &b), y)| *y = Self::f16(a, b))
253 }
254 #[inline(always)]
255 fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
256 xs1.iter()
257 .zip(xs2)
258 .zip(ys)
259 .for_each(|((&a, &b), y)| *y = Self::f32(a, b))
260 }
261 #[inline(always)]
262 fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
263 xs1.iter()
264 .zip(xs2)
265 .zip(ys)
266 .for_each(|((&a, &b), y)| *y = Self::f64(a, b))
267 }
268 #[inline(always)]
269 fn u8_vec(xs1: &[u8], xs2: &[u8], ys: &mut [u8]) {
270 xs1.iter()
271 .zip(xs2)
272 .zip(ys)
273 .for_each(|((&a, &b), y)| *y = Self::u8(a, b))
274 }
275 #[inline(always)]
276 fn u32_vec(xs1: &[u32], xs2: &[u32], ys: &mut [u32]) {
277 xs1.iter()
278 .zip(xs2)
279 .zip(ys)
280 .for_each(|((&a, &b), y)| *y = Self::u32(a, b))
281 }
282 #[inline(always)]
283 fn i16_vec(xs1: &[i16], xs2: &[i16], ys: &mut [i16]) {
284 xs1.iter()
285 .zip(xs2)
286 .zip(ys)
287 .for_each(|((&a, &b), y)| *y = Self::i16(a, b))
288 }
289 #[inline(always)]
290 fn i32_vec(xs1: &[i32], xs2: &[i32], ys: &mut [i32]) {
291 xs1.iter()
292 .zip(xs2)
293 .zip(ys)
294 .for_each(|((&a, &b), y)| *y = Self::i32(a, b))
295 }
296 #[inline(always)]
297 fn i64_vec(xs1: &[i64], xs2: &[i64], ys: &mut [i64]) {
298 xs1.iter()
299 .zip(xs2)
300 .zip(ys)
301 .for_each(|((&a, &b), y)| *y = Self::i64(a, b))
302 }
303
304 #[inline(always)]
308 fn bf16_scalar_vec(scalar: bf16, xs: &[bf16], ys: &mut [bf16]) {
309 xs.iter()
310 .zip(ys)
311 .for_each(|(&x, y)| *y = Self::bf16(x, scalar))
312 }
313 #[inline(always)]
314 fn f16_scalar_vec(scalar: f16, xs: &[f16], ys: &mut [f16]) {
315 xs.iter()
316 .zip(ys)
317 .for_each(|(&x, y)| *y = Self::f16(x, scalar))
318 }
319 #[inline(always)]
320 fn f32_scalar_vec(scalar: f32, xs: &[f32], ys: &mut [f32]) {
321 xs.iter()
322 .zip(ys)
323 .for_each(|(&x, y)| *y = Self::f32(x, scalar))
324 }
325 #[inline(always)]
326 fn f64_scalar_vec(scalar: f64, xs: &[f64], ys: &mut [f64]) {
327 xs.iter()
328 .zip(ys)
329 .for_each(|(&x, y)| *y = Self::f64(x, scalar))
330 }
331 #[inline(always)]
332 fn u8_scalar_vec(scalar: u8, xs: &[u8], ys: &mut [u8]) {
333 xs.iter()
334 .zip(ys)
335 .for_each(|(&x, y)| *y = Self::u8(x, scalar))
336 }
337 #[inline(always)]
338 fn u32_scalar_vec(scalar: u32, xs: &[u32], ys: &mut [u32]) {
339 xs.iter()
340 .zip(ys)
341 .for_each(|(&x, y)| *y = Self::u32(x, scalar))
342 }
343 #[inline(always)]
344 fn i16_scalar_vec(scalar: i16, xs: &[i16], ys: &mut [i16]) {
345 xs.iter()
346 .zip(ys)
347 .for_each(|(&x, y)| *y = Self::i16(x, scalar))
348 }
349 #[inline(always)]
350 fn i32_scalar_vec(scalar: i32, xs: &[i32], ys: &mut [i32]) {
351 xs.iter()
352 .zip(ys)
353 .for_each(|(&x, y)| *y = Self::i32(x, scalar))
354 }
355 #[inline(always)]
356 fn i64_scalar_vec(scalar: i64, xs: &[i64], ys: &mut [i64]) {
357 xs.iter()
358 .zip(ys)
359 .for_each(|(&x, y)| *y = Self::i64(x, scalar))
360 }
361}
362
363pub struct Add;
364pub struct Div;
365pub struct Mul;
366pub struct Sub;
367pub struct Maximum;
368pub struct Minimum;
369pub struct Exp;
370pub struct Log;
371pub struct Sin;
372pub struct Cos;
373pub struct Abs;
374pub struct Neg;
375pub struct Recip;
376pub struct Sqr;
377pub struct Sqrt;
378pub struct Gelu;
379pub struct GeluErf;
380pub struct Erf;
381pub struct Relu;
382pub struct Silu;
383pub struct Tanh;
384pub struct Floor;
385pub struct Ceil;
386pub struct Round;
387pub struct Sign;
388
389macro_rules! bin_op {
393 ($op:ident, $name:ident, $e:expr $(, $vec_op:ident)?) => {
394 impl BinaryOpT for $op {
395 const NAME: &'static str = stringify!($name);
396 const KERNEL: &'static str = concat!("b", stringify!($name));
397 const V: Self = $op;
398 #[inline(always)]
399 fn bf16(v1: bf16, v2: bf16) -> bf16 { $e(v1, v2) }
400 #[inline(always)]
401 fn f16(v1: f16, v2: f16) -> f16 { $e(v1, v2) }
402 #[inline(always)]
403 fn f32(v1: f32, v2: f32) -> f32 { $e(v1, v2) }
404 #[inline(always)]
405 fn f64(v1: f64, v2: f64) -> f64 { $e(v1, v2) }
406 #[inline(always)]
407 fn u8(v1: u8, v2: u8) -> u8 { $e(v1, v2) }
408 #[inline(always)]
409 fn u32(v1: u32, v2: u32) -> u32 { $e(v1, v2) }
410 #[inline(always)]
411 fn i16(v1: i16, v2: i16) -> i16 { $e(v1, v2) }
412 #[inline(always)]
413 fn i32(v1: i32, v2: i32) -> i32 { $e(v1, v2) }
414 #[inline(always)]
415 fn i64(v1: i64, v2: i64) -> i64 { $e(v1, v2) }
416 #[inline(always)]
417 fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3 { $e(v1, v2) }
418 $(
419 #[inline(always)]
420 fn f32_vec(lhs: &[f32], rhs: &[f32], res: &mut [f32]) {
421 <f32 as crate::cpu::kernels::VecOps>::$vec_op(lhs, rhs, res)
422 }
423 #[inline(always)]
424 fn f64_vec(lhs: &[f64], rhs: &[f64], res: &mut [f64]) {
425 <f64 as crate::cpu::kernels::VecOps>::$vec_op(lhs, rhs, res)
426 }
427 #[inline(always)]
428 fn bf16_vec(lhs: &[bf16], rhs: &[bf16], res: &mut [bf16]) {
429 <bf16 as crate::cpu::kernels::VecOps>::$vec_op(lhs, rhs, res)
430 }
431 #[inline(always)]
432 fn f16_vec(lhs: &[f16], rhs: &[f16], res: &mut [f16]) {
433 <f16 as crate::cpu::kernels::VecOps>::$vec_op(lhs, rhs, res)
434 }
435 #[inline(always)]
436 fn bf16_scalar_vec(scalar: bf16, xs: &[bf16], ys: &mut [bf16]) {
437 <bf16 as crate::cpu::kernels::VecOps>::scalar_add(scalar, xs, ys)
438 }
439 #[inline(always)]
440 fn f16_scalar_vec(scalar: f16, xs: &[f16], ys: &mut [f16]) {
441 <f16 as crate::cpu::kernels::VecOps>::scalar_add(scalar, xs, ys)
442 }
443 #[inline(always)]
444 fn f32_scalar_vec(scalar: f32, xs: &[f32], ys: &mut [f32]) {
445 <f32 as crate::cpu::kernels::VecOps>::scalar_add(scalar, xs, ys)
446 }
447 #[inline(always)]
448 fn f64_scalar_vec(scalar: f64, xs: &[f64], ys: &mut [f64]) {
449 <f64 as crate::cpu::kernels::VecOps>::scalar_add(scalar, xs, ys)
450 }
451 )?
452 }
453 };
454}
455
456bin_op!(Add, add, |v1, v2| v1 + v2, vec_add);
457bin_op!(Sub, sub, |v1, v2| v1 - v2);
458bin_op!(Mul, mul, |v1, v2| v1 * v2);
459bin_op!(Div, div, |v1, v2| v1 / v2);
460bin_op!(Minimum, minimum, |v1, v2| if v1 > v2 { v2 } else { v1 });
461bin_op!(Maximum, maximum, |v1, v2| if v1 < v2 { v2 } else { v1 });
462
463#[allow(clippy::redundant_closure_call)]
464macro_rules! unary_op {
465 ($op: ident, $name: literal, $a: ident, $e: expr) => {
466 impl UnaryOpT for $op {
467 const NAME: &'static str = $name;
468 const KERNEL: &'static str = concat!("u", $name);
469 const V: Self = $op;
470 #[inline(always)]
471 fn bf16($a: bf16) -> bf16 {
472 $e
473 }
474 #[inline(always)]
475 fn f16($a: f16) -> f16 {
476 $e
477 }
478 #[inline(always)]
479 fn f32($a: f32) -> f32 {
480 $e
481 }
482 #[inline(always)]
483 fn f64($a: f64) -> f64 {
484 $e
485 }
486 #[inline(always)]
487 fn u8(_: u8) -> u8 {
488 todo!("no unary function for u8")
489 }
490 #[inline(always)]
491 fn u32(_: u32) -> u32 {
492 todo!("no unary function for u32")
493 }
494 #[inline(always)]
495 fn i16(_: i16) -> i16 {
496 todo!("no unary function for i16")
497 }
498 #[inline(always)]
499 fn i32(_: i32) -> i32 {
500 todo!("no unary function for i32")
501 }
502 #[inline(always)]
503 fn i64(_: i64) -> i64 {
504 todo!("no unary function for i64")
505 }
506 #[inline(always)]
507 fn f8e4m3($a: f8e4m3) -> f8e4m3 {
508 $e
509 }
510 }
511 };
512
513 ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
514 impl UnaryOpT for $op {
515 const NAME: &'static str = $name;
516 const KERNEL: &'static str = concat!("u", $name);
517 const V: Self = $op;
518 #[inline(always)]
519 fn bf16($a: bf16) -> bf16 {
520 $e
521 }
522 #[inline(always)]
523 fn f16($a: f16) -> f16 {
524 $e
525 }
526 #[inline(always)]
527 fn f32($a: f32) -> f32 {
528 $e
529 }
530 #[inline(always)]
531 fn f64($a: f64) -> f64 {
532 $e
533 }
534 #[inline(always)]
535 fn u8(_: u8) -> u8 {
536 todo!("no unary function for u8")
537 }
538 #[inline(always)]
539 fn u32(_: u32) -> u32 {
540 todo!("no unary function for u32")
541 }
542 #[inline(always)]
543 fn i16(_: i16) -> i16 {
544 todo!("no unary function for i16")
545 }
546 #[inline(always)]
547 fn i32(_: i32) -> i32 {
548 todo!("no unary function for i32")
549 }
550 #[inline(always)]
551 fn i64(_: i64) -> i64 {
552 todo!("no unary function for i64")
553 }
554 #[inline(always)]
555 fn f8e4m3($a: f8e4m3) -> f8e4m3 {
556 $e
557 }
558
559 #[cfg(feature = "mkl")]
560 #[inline(always)]
561 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
562 crate::mkl::$f32_vec(xs, ys)
563 }
564 #[cfg(feature = "mkl")]
565 #[inline(always)]
566 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
567 crate::mkl::$f64_vec(xs, ys)
568 }
569
570 #[cfg(feature = "accelerate")]
571 #[inline(always)]
572 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
573 crate::accelerate::$f32_vec(xs, ys)
574 }
575 #[cfg(feature = "accelerate")]
576 #[inline(always)]
577 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
578 crate::accelerate::$f64_vec(xs, ys)
579 }
580 }
581 };
582}
583
584unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
585unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
586unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
587unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
588unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
589unary_op!(Neg, "neg", v, -v);
590unary_op!(Recip, "recip", v, v.recip());
591unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
592unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
593
594#[allow(clippy::excessive_precision)]
597const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
598#[allow(clippy::excessive_precision)]
599const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
600
601impl UnaryOpT for Gelu {
605 const NAME: &'static str = "gelu";
606 const V: Self = Gelu;
607 #[inline(always)]
608 fn bf16(v: bf16) -> bf16 {
609 bf16::from_f32_const(0.5)
610 * v
611 * (bf16::ONE
612 + bf16::tanh(
613 bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
614 * v
615 * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
616 ))
617 }
618 #[inline(always)]
619 fn f16(v: f16) -> f16 {
620 f16::from_f32_const(0.5)
621 * v
622 * (f16::ONE
623 + f16::tanh(
624 f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
625 * v
626 * (f16::ONE + f16::from_f32_const(0.044715) * v * v),
627 ))
628 }
629 #[inline(always)]
630 fn f32(v: f32) -> f32 {
631 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
632 }
633 #[inline(always)]
634 fn f64(v: f64) -> f64 {
635 0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
636 }
637 #[inline(always)]
638 fn u8(_: u8) -> u8 {
639 0
640 }
641 #[inline(always)]
642 fn u32(_: u32) -> u32 {
643 0
644 }
645 #[inline(always)]
646 fn i16(_: i16) -> i16 {
647 0
648 }
649 #[inline(always)]
650 fn i32(_: i32) -> i32 {
651 0
652 }
653 #[inline(always)]
654 fn i64(_: i64) -> i64 {
655 0
656 }
657 #[inline(always)]
658 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
659 f8e4m3::from_f32(0.5)
660 * v
661 * (f8e4m3::ONE
662 + f8e4m3::tanh(
663 f8e4m3::from_f32(SQRT_TWO_OVER_PI_F32)
664 * v
665 * (f8e4m3::ONE + f8e4m3::from_f32(0.044715) * v * v),
666 ))
667 }
668 const KERNEL: &'static str = "ugelu";
669
670 #[cfg(feature = "mkl")]
671 #[inline(always)]
672 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
673 crate::mkl::vs_gelu(xs, ys)
674 }
675
676 #[cfg(feature = "mkl")]
677 #[inline(always)]
678 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
679 crate::mkl::vd_gelu(xs, ys)
680 }
681
682 #[cfg(feature = "accelerate")]
683 #[inline(always)]
684 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
685 crate::accelerate::vs_gelu(xs, ys)
686 }
687
688 #[cfg(feature = "accelerate")]
689 #[inline(always)]
690 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
691 crate::accelerate::vd_gelu(xs, ys)
692 }
693}
694
695impl UnaryOpT for Erf {
698 const NAME: &'static str = "erf";
699 const KERNEL: &'static str = "uerf";
700 const V: Self = Erf;
701 #[inline(always)]
702 fn bf16(v: bf16) -> bf16 {
703 bf16::from_f64(Self::f64(v.to_f64()))
704 }
705 #[inline(always)]
706 fn f16(v: f16) -> f16 {
707 f16::from_f64(Self::f64(v.to_f64()))
708 }
709 #[inline(always)]
710 fn f32(v: f32) -> f32 {
711 crate::cpu::erf::erf_f32(v)
712 }
713 #[inline(always)]
714 fn f64(v: f64) -> f64 {
715 crate::cpu::erf::erf_f64(v)
716 }
717 #[inline(always)]
718 fn u8(_: u8) -> u8 {
719 0
720 }
721 #[inline(always)]
722 fn u32(_: u32) -> u32 {
723 0
724 }
725 #[inline(always)]
726 fn i16(_: i16) -> i16 {
727 0
728 }
729 #[inline(always)]
730 fn i32(_: i32) -> i32 {
731 0
732 }
733 #[inline(always)]
734 fn i64(_: i64) -> i64 {
735 0
736 }
737 #[inline(always)]
738 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
739 f8e4m3::from_f64(Self::f64(v.to_f64()))
740 }
741}
742
743impl UnaryOpT for Silu {
745 const NAME: &'static str = "silu";
746 const V: Self = Silu;
747 #[inline(always)]
748 fn bf16(v: bf16) -> bf16 {
749 v / (bf16::ONE + (-v).exp())
750 }
751 #[inline(always)]
752 fn f16(v: f16) -> f16 {
753 v / (f16::ONE + (-v).exp())
754 }
755 #[inline(always)]
756 fn f32(v: f32) -> f32 {
757 v / (1.0 + (-v).exp())
758 }
759 #[inline(always)]
760 fn f64(v: f64) -> f64 {
761 v / (1.0 + (-v).exp())
762 }
763 #[inline(always)]
764 fn u8(_: u8) -> u8 {
765 0
766 }
767 #[inline(always)]
768 fn u32(_: u32) -> u32 {
769 0
770 }
771 #[inline(always)]
772 fn i16(_: i16) -> i16 {
773 0
774 }
775 #[inline(always)]
776 fn i32(_: i32) -> i32 {
777 0
778 }
779 #[inline(always)]
780 fn i64(_: i64) -> i64 {
781 0
782 }
783 #[inline(always)]
784 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
785 v / (f8e4m3::ONE + (-v).exp())
786 }
787 const KERNEL: &'static str = "usilu";
788
789 #[cfg(feature = "mkl")]
790 #[inline(always)]
791 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
792 crate::mkl::vs_silu(xs, ys)
793 }
794
795 #[cfg(feature = "mkl")]
796 #[inline(always)]
797 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
798 crate::mkl::vd_silu(xs, ys)
799 }
800
801 #[cfg(feature = "accelerate")]
802 #[inline(always)]
803 fn f32_vec(xs: &[f32], ys: &mut [f32]) {
804 crate::accelerate::vs_silu(xs, ys)
805 }
806
807 #[cfg(feature = "accelerate")]
808 #[inline(always)]
809 fn f64_vec(xs: &[f64], ys: &mut [f64]) {
810 crate::accelerate::vd_silu(xs, ys)
811 }
812}
813
814impl UnaryOpT for Abs {
815 const NAME: &'static str = "abs";
816 const KERNEL: &'static str = "uabs";
817 const V: Self = Abs;
818 #[inline(always)]
819 fn bf16(v: bf16) -> bf16 {
820 v.abs()
821 }
822 #[inline(always)]
823 fn f16(v: f16) -> f16 {
824 v.abs()
825 }
826 #[inline(always)]
827 fn f32(v: f32) -> f32 {
828 v.abs()
829 }
830 #[inline(always)]
831 fn f64(v: f64) -> f64 {
832 v.abs()
833 }
834 #[inline(always)]
835 fn u8(v: u8) -> u8 {
836 v
837 }
838 #[inline(always)]
839 fn u32(v: u32) -> u32 {
840 v
841 }
842 #[inline(always)]
843 fn i16(v: i16) -> i16 {
844 v.abs()
845 }
846 #[inline(always)]
847 fn i32(v: i32) -> i32 {
848 v.abs()
849 }
850 #[inline(always)]
851 fn i64(v: i64) -> i64 {
852 v.abs()
853 }
854 #[inline(always)]
855 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
856 v.abs()
857 }
858}
859
860impl UnaryOpT for Ceil {
861 const NAME: &'static str = "ceil";
862 const KERNEL: &'static str = "uceil";
863 const V: Self = Ceil;
864 #[inline(always)]
865 fn bf16(v: bf16) -> bf16 {
866 v.ceil()
867 }
868 #[inline(always)]
869 fn f16(v: f16) -> f16 {
870 v.ceil()
871 }
872 #[inline(always)]
873 fn f32(v: f32) -> f32 {
874 v.ceil()
875 }
876 #[inline(always)]
877 fn f64(v: f64) -> f64 {
878 v.ceil()
879 }
880 #[inline(always)]
881 fn u8(v: u8) -> u8 {
882 v
883 }
884 #[inline(always)]
885 fn u32(v: u32) -> u32 {
886 v
887 }
888 #[inline(always)]
889 fn i16(v: i16) -> i16 {
890 v
891 }
892 #[inline(always)]
893 fn i32(v: i32) -> i32 {
894 v
895 }
896 #[inline(always)]
897 fn i64(v: i64) -> i64 {
898 v
899 }
900 #[inline(always)]
901 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
902 v.ceil()
903 }
904}
905
906impl UnaryOpT for Floor {
907 const NAME: &'static str = "floor";
908 const KERNEL: &'static str = "ufloor";
909 const V: Self = Floor;
910 #[inline(always)]
911 fn bf16(v: bf16) -> bf16 {
912 v.floor()
913 }
914 #[inline(always)]
915 fn f16(v: f16) -> f16 {
916 v.floor()
917 }
918 #[inline(always)]
919 fn f32(v: f32) -> f32 {
920 v.floor()
921 }
922 #[inline(always)]
923 fn f64(v: f64) -> f64 {
924 v.floor()
925 }
926 #[inline(always)]
927 fn u8(v: u8) -> u8 {
928 v
929 }
930 #[inline(always)]
931 fn u32(v: u32) -> u32 {
932 v
933 }
934 #[inline(always)]
935 fn i16(v: i16) -> i16 {
936 v
937 }
938 #[inline(always)]
939 fn i32(v: i32) -> i32 {
940 v
941 }
942 #[inline(always)]
943 fn i64(v: i64) -> i64 {
944 v
945 }
946 #[inline(always)]
947 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
948 v.floor()
949 }
950}
951
952impl UnaryOpT for Round {
953 const NAME: &'static str = "round";
954 const KERNEL: &'static str = "uround";
955 const V: Self = Round;
956 #[inline(always)]
957 fn bf16(v: bf16) -> bf16 {
958 v.round()
959 }
960 #[inline(always)]
961 fn f16(v: f16) -> f16 {
962 v.round()
963 }
964 #[inline(always)]
965 fn f32(v: f32) -> f32 {
966 v.round()
967 }
968 #[inline(always)]
969 fn f64(v: f64) -> f64 {
970 v.round()
971 }
972 #[inline(always)]
973 fn u8(v: u8) -> u8 {
974 v
975 }
976 #[inline(always)]
977 fn u32(v: u32) -> u32 {
978 v
979 }
980 #[inline(always)]
981 fn i16(v: i16) -> i16 {
982 v
983 }
984 #[inline(always)]
985 fn i32(v: i32) -> i32 {
986 v
987 }
988 #[inline(always)]
989 fn i64(v: i64) -> i64 {
990 v
991 }
992 #[inline(always)]
993 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
994 v.round()
995 }
996}
997
998impl UnaryOpT for GeluErf {
999 const NAME: &'static str = "gelu_erf";
1000 const KERNEL: &'static str = "ugelu_erf";
1001 const V: Self = GeluErf;
1002 #[inline(always)]
1003 fn bf16(v: bf16) -> bf16 {
1004 bf16::from_f64(Self::f64(v.to_f64()))
1005 }
1006 #[inline(always)]
1007 fn f16(v: f16) -> f16 {
1008 f16::from_f64(Self::f64(v.to_f64()))
1009 }
1010 #[inline(always)]
1011 fn f32(v: f32) -> f32 {
1012 (crate::cpu::erf::erf_f32(v * std::f32::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v
1013 }
1014 #[inline(always)]
1015 fn f64(v: f64) -> f64 {
1016 (crate::cpu::erf::erf_f64(v * std::f64::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v
1017 }
1018 #[inline(always)]
1019 fn u8(_: u8) -> u8 {
1020 0
1021 }
1022 #[inline(always)]
1023 fn u32(_: u32) -> u32 {
1024 0
1025 }
1026 #[inline(always)]
1027 fn i16(_: i16) -> i16 {
1028 0
1029 }
1030 #[inline(always)]
1031 fn i32(_: i32) -> i32 {
1032 0
1033 }
1034 #[inline(always)]
1035 fn i64(_: i64) -> i64 {
1036 0
1037 }
1038 #[inline(always)]
1039 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
1040 f8e4m3::from_f32(Self::f32(v.to_f32()))
1041 }
1042}
1043
1044impl UnaryOpT for Relu {
1045 const NAME: &'static str = "relu";
1046 const KERNEL: &'static str = "urelu";
1047 const V: Self = Relu;
1048 #[inline(always)]
1049 fn bf16(v: bf16) -> bf16 {
1050 v.max(bf16::ZERO)
1051 }
1052 #[inline(always)]
1053 fn f16(v: f16) -> f16 {
1054 v.max(f16::ZERO)
1055 }
1056 #[inline(always)]
1057 fn f32(v: f32) -> f32 {
1058 v.max(0f32)
1059 }
1060 #[inline(always)]
1061 fn f64(v: f64) -> f64 {
1062 v.max(0f64)
1063 }
1064 #[inline(always)]
1065 fn u8(v: u8) -> u8 {
1066 v
1067 }
1068 #[inline(always)]
1069 fn u32(v: u32) -> u32 {
1070 v
1071 }
1072 #[inline(always)]
1073 fn i16(v: i16) -> i16 {
1074 v.max(0)
1075 }
1076 #[inline(always)]
1077 fn i32(v: i32) -> i32 {
1078 v.max(0)
1079 }
1080 #[inline(always)]
1081 fn i64(v: i64) -> i64 {
1082 v.max(0)
1083 }
1084 #[inline(always)]
1085 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
1086 v.max(f8e4m3::ZERO)
1087 }
1088}
1089
1090#[derive(Clone)]
1093pub struct BackpropOp(Option<Op>);
1094
1095impl BackpropOp {
1096 pub fn none() -> Self {
1097 BackpropOp(None)
1098 }
1099
1100 pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
1101 let op = if arg.track_op() {
1102 Some(f(arg.clone()))
1103 } else {
1104 None
1105 };
1106 Self(op)
1107 }
1108
1109 pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
1110 let op = if arg1.track_op() || arg2.track_op() {
1111 Some(f(arg1.clone(), arg2.clone()))
1112 } else {
1113 None
1114 };
1115 Self(op)
1116 }
1117
1118 pub(crate) fn new3(
1119 arg1: &Tensor,
1120 arg2: &Tensor,
1121 arg3: &Tensor,
1122 f: impl Fn(Tensor, Tensor, Tensor) -> Op,
1123 ) -> Self {
1124 let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
1125 Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
1126 } else {
1127 None
1128 };
1129 Self(op)
1130 }
1131
1132 pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
1133 let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
1134 let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
1135 Some(f(args))
1136 } else {
1137 None
1138 };
1139 Self(op)
1140 }
1141
1142 pub(crate) fn is_none(&self) -> bool {
1143 self.0.is_none()
1144 }
1145}
1146
1147impl std::ops::Deref for BackpropOp {
1148 type Target = Option<Op>;
1149 fn deref(&self) -> &Self::Target {
1150 &self.0
1151 }
1152}
1153
1154impl UnaryOpT for Sign {
1155 const NAME: &'static str = "sign";
1156 const KERNEL: &'static str = "usign";
1157 const V: Self = Sign;
1158 #[inline(always)]
1159 fn bf16(v: bf16) -> bf16 {
1160 bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
1161 }
1162 #[inline(always)]
1163 fn f16(v: f16) -> f16 {
1164 f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
1165 }
1166 #[inline(always)]
1167 fn f32(v: f32) -> f32 {
1168 f32::from(v > 0.) - f32::from(v < 0.)
1169 }
1170 #[inline(always)]
1171 fn f64(v: f64) -> f64 {
1172 f64::from(v > 0.) - f64::from(v < 0.)
1173 }
1174 #[inline(always)]
1175 fn u8(v: u8) -> u8 {
1176 u8::min(1, v)
1177 }
1178 #[inline(always)]
1179 fn u32(v: u32) -> u32 {
1180 u32::min(1, v)
1181 }
1182 #[inline(always)]
1183 fn i16(v: i16) -> i16 {
1184 (v > 0) as i16 - (v < 0) as i16
1185 }
1186 #[inline(always)]
1187 fn i32(v: i32) -> i32 {
1188 (v > 0) as i32 - (v < 0) as i32
1189 }
1190 #[inline(always)]
1191 fn i64(v: i64) -> i64 {
1192 (v > 0) as i64 - (v < 0) as i64
1193 }
1194 #[inline(always)]
1195 fn f8e4m3(v: f8e4m3) -> f8e4m3 {
1196 if v > f8e4m3::ZERO {
1197 f8e4m3::ONE
1198 } else if v < f8e4m3::ZERO {
1199 -f8e4m3::ONE
1200 } else {
1201 f8e4m3::ZERO
1202 }
1203 }
1204}