candle_core/
op.rs

1//! Tensor Operation Enums and Traits
2//!
3#![allow(clippy::redundant_closure_call)]
4use crate::Tensor;
5use float8::F8E4M3 as f8e4m3;
6use half::{bf16, f16};
7use num_traits::float::Float;
8
9#[derive(Clone, Copy, PartialEq, Eq)]
10pub enum CmpOp {
11    Eq,
12    Ne,
13    Le,
14    Ge,
15    Lt,
16    Gt,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ReduceOp {
21    Sum,
22    Min,
23    Max,
24    ArgMin,
25    ArgMax,
26}
27
28impl ReduceOp {
29    pub(crate) fn name(&self) -> &'static str {
30        match self {
31            Self::ArgMax => "argmax",
32            Self::ArgMin => "argmin",
33            Self::Min => "min",
34            Self::Max => "max",
35            Self::Sum => "sum",
36        }
37    }
38}
39
40// These ops return the same type as their input type.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum BinaryOp {
43    Add,
44    Mul,
45    Sub,
46    Div,
47    Maximum,
48    Minimum,
49}
50
51// Unary ops with no argument
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum UnaryOp {
54    Exp,
55    Log,
56    Sin,
57    Cos,
58    Abs,
59    Neg,
60    Recip,
61    Sqr,
62    Sqrt,
63    Gelu,
64    GeluErf,
65    Erf,
66    Relu,
67    Silu,
68    Tanh,
69    Floor,
70    Ceil,
71    Round,
72    Sign,
73}
74
75#[derive(Clone)]
76pub enum Op {
77    Binary(Tensor, Tensor, BinaryOp),
78    Unary(Tensor, UnaryOp),
79    Cmp(Tensor, CmpOp),
80    // The third argument is the reduced shape with `keepdim=true`.
81    Reduce(Tensor, ReduceOp, Vec<usize>),
82    Matmul(Tensor, Tensor),
83    Gather(Tensor, Tensor, usize),
84    Scatter(Tensor, Tensor, Tensor, usize),
85    ScatterAdd(Tensor, Tensor, Tensor, usize),
86    IndexSelect(Tensor, Tensor, usize),
87    IndexAdd(Tensor, Tensor, Tensor, usize),
88    WhereCond(Tensor, Tensor, Tensor),
89
90    #[allow(dead_code)]
91    Conv1D {
92        arg: Tensor,
93        kernel: Tensor,
94        padding: usize,
95        stride: usize,
96        dilation: usize,
97    },
98
99    #[allow(dead_code)]
100    ConvTranspose1D {
101        arg: Tensor,
102        kernel: Tensor,
103        padding: usize,
104        output_padding: usize,
105        stride: usize,
106        dilation: usize,
107    },
108
109    #[allow(dead_code)]
110    Conv2D {
111        arg: Tensor,
112        kernel: Tensor,
113        padding: usize,
114        stride: usize,
115        dilation: usize,
116    },
117
118    #[allow(dead_code)]
119    ConvTranspose2D {
120        arg: Tensor,
121        kernel: Tensor,
122        padding: usize,
123        output_padding: usize,
124        stride: usize,
125        dilation: usize,
126    },
127
128    AvgPool2D {
129        arg: Tensor,
130        kernel_size: (usize, usize),
131        stride: (usize, usize),
132    },
133
134    MaxPool2D {
135        arg: Tensor,
136        kernel_size: (usize, usize),
137        stride: (usize, usize),
138    },
139
140    UpsampleNearest1D {
141        arg: Tensor,
142        target_size: usize,
143    },
144    UpsampleNearest2D {
145        arg: Tensor,
146        target_h: usize,
147        target_w: usize,
148    },
149    UpsampleBilinear2D {
150        arg: Tensor,
151        target_h: usize,
152        target_w: usize,
153        align_corners: bool,
154    },
155
156    Cat(Vec<Tensor>, usize),
157
158    #[allow(dead_code)] // add is currently unused.
159    Affine {
160        arg: Tensor,
161        mul: f64,
162        add: f64,
163    },
164    ToDType(Tensor),
165    Copy(Tensor),
166    Broadcast(Tensor),
167    Narrow(Tensor, usize, usize, usize),
168    SliceScatter0(Tensor, Tensor, usize),
169    Reshape(Tensor),
170    ToDevice(Tensor),
171    Transpose(Tensor, usize, usize),
172    Permute(Tensor, Vec<usize>),
173    Elu(Tensor, f64),
174    Powf(Tensor, f64),
175    CustomOp1(
176        Tensor,
177        std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
178    ),
179    CustomOp2(
180        Tensor,
181        Tensor,
182        std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
183    ),
184    CustomOp3(
185        Tensor,
186        Tensor,
187        Tensor,
188        std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
189    ),
190}
191
192pub trait UnaryOpT {
193    const NAME: &'static str;
194    const KERNEL: &'static str;
195    const V: Self;
196    fn bf16(v1: bf16) -> bf16;
197    fn f16(v1: f16) -> f16;
198    fn f32(v1: f32) -> f32;
199    fn f64(v1: f64) -> f64;
200    fn u8(v1: u8) -> u8;
201    fn u32(v1: u32) -> u32;
202    fn i16(v1: i16) -> i16;
203    fn i32(v1: i32) -> i32;
204    fn i64(v1: i64) -> i64;
205    fn f8e4m3(v1: f8e4m3) -> f8e4m3;
206
207    // There is no very good way to represent optional function in traits so we go for an explicit
208    // boolean flag to mark the function as existing.
209    const BF16_VEC: bool = false;
210    fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
211    const F16_VEC: bool = false;
212    fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
213    const F32_VEC: bool = false;
214    fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
215    const F64_VEC: bool = false;
216    fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
217}
218
219pub trait BinaryOpT {
220    const NAME: &'static str;
221    const KERNEL: &'static str;
222    const V: Self;
223    fn bf16(v1: bf16, v2: bf16) -> bf16;
224    fn f16(v1: f16, v2: f16) -> f16;
225    fn f32(v1: f32, v2: f32) -> f32;
226    fn f64(v1: f64, v2: f64) -> f64;
227    fn u8(v1: u8, v2: u8) -> u8;
228    fn u32(v1: u32, v2: u32) -> u32;
229    fn i16(v1: i16, v2: i16) -> i16;
230    fn i32(v1: i32, v2: i32) -> i32;
231    fn i64(v1: i64, v2: i64) -> i64;
232    fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3;
233
234    const BF16_VEC: bool = false;
235    fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
236    const F16_VEC: bool = false;
237    fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {}
238    const F32_VEC: bool = false;
239    fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {}
240    const F64_VEC: bool = false;
241    fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {}
242    const U8_VEC: bool = false;
243    fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
244    const U32_VEC: bool = false;
245    fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
246    const I64_VEC: bool = false;
247    fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
248}
249
250pub struct Add;
251pub struct Div;
252pub struct Mul;
253pub struct Sub;
254pub struct Maximum;
255pub struct Minimum;
256pub struct Exp;
257pub struct Log;
258pub struct Sin;
259pub struct Cos;
260pub struct Abs;
261pub struct Neg;
262pub struct Recip;
263pub struct Sqr;
264pub struct Sqrt;
265pub struct Gelu;
266pub struct GeluErf;
267pub struct Erf;
268pub struct Relu;
269pub struct Silu;
270pub struct Tanh;
271pub struct Floor;
272pub struct Ceil;
273pub struct Round;
274pub struct Sign;
275
276macro_rules! bin_op {
277    ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
278        impl BinaryOpT for $op {
279            const NAME: &'static str = $name;
280            const KERNEL: &'static str = concat!("b", $name);
281            const V: Self = $op;
282            #[inline(always)]
283            fn bf16(v1: bf16, v2: bf16) -> bf16 {
284                $e(v1, v2)
285            }
286            #[inline(always)]
287            fn f16(v1: f16, v2: f16) -> f16 {
288                $e(v1, v2)
289            }
290            #[inline(always)]
291            fn f32(v1: f32, v2: f32) -> f32 {
292                $e(v1, v2)
293            }
294            #[inline(always)]
295            fn f64(v1: f64, v2: f64) -> f64 {
296                $e(v1, v2)
297            }
298            #[inline(always)]
299            fn u8(v1: u8, v2: u8) -> u8 {
300                $e(v1, v2)
301            }
302            #[inline(always)]
303            fn u32(v1: u32, v2: u32) -> u32 {
304                $e(v1, v2)
305            }
306            #[inline(always)]
307            fn i16(v1: i16, v2: i16) -> i16 {
308                $e(v1, v2)
309            }
310            #[inline(always)]
311            fn i32(v1: i32, v2: i32) -> i32 {
312                $e(v1, v2)
313            }
314            #[inline(always)]
315            fn i64(v1: i64, v2: i64) -> i64 {
316                $e(v1, v2)
317            }
318            #[inline(always)]
319            fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3 {
320                $e(v1, v2)
321            }
322
323            #[cfg(feature = "mkl")]
324            const F32_VEC: bool = true;
325            #[cfg(feature = "mkl")]
326            const F64_VEC: bool = true;
327            #[cfg(feature = "mkl")]
328            #[inline(always)]
329            fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
330                crate::mkl::$f32_vec(xs1, xs2, ys)
331            }
332            #[cfg(feature = "mkl")]
333            #[inline(always)]
334            fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
335                crate::mkl::$f64_vec(xs1, xs2, ys)
336            }
337
338            #[cfg(feature = "accelerate")]
339            const F32_VEC: bool = true;
340            #[cfg(feature = "accelerate")]
341            const F64_VEC: bool = true;
342            #[cfg(feature = "accelerate")]
343            #[inline(always)]
344            fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
345                crate::accelerate::$f32_vec(xs1, xs2, ys)
346            }
347            #[cfg(feature = "accelerate")]
348            #[inline(always)]
349            fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
350                crate::accelerate::$f64_vec(xs1, xs2, ys)
351            }
352        }
353    };
354}
355
356bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
357bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
358bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
359bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
360bin_op!(
361    Minimum,
362    "minimum",
363    |v1, v2| if v1 > v2 { v2 } else { v1 },
364    vs_min,
365    vd_min
366);
367bin_op!(
368    Maximum,
369    "maximum",
370    |v1, v2| if v1 < v2 { v2 } else { v1 },
371    vs_max,
372    vd_max
373);
374
375#[allow(clippy::redundant_closure_call)]
376macro_rules! unary_op {
377    ($op: ident, $name: literal, $a: ident, $e: expr) => {
378        impl UnaryOpT for $op {
379            const NAME: &'static str = $name;
380            const KERNEL: &'static str = concat!("u", $name);
381            const V: Self = $op;
382            #[inline(always)]
383            fn bf16($a: bf16) -> bf16 {
384                $e
385            }
386            #[inline(always)]
387            fn f16($a: f16) -> f16 {
388                $e
389            }
390            #[inline(always)]
391            fn f32($a: f32) -> f32 {
392                $e
393            }
394            #[inline(always)]
395            fn f64($a: f64) -> f64 {
396                $e
397            }
398            #[inline(always)]
399            fn u8(_: u8) -> u8 {
400                todo!("no unary function for u8")
401            }
402            #[inline(always)]
403            fn u32(_: u32) -> u32 {
404                todo!("no unary function for u32")
405            }
406            #[inline(always)]
407            fn i16(_: i16) -> i16 {
408                todo!("no unary function for i16")
409            }
410            #[inline(always)]
411            fn i32(_: i32) -> i32 {
412                todo!("no unary function for i32")
413            }
414            #[inline(always)]
415            fn i64(_: i64) -> i64 {
416                todo!("no unary function for i64")
417            }
418            #[inline(always)]
419            fn f8e4m3($a: f8e4m3) -> f8e4m3 {
420                $e
421            }
422        }
423    };
424
425    ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
426        impl UnaryOpT for $op {
427            const NAME: &'static str = $name;
428            const KERNEL: &'static str = concat!("u", $name);
429            const V: Self = $op;
430            #[inline(always)]
431            fn bf16($a: bf16) -> bf16 {
432                $e
433            }
434            #[inline(always)]
435            fn f16($a: f16) -> f16 {
436                $e
437            }
438            #[inline(always)]
439            fn f32($a: f32) -> f32 {
440                $e
441            }
442            #[inline(always)]
443            fn f64($a: f64) -> f64 {
444                $e
445            }
446            #[inline(always)]
447            fn u8(_: u8) -> u8 {
448                todo!("no unary function for u8")
449            }
450            #[inline(always)]
451            fn u32(_: u32) -> u32 {
452                todo!("no unary function for u32")
453            }
454            #[inline(always)]
455            fn i16(_: i16) -> i16 {
456                todo!("no unary function for i16")
457            }
458            #[inline(always)]
459            fn i32(_: i32) -> i32 {
460                todo!("no unary function for i32")
461            }
462            #[inline(always)]
463            fn i64(_: i64) -> i64 {
464                todo!("no unary function for i64")
465            }
466            #[inline(always)]
467            fn f8e4m3($a: f8e4m3) -> f8e4m3 {
468                $e
469            }
470
471            #[cfg(feature = "mkl")]
472            const F32_VEC: bool = true;
473            #[cfg(feature = "mkl")]
474            const F64_VEC: bool = true;
475            #[cfg(feature = "mkl")]
476            #[inline(always)]
477            fn f32_vec(xs: &[f32], ys: &mut [f32]) {
478                crate::mkl::$f32_vec(xs, ys)
479            }
480            #[cfg(feature = "mkl")]
481            #[inline(always)]
482            fn f64_vec(xs: &[f64], ys: &mut [f64]) {
483                crate::mkl::$f64_vec(xs, ys)
484            }
485
486            #[cfg(feature = "accelerate")]
487            const F32_VEC: bool = true;
488            #[cfg(feature = "accelerate")]
489            const F64_VEC: bool = true;
490            #[cfg(feature = "accelerate")]
491            #[inline(always)]
492            fn f32_vec(xs: &[f32], ys: &mut [f32]) {
493                crate::accelerate::$f32_vec(xs, ys)
494            }
495            #[cfg(feature = "accelerate")]
496            #[inline(always)]
497            fn f64_vec(xs: &[f64], ys: &mut [f64]) {
498                crate::accelerate::$f64_vec(xs, ys)
499            }
500        }
501    };
502}
503
504unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
505unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
506unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
507unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
508unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
509unary_op!(Neg, "neg", v, -v);
510unary_op!(Recip, "recip", v, v.recip());
511unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
512unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
513
514// Hardcode the value for sqrt(2/pi)
515// https://github.com/huggingface/candle/issues/1982
516#[allow(clippy::excessive_precision)]
517const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
518#[allow(clippy::excessive_precision)]
519const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
520
521/// Tanh based approximation of the `gelu` operation
522/// GeluErf is the more precise one.
523/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
524impl UnaryOpT for Gelu {
525    const NAME: &'static str = "gelu";
526    const V: Self = Gelu;
527    #[inline(always)]
528    fn bf16(v: bf16) -> bf16 {
529        bf16::from_f32_const(0.5)
530            * v
531            * (bf16::ONE
532                + bf16::tanh(
533                    bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
534                        * v
535                        * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
536                ))
537    }
538    #[inline(always)]
539    fn f16(v: f16) -> f16 {
540        f16::from_f32_const(0.5)
541            * v
542            * (f16::ONE
543                + f16::tanh(
544                    f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
545                        * v
546                        * (f16::ONE + f16::from_f32_const(0.044715) * v * v),
547                ))
548    }
549    #[inline(always)]
550    fn f32(v: f32) -> f32 {
551        0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
552    }
553    #[inline(always)]
554    fn f64(v: f64) -> f64 {
555        0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
556    }
557    #[inline(always)]
558    fn u8(_: u8) -> u8 {
559        0
560    }
561    #[inline(always)]
562    fn u32(_: u32) -> u32 {
563        0
564    }
565    #[inline(always)]
566    fn i16(_: i16) -> i16 {
567        0
568    }
569    #[inline(always)]
570    fn i32(_: i32) -> i32 {
571        0
572    }
573    #[inline(always)]
574    fn i64(_: i64) -> i64 {
575        0
576    }
577    #[inline(always)]
578    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
579        f8e4m3::from_f32(0.5)
580            * v
581            * (f8e4m3::ONE
582                + f8e4m3::tanh(
583                    f8e4m3::from_f32(SQRT_TWO_OVER_PI_F32)
584                        * v
585                        * (f8e4m3::ONE + f8e4m3::from_f32(0.044715) * v * v),
586                ))
587    }
588    const KERNEL: &'static str = "ugelu";
589
590    #[cfg(feature = "mkl")]
591    const F32_VEC: bool = true;
592
593    #[cfg(feature = "mkl")]
594    #[inline(always)]
595    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
596        crate::mkl::vs_gelu(xs, ys)
597    }
598
599    #[cfg(feature = "mkl")]
600    const F64_VEC: bool = true;
601
602    #[cfg(feature = "mkl")]
603    #[inline(always)]
604    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
605        crate::mkl::vd_gelu(xs, ys)
606    }
607
608    #[cfg(feature = "accelerate")]
609    const F32_VEC: bool = true;
610
611    #[cfg(feature = "accelerate")]
612    #[inline(always)]
613    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
614        crate::accelerate::vs_gelu(xs, ys)
615    }
616
617    #[cfg(feature = "accelerate")]
618    const F64_VEC: bool = true;
619
620    #[cfg(feature = "accelerate")]
621    #[inline(always)]
622    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
623        crate::accelerate::vd_gelu(xs, ys)
624    }
625}
626
627/// `erf` operation
628/// <https://en.wikipedia.org/wiki/Error_function>
629impl UnaryOpT for Erf {
630    const NAME: &'static str = "erf";
631    const KERNEL: &'static str = "uerf";
632    const V: Self = Erf;
633    #[inline(always)]
634    fn bf16(v: bf16) -> bf16 {
635        bf16::from_f64(Self::f64(v.to_f64()))
636    }
637    #[inline(always)]
638    fn f16(v: f16) -> f16 {
639        f16::from_f64(Self::f64(v.to_f64()))
640    }
641    #[inline(always)]
642    fn f32(v: f32) -> f32 {
643        crate::cpu::erf::erf_f32(v)
644    }
645    #[inline(always)]
646    fn f64(v: f64) -> f64 {
647        crate::cpu::erf::erf_f64(v)
648    }
649    #[inline(always)]
650    fn u8(_: u8) -> u8 {
651        0
652    }
653    #[inline(always)]
654    fn u32(_: u32) -> u32 {
655        0
656    }
657    #[inline(always)]
658    fn i16(_: i16) -> i16 {
659        0
660    }
661    #[inline(always)]
662    fn i32(_: i32) -> i32 {
663        0
664    }
665    #[inline(always)]
666    fn i64(_: i64) -> i64 {
667        0
668    }
669    #[inline(always)]
670    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
671        f8e4m3::from_f64(Self::f64(v.to_f64()))
672    }
673}
674
675/// Silu operation
676impl UnaryOpT for Silu {
677    const NAME: &'static str = "silu";
678    const V: Self = Silu;
679    #[inline(always)]
680    fn bf16(v: bf16) -> bf16 {
681        v / (bf16::ONE + (-v).exp())
682    }
683    #[inline(always)]
684    fn f16(v: f16) -> f16 {
685        v / (f16::ONE + (-v).exp())
686    }
687    #[inline(always)]
688    fn f32(v: f32) -> f32 {
689        v / (1.0 + (-v).exp())
690    }
691    #[inline(always)]
692    fn f64(v: f64) -> f64 {
693        v / (1.0 + (-v).exp())
694    }
695    #[inline(always)]
696    fn u8(_: u8) -> u8 {
697        0
698    }
699    #[inline(always)]
700    fn u32(_: u32) -> u32 {
701        0
702    }
703    #[inline(always)]
704    fn i16(_: i16) -> i16 {
705        0
706    }
707    #[inline(always)]
708    fn i32(_: i32) -> i32 {
709        0
710    }
711    #[inline(always)]
712    fn i64(_: i64) -> i64 {
713        0
714    }
715    #[inline(always)]
716    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
717        v / (f8e4m3::ONE + (-v).exp())
718    }
719    const KERNEL: &'static str = "usilu";
720
721    #[cfg(feature = "mkl")]
722    const F32_VEC: bool = true;
723
724    #[cfg(feature = "mkl")]
725    #[inline(always)]
726    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
727        crate::mkl::vs_silu(xs, ys)
728    }
729
730    #[cfg(feature = "mkl")]
731    const F64_VEC: bool = true;
732
733    #[cfg(feature = "mkl")]
734    #[inline(always)]
735    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
736        crate::mkl::vd_silu(xs, ys)
737    }
738
739    #[cfg(feature = "accelerate")]
740    const F32_VEC: bool = true;
741
742    #[cfg(feature = "accelerate")]
743    #[inline(always)]
744    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
745        crate::accelerate::vs_silu(xs, ys)
746    }
747
748    #[cfg(feature = "accelerate")]
749    const F64_VEC: bool = true;
750
751    #[cfg(feature = "accelerate")]
752    #[inline(always)]
753    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
754        crate::accelerate::vd_silu(xs, ys)
755    }
756}
757
758impl UnaryOpT for Abs {
759    const NAME: &'static str = "abs";
760    const KERNEL: &'static str = "uabs";
761    const V: Self = Abs;
762    #[inline(always)]
763    fn bf16(v: bf16) -> bf16 {
764        v.abs()
765    }
766    #[inline(always)]
767    fn f16(v: f16) -> f16 {
768        v.abs()
769    }
770    #[inline(always)]
771    fn f32(v: f32) -> f32 {
772        v.abs()
773    }
774    #[inline(always)]
775    fn f64(v: f64) -> f64 {
776        v.abs()
777    }
778    #[inline(always)]
779    fn u8(v: u8) -> u8 {
780        v
781    }
782    #[inline(always)]
783    fn u32(v: u32) -> u32 {
784        v
785    }
786    #[inline(always)]
787    fn i16(v: i16) -> i16 {
788        v.abs()
789    }
790    #[inline(always)]
791    fn i32(v: i32) -> i32 {
792        v.abs()
793    }
794    #[inline(always)]
795    fn i64(v: i64) -> i64 {
796        v.abs()
797    }
798    #[inline(always)]
799    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
800        v.abs()
801    }
802}
803
804impl UnaryOpT for Ceil {
805    const NAME: &'static str = "ceil";
806    const KERNEL: &'static str = "uceil";
807    const V: Self = Ceil;
808    #[inline(always)]
809    fn bf16(v: bf16) -> bf16 {
810        v.ceil()
811    }
812    #[inline(always)]
813    fn f16(v: f16) -> f16 {
814        v.ceil()
815    }
816    #[inline(always)]
817    fn f32(v: f32) -> f32 {
818        v.ceil()
819    }
820    #[inline(always)]
821    fn f64(v: f64) -> f64 {
822        v.ceil()
823    }
824    #[inline(always)]
825    fn u8(v: u8) -> u8 {
826        v
827    }
828    #[inline(always)]
829    fn u32(v: u32) -> u32 {
830        v
831    }
832    #[inline(always)]
833    fn i16(v: i16) -> i16 {
834        v
835    }
836    #[inline(always)]
837    fn i32(v: i32) -> i32 {
838        v
839    }
840    #[inline(always)]
841    fn i64(v: i64) -> i64 {
842        v
843    }
844    #[inline(always)]
845    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
846        v.ceil()
847    }
848}
849
850impl UnaryOpT for Floor {
851    const NAME: &'static str = "floor";
852    const KERNEL: &'static str = "ufloor";
853    const V: Self = Floor;
854    #[inline(always)]
855    fn bf16(v: bf16) -> bf16 {
856        v.floor()
857    }
858    #[inline(always)]
859    fn f16(v: f16) -> f16 {
860        v.floor()
861    }
862    #[inline(always)]
863    fn f32(v: f32) -> f32 {
864        v.floor()
865    }
866    #[inline(always)]
867    fn f64(v: f64) -> f64 {
868        v.floor()
869    }
870    #[inline(always)]
871    fn u8(v: u8) -> u8 {
872        v
873    }
874    #[inline(always)]
875    fn u32(v: u32) -> u32 {
876        v
877    }
878    #[inline(always)]
879    fn i16(v: i16) -> i16 {
880        v
881    }
882    #[inline(always)]
883    fn i32(v: i32) -> i32 {
884        v
885    }
886    #[inline(always)]
887    fn i64(v: i64) -> i64 {
888        v
889    }
890    #[inline(always)]
891    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
892        v.floor()
893    }
894}
895
896impl UnaryOpT for Round {
897    const NAME: &'static str = "round";
898    const KERNEL: &'static str = "uround";
899    const V: Self = Round;
900    #[inline(always)]
901    fn bf16(v: bf16) -> bf16 {
902        v.round()
903    }
904    #[inline(always)]
905    fn f16(v: f16) -> f16 {
906        v.round()
907    }
908    #[inline(always)]
909    fn f32(v: f32) -> f32 {
910        v.round()
911    }
912    #[inline(always)]
913    fn f64(v: f64) -> f64 {
914        v.round()
915    }
916    #[inline(always)]
917    fn u8(v: u8) -> u8 {
918        v
919    }
920    #[inline(always)]
921    fn u32(v: u32) -> u32 {
922        v
923    }
924    #[inline(always)]
925    fn i16(v: i16) -> i16 {
926        v
927    }
928    #[inline(always)]
929    fn i32(v: i32) -> i32 {
930        v
931    }
932    #[inline(always)]
933    fn i64(v: i64) -> i64 {
934        v
935    }
936    #[inline(always)]
937    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
938        v.round()
939    }
940}
941
942impl UnaryOpT for GeluErf {
943    const NAME: &'static str = "gelu_erf";
944    const KERNEL: &'static str = "ugelu_erf";
945    const V: Self = GeluErf;
946    #[inline(always)]
947    fn bf16(v: bf16) -> bf16 {
948        bf16::from_f64(Self::f64(v.to_f64()))
949    }
950    #[inline(always)]
951    fn f16(v: f16) -> f16 {
952        f16::from_f64(Self::f64(v.to_f64()))
953    }
954    #[inline(always)]
955    fn f32(v: f32) -> f32 {
956        (crate::cpu::erf::erf_f32(v * std::f32::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v
957    }
958    #[inline(always)]
959    fn f64(v: f64) -> f64 {
960        (crate::cpu::erf::erf_f64(v * std::f64::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v
961    }
962    #[inline(always)]
963    fn u8(_: u8) -> u8 {
964        0
965    }
966    #[inline(always)]
967    fn u32(_: u32) -> u32 {
968        0
969    }
970    #[inline(always)]
971    fn i16(_: i16) -> i16 {
972        0
973    }
974    #[inline(always)]
975    fn i32(_: i32) -> i32 {
976        0
977    }
978    #[inline(always)]
979    fn i64(_: i64) -> i64 {
980        0
981    }
982    #[inline(always)]
983    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
984        f8e4m3::from_f32(Self::f32(v.to_f32()))
985    }
986}
987
988impl UnaryOpT for Relu {
989    const NAME: &'static str = "relu";
990    const KERNEL: &'static str = "urelu";
991    const V: Self = Relu;
992    #[inline(always)]
993    fn bf16(v: bf16) -> bf16 {
994        v.max(bf16::ZERO)
995    }
996    #[inline(always)]
997    fn f16(v: f16) -> f16 {
998        v.max(f16::ZERO)
999    }
1000    #[inline(always)]
1001    fn f32(v: f32) -> f32 {
1002        v.max(0f32)
1003    }
1004    #[inline(always)]
1005    fn f64(v: f64) -> f64 {
1006        v.max(0f64)
1007    }
1008    #[inline(always)]
1009    fn u8(v: u8) -> u8 {
1010        v
1011    }
1012    #[inline(always)]
1013    fn u32(v: u32) -> u32 {
1014        v
1015    }
1016    #[inline(always)]
1017    fn i16(v: i16) -> i16 {
1018        v.max(0)
1019    }
1020    #[inline(always)]
1021    fn i32(v: i32) -> i32 {
1022        v.max(0)
1023    }
1024    #[inline(always)]
1025    fn i64(v: i64) -> i64 {
1026        v.max(0)
1027    }
1028    #[inline(always)]
1029    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
1030        v.max(f8e4m3::ZERO)
1031    }
1032}
1033
1034/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
1035/// properly checked when creating a new value
1036#[derive(Clone)]
1037pub struct BackpropOp(Option<Op>);
1038
1039impl BackpropOp {
1040    pub fn none() -> Self {
1041        BackpropOp(None)
1042    }
1043
1044    pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
1045        let op = if arg.track_op() {
1046            Some(f(arg.clone()))
1047        } else {
1048            None
1049        };
1050        Self(op)
1051    }
1052
1053    pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
1054        let op = if arg1.track_op() || arg2.track_op() {
1055            Some(f(arg1.clone(), arg2.clone()))
1056        } else {
1057            None
1058        };
1059        Self(op)
1060    }
1061
1062    pub(crate) fn new3(
1063        arg1: &Tensor,
1064        arg2: &Tensor,
1065        arg3: &Tensor,
1066        f: impl Fn(Tensor, Tensor, Tensor) -> Op,
1067    ) -> Self {
1068        let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
1069            Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
1070        } else {
1071            None
1072        };
1073        Self(op)
1074    }
1075
1076    pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
1077        let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
1078            let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
1079            Some(f(args))
1080        } else {
1081            None
1082        };
1083        Self(op)
1084    }
1085
1086    pub(crate) fn is_none(&self) -> bool {
1087        self.0.is_none()
1088    }
1089}
1090
1091impl std::ops::Deref for BackpropOp {
1092    type Target = Option<Op>;
1093    fn deref(&self) -> &Self::Target {
1094        &self.0
1095    }
1096}
1097
1098impl UnaryOpT for Sign {
1099    const NAME: &'static str = "sign";
1100    const KERNEL: &'static str = "usign";
1101    const V: Self = Sign;
1102    #[inline(always)]
1103    fn bf16(v: bf16) -> bf16 {
1104        bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
1105    }
1106    #[inline(always)]
1107    fn f16(v: f16) -> f16 {
1108        f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
1109    }
1110    #[inline(always)]
1111    fn f32(v: f32) -> f32 {
1112        f32::from(v > 0.) - f32::from(v < 0.)
1113    }
1114    #[inline(always)]
1115    fn f64(v: f64) -> f64 {
1116        f64::from(v > 0.) - f64::from(v < 0.)
1117    }
1118    #[inline(always)]
1119    fn u8(v: u8) -> u8 {
1120        u8::min(1, v)
1121    }
1122    #[inline(always)]
1123    fn u32(v: u32) -> u32 {
1124        u32::min(1, v)
1125    }
1126    #[inline(always)]
1127    fn i16(v: i16) -> i16 {
1128        (v > 0) as i16 - (v < 0) as i16
1129    }
1130    #[inline(always)]
1131    fn i32(v: i32) -> i32 {
1132        (v > 0) as i32 - (v < 0) as i32
1133    }
1134    #[inline(always)]
1135    fn i64(v: i64) -> i64 {
1136        (v > 0) as i64 - (v < 0) as i64
1137    }
1138    #[inline(always)]
1139    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
1140        if v > f8e4m3::ZERO {
1141            f8e4m3::ONE
1142        } else if v < f8e4m3::ZERO {
1143            -f8e4m3::ONE
1144        } else {
1145            f8e4m3::ZERO
1146        }
1147    }
1148}