candle_core/
op.rs

1//! Tensor Opertion Enums and Traits
2//!
3#![allow(clippy::redundant_closure_call)]
4use crate::Tensor;
5use half::{bf16, f16};
6use num_traits::float::Float;
7
8#[derive(Clone, Copy, PartialEq, Eq)]
9pub enum CmpOp {
10    Eq,
11    Ne,
12    Le,
13    Ge,
14    Lt,
15    Gt,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ReduceOp {
20    Sum,
21    Min,
22    Max,
23    ArgMin,
24    ArgMax,
25}
26
27impl ReduceOp {
28    pub(crate) fn name(&self) -> &'static str {
29        match self {
30            Self::ArgMax => "argmax",
31            Self::ArgMin => "argmin",
32            Self::Min => "min",
33            Self::Max => "max",
34            Self::Sum => "sum",
35        }
36    }
37}
38
39// These ops return the same type as their input type.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum BinaryOp {
42    Add,
43    Mul,
44    Sub,
45    Div,
46    Maximum,
47    Minimum,
48}
49
50// Unary ops with no argument
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum UnaryOp {
53    Exp,
54    Log,
55    Sin,
56    Cos,
57    Abs,
58    Neg,
59    Recip,
60    Sqr,
61    Sqrt,
62    Gelu,
63    GeluErf,
64    Erf,
65    Relu,
66    Silu,
67    Tanh,
68    Floor,
69    Ceil,
70    Round,
71    Sign,
72}
73
74#[derive(Clone)]
75pub enum Op {
76    Binary(Tensor, Tensor, BinaryOp),
77    Unary(Tensor, UnaryOp),
78    Cmp(Tensor, CmpOp),
79    // The third argument is the reduced shape with `keepdim=true`.
80    Reduce(Tensor, ReduceOp, Vec<usize>),
81    Matmul(Tensor, Tensor),
82    Gather(Tensor, Tensor, usize),
83    Scatter(Tensor, Tensor, Tensor, usize),
84    ScatterAdd(Tensor, Tensor, Tensor, usize),
85    IndexSelect(Tensor, Tensor, usize),
86    IndexAdd(Tensor, Tensor, Tensor, usize),
87    WhereCond(Tensor, Tensor, Tensor),
88
89    #[allow(dead_code)]
90    Conv1D {
91        arg: Tensor,
92        kernel: Tensor,
93        padding: usize,
94        stride: usize,
95        dilation: usize,
96    },
97
98    #[allow(dead_code)]
99    ConvTranspose1D {
100        arg: Tensor,
101        kernel: Tensor,
102        padding: usize,
103        output_padding: usize,
104        stride: usize,
105        dilation: usize,
106    },
107
108    #[allow(dead_code)]
109    Conv2D {
110        arg: Tensor,
111        kernel: Tensor,
112        padding: usize,
113        stride: usize,
114        dilation: usize,
115    },
116
117    #[allow(dead_code)]
118    ConvTranspose2D {
119        arg: Tensor,
120        kernel: Tensor,
121        padding: usize,
122        output_padding: usize,
123        stride: usize,
124        dilation: usize,
125    },
126
127    AvgPool2D {
128        arg: Tensor,
129        kernel_size: (usize, usize),
130        stride: (usize, usize),
131    },
132
133    MaxPool2D {
134        arg: Tensor,
135        kernel_size: (usize, usize),
136        stride: (usize, usize),
137    },
138
139    UpsampleNearest1D {
140        arg: Tensor,
141        target_size: usize,
142    },
143    UpsampleNearest2D {
144        arg: Tensor,
145        target_h: usize,
146        target_w: usize,
147    },
148
149    Cat(Vec<Tensor>, usize),
150
151    #[allow(dead_code)] // add is currently unused.
152    Affine {
153        arg: Tensor,
154        mul: f64,
155        add: f64,
156    },
157    ToDType(Tensor),
158    Copy(Tensor),
159    Broadcast(Tensor),
160    Narrow(Tensor, usize, usize, usize),
161    SliceScatter0(Tensor, Tensor, usize),
162    Reshape(Tensor),
163    ToDevice(Tensor),
164    Transpose(Tensor, usize, usize),
165    Permute(Tensor, Vec<usize>),
166    Elu(Tensor, f64),
167    Powf(Tensor, f64),
168    CustomOp1(
169        Tensor,
170        std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
171    ),
172    CustomOp2(
173        Tensor,
174        Tensor,
175        std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
176    ),
177    CustomOp3(
178        Tensor,
179        Tensor,
180        Tensor,
181        std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
182    ),
183}
184
185pub trait UnaryOpT {
186    const NAME: &'static str;
187    const KERNEL: &'static str;
188    const V: Self;
189    fn bf16(v1: bf16) -> bf16;
190    fn f16(v1: f16) -> f16;
191    fn f32(v1: f32) -> f32;
192    fn f64(v1: f64) -> f64;
193    fn u8(v1: u8) -> u8;
194    fn u32(v1: u32) -> u32;
195    fn i64(v1: i64) -> i64;
196
197    // There is no very good way to represent optional function in traits so we go for an explicit
198    // boolean flag to mark the function as existing.
199    const BF16_VEC: bool = false;
200    fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
201    const F16_VEC: bool = false;
202    fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
203    const F32_VEC: bool = false;
204    fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
205    const F64_VEC: bool = false;
206    fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
207}
208
209pub trait BinaryOpT {
210    const NAME: &'static str;
211    const KERNEL: &'static str;
212    const V: Self;
213    fn bf16(v1: bf16, v2: bf16) -> bf16;
214    fn f16(v1: f16, v2: f16) -> f16;
215    fn f32(v1: f32, v2: f32) -> f32;
216    fn f64(v1: f64, v2: f64) -> f64;
217    fn u8(v1: u8, v2: u8) -> u8;
218    fn u32(v1: u32, v2: u32) -> u32;
219    fn i64(v1: i64, v2: i64) -> i64;
220
221    const BF16_VEC: bool = false;
222    fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
223    const F16_VEC: bool = false;
224    fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {}
225    const F32_VEC: bool = false;
226    fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {}
227    const F64_VEC: bool = false;
228    fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {}
229    const U8_VEC: bool = false;
230    fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
231    const U32_VEC: bool = false;
232    fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
233    const I64_VEC: bool = false;
234    fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
235}
236
237pub(crate) struct Add;
238pub(crate) struct Div;
239pub(crate) struct Mul;
240pub(crate) struct Sub;
241pub(crate) struct Maximum;
242pub(crate) struct Minimum;
243pub(crate) struct Exp;
244pub(crate) struct Log;
245pub(crate) struct Sin;
246pub(crate) struct Cos;
247pub(crate) struct Abs;
248pub(crate) struct Neg;
249pub(crate) struct Recip;
250pub(crate) struct Sqr;
251pub(crate) struct Sqrt;
252pub(crate) struct Gelu;
253pub(crate) struct GeluErf;
254pub(crate) struct Erf;
255pub(crate) struct Relu;
256pub(crate) struct Silu;
257pub(crate) struct Tanh;
258pub(crate) struct Floor;
259pub(crate) struct Ceil;
260pub(crate) struct Round;
261pub(crate) struct Sign;
262
263macro_rules! bin_op {
264    ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
265        impl BinaryOpT for $op {
266            const NAME: &'static str = $name;
267            const KERNEL: &'static str = concat!("b", $name);
268            const V: Self = $op;
269            #[inline(always)]
270            fn bf16(v1: bf16, v2: bf16) -> bf16 {
271                $e(v1, v2)
272            }
273            #[inline(always)]
274            fn f16(v1: f16, v2: f16) -> f16 {
275                $e(v1, v2)
276            }
277            #[inline(always)]
278            fn f32(v1: f32, v2: f32) -> f32 {
279                $e(v1, v2)
280            }
281            #[inline(always)]
282            fn f64(v1: f64, v2: f64) -> f64 {
283                $e(v1, v2)
284            }
285            #[inline(always)]
286            fn u8(v1: u8, v2: u8) -> u8 {
287                $e(v1, v2)
288            }
289            #[inline(always)]
290            fn u32(v1: u32, v2: u32) -> u32 {
291                $e(v1, v2)
292            }
293            #[inline(always)]
294            fn i64(v1: i64, v2: i64) -> i64 {
295                $e(v1, v2)
296            }
297
298            #[cfg(feature = "mkl")]
299            const F32_VEC: bool = true;
300            #[cfg(feature = "mkl")]
301            const F64_VEC: bool = true;
302            #[cfg(feature = "mkl")]
303            #[inline(always)]
304            fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
305                crate::mkl::$f32_vec(xs1, xs2, ys)
306            }
307            #[cfg(feature = "mkl")]
308            #[inline(always)]
309            fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
310                crate::mkl::$f64_vec(xs1, xs2, ys)
311            }
312
313            #[cfg(feature = "accelerate")]
314            const F32_VEC: bool = true;
315            #[cfg(feature = "accelerate")]
316            const F64_VEC: bool = true;
317            #[cfg(feature = "accelerate")]
318            #[inline(always)]
319            fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
320                crate::accelerate::$f32_vec(xs1, xs2, ys)
321            }
322            #[cfg(feature = "accelerate")]
323            #[inline(always)]
324            fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
325                crate::accelerate::$f64_vec(xs1, xs2, ys)
326            }
327        }
328    };
329}
330
331bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
332bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
333bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
334bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
335bin_op!(
336    Minimum,
337    "minimum",
338    |v1, v2| if v1 > v2 { v2 } else { v1 },
339    vs_min,
340    vd_min
341);
342bin_op!(
343    Maximum,
344    "maximum",
345    |v1, v2| if v1 < v2 { v2 } else { v1 },
346    vs_max,
347    vd_max
348);
349
350#[allow(clippy::redundant_closure_call)]
351macro_rules! unary_op {
352    ($op: ident, $name: literal, $a: ident, $e: expr) => {
353        impl UnaryOpT for $op {
354            const NAME: &'static str = $name;
355            const KERNEL: &'static str = concat!("u", $name);
356            const V: Self = $op;
357            #[inline(always)]
358            fn bf16($a: bf16) -> bf16 {
359                $e
360            }
361            #[inline(always)]
362            fn f16($a: f16) -> f16 {
363                $e
364            }
365            #[inline(always)]
366            fn f32($a: f32) -> f32 {
367                $e
368            }
369            #[inline(always)]
370            fn f64($a: f64) -> f64 {
371                $e
372            }
373            #[inline(always)]
374            fn u8(_: u8) -> u8 {
375                todo!("no unary function for u8")
376            }
377            #[inline(always)]
378            fn u32(_: u32) -> u32 {
379                todo!("no unary function for u32")
380            }
381            #[inline(always)]
382            fn i64(_: i64) -> i64 {
383                todo!("no unary function for i64")
384            }
385        }
386    };
387
388    ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
389        impl UnaryOpT for $op {
390            const NAME: &'static str = $name;
391            const KERNEL: &'static str = concat!("u", $name);
392            const V: Self = $op;
393            #[inline(always)]
394            fn bf16($a: bf16) -> bf16 {
395                $e
396            }
397            #[inline(always)]
398            fn f16($a: f16) -> f16 {
399                $e
400            }
401            #[inline(always)]
402            fn f32($a: f32) -> f32 {
403                $e
404            }
405            #[inline(always)]
406            fn f64($a: f64) -> f64 {
407                $e
408            }
409            #[inline(always)]
410            fn u8(_: u8) -> u8 {
411                todo!("no unary function for u8")
412            }
413            #[inline(always)]
414            fn u32(_: u32) -> u32 {
415                todo!("no unary function for u32")
416            }
417            #[inline(always)]
418            fn i64(_: i64) -> i64 {
419                todo!("no unary function for i64")
420            }
421
422            #[cfg(feature = "mkl")]
423            const F32_VEC: bool = true;
424            #[cfg(feature = "mkl")]
425            const F64_VEC: bool = true;
426            #[cfg(feature = "mkl")]
427            #[inline(always)]
428            fn f32_vec(xs: &[f32], ys: &mut [f32]) {
429                crate::mkl::$f32_vec(xs, ys)
430            }
431            #[cfg(feature = "mkl")]
432            #[inline(always)]
433            fn f64_vec(xs: &[f64], ys: &mut [f64]) {
434                crate::mkl::$f64_vec(xs, ys)
435            }
436
437            #[cfg(feature = "accelerate")]
438            const F32_VEC: bool = true;
439            #[cfg(feature = "accelerate")]
440            const F64_VEC: bool = true;
441            #[cfg(feature = "accelerate")]
442            #[inline(always)]
443            fn f32_vec(xs: &[f32], ys: &mut [f32]) {
444                crate::accelerate::$f32_vec(xs, ys)
445            }
446            #[cfg(feature = "accelerate")]
447            #[inline(always)]
448            fn f64_vec(xs: &[f64], ys: &mut [f64]) {
449                crate::accelerate::$f64_vec(xs, ys)
450            }
451        }
452    };
453}
454
455unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
456unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
457unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
458unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
459unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
460unary_op!(Neg, "neg", v, -v);
461unary_op!(Recip, "recip", v, v.recip());
462unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
463unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
464
465// Hardcode the value for sqrt(2/pi)
466// https://github.com/huggingface/candle/issues/1982
467#[allow(clippy::excessive_precision)]
468const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
469#[allow(clippy::excessive_precision)]
470const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
471
472/// Tanh based approximation of the `gelu` operation
473/// GeluErf is the more precise one.
474/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
475impl UnaryOpT for Gelu {
476    const NAME: &'static str = "gelu";
477    const V: Self = Gelu;
478    #[inline(always)]
479    fn bf16(v: bf16) -> bf16 {
480        bf16::from_f32_const(0.5)
481            * v
482            * (bf16::ONE
483                + bf16::tanh(
484                    bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
485                        * v
486                        * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
487                ))
488    }
489    #[inline(always)]
490    fn f16(v: f16) -> f16 {
491        f16::from_f32_const(0.5)
492            * v
493            * (f16::ONE
494                + f16::tanh(
495                    f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
496                        * v
497                        * (f16::ONE + f16::from_f32_const(0.044715) * v * v),
498                ))
499    }
500    #[inline(always)]
501    fn f32(v: f32) -> f32 {
502        0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
503    }
504    #[inline(always)]
505    fn f64(v: f64) -> f64 {
506        0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
507    }
508    #[inline(always)]
509    fn u8(_: u8) -> u8 {
510        0
511    }
512    #[inline(always)]
513    fn u32(_: u32) -> u32 {
514        0
515    }
516    #[inline(always)]
517    fn i64(_: i64) -> i64 {
518        0
519    }
520    const KERNEL: &'static str = "ugelu";
521
522    #[cfg(feature = "mkl")]
523    const F32_VEC: bool = true;
524
525    #[cfg(feature = "mkl")]
526    #[inline(always)]
527    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
528        crate::mkl::vs_gelu(xs, ys)
529    }
530
531    #[cfg(feature = "mkl")]
532    const F64_VEC: bool = true;
533
534    #[cfg(feature = "mkl")]
535    #[inline(always)]
536    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
537        crate::mkl::vd_gelu(xs, ys)
538    }
539
540    #[cfg(feature = "accelerate")]
541    const F32_VEC: bool = true;
542
543    #[cfg(feature = "accelerate")]
544    #[inline(always)]
545    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
546        crate::accelerate::vs_gelu(xs, ys)
547    }
548
549    #[cfg(feature = "accelerate")]
550    const F64_VEC: bool = true;
551
552    #[cfg(feature = "accelerate")]
553    #[inline(always)]
554    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
555        crate::accelerate::vd_gelu(xs, ys)
556    }
557}
558
559/// `erf` operation
560/// <https://en.wikipedia.org/wiki/Error_function>
561impl UnaryOpT for Erf {
562    const NAME: &'static str = "erf";
563    const KERNEL: &'static str = "uerf";
564    const V: Self = Erf;
565    #[inline(always)]
566    fn bf16(v: bf16) -> bf16 {
567        bf16::from_f64(Self::f64(v.to_f64()))
568    }
569    #[inline(always)]
570    fn f16(v: f16) -> f16 {
571        f16::from_f64(Self::f64(v.to_f64()))
572    }
573    #[inline(always)]
574    fn f32(v: f32) -> f32 {
575        Self::f64(v as f64) as f32
576    }
577    #[inline(always)]
578    fn f64(v: f64) -> f64 {
579        crate::cpu::erf::erf(v)
580    }
581    #[inline(always)]
582    fn u8(_: u8) -> u8 {
583        0
584    }
585    #[inline(always)]
586    fn u32(_: u32) -> u32 {
587        0
588    }
589    #[inline(always)]
590    fn i64(_: i64) -> i64 {
591        0
592    }
593}
594
595/// Silu operation
596impl UnaryOpT for Silu {
597    const NAME: &'static str = "silu";
598    const V: Self = Silu;
599    #[inline(always)]
600    fn bf16(v: bf16) -> bf16 {
601        v / (bf16::ONE + (-v).exp())
602    }
603    #[inline(always)]
604    fn f16(v: f16) -> f16 {
605        v / (f16::ONE + (-v).exp())
606    }
607    #[inline(always)]
608    fn f32(v: f32) -> f32 {
609        v / (1.0 + (-v).exp())
610    }
611    #[inline(always)]
612    fn f64(v: f64) -> f64 {
613        v / (1.0 + (-v).exp())
614    }
615    #[inline(always)]
616    fn u8(_: u8) -> u8 {
617        0
618    }
619    #[inline(always)]
620    fn u32(_: u32) -> u32 {
621        0
622    }
623    #[inline(always)]
624    fn i64(_: i64) -> i64 {
625        0
626    }
627    const KERNEL: &'static str = "usilu";
628
629    #[cfg(feature = "mkl")]
630    const F32_VEC: bool = true;
631
632    #[cfg(feature = "mkl")]
633    #[inline(always)]
634    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
635        crate::mkl::vs_silu(xs, ys)
636    }
637
638    #[cfg(feature = "mkl")]
639    const F64_VEC: bool = true;
640
641    #[cfg(feature = "mkl")]
642    #[inline(always)]
643    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
644        crate::mkl::vd_silu(xs, ys)
645    }
646
647    #[cfg(feature = "accelerate")]
648    const F32_VEC: bool = true;
649
650    #[cfg(feature = "accelerate")]
651    #[inline(always)]
652    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
653        crate::accelerate::vs_silu(xs, ys)
654    }
655
656    #[cfg(feature = "accelerate")]
657    const F64_VEC: bool = true;
658
659    #[cfg(feature = "accelerate")]
660    #[inline(always)]
661    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
662        crate::accelerate::vd_silu(xs, ys)
663    }
664}
665
666impl UnaryOpT for Abs {
667    const NAME: &'static str = "abs";
668    const KERNEL: &'static str = "uabs";
669    const V: Self = Abs;
670    #[inline(always)]
671    fn bf16(v: bf16) -> bf16 {
672        v.abs()
673    }
674    #[inline(always)]
675    fn f16(v: f16) -> f16 {
676        v.abs()
677    }
678    #[inline(always)]
679    fn f32(v: f32) -> f32 {
680        v.abs()
681    }
682    #[inline(always)]
683    fn f64(v: f64) -> f64 {
684        v.abs()
685    }
686    #[inline(always)]
687    fn u8(v: u8) -> u8 {
688        v
689    }
690    #[inline(always)]
691    fn u32(v: u32) -> u32 {
692        v
693    }
694    #[inline(always)]
695    fn i64(v: i64) -> i64 {
696        v.abs()
697    }
698}
699
700impl UnaryOpT for Ceil {
701    const NAME: &'static str = "ceil";
702    const KERNEL: &'static str = "uceil";
703    const V: Self = Ceil;
704    #[inline(always)]
705    fn bf16(v: bf16) -> bf16 {
706        v.ceil()
707    }
708    #[inline(always)]
709    fn f16(v: f16) -> f16 {
710        v.ceil()
711    }
712    #[inline(always)]
713    fn f32(v: f32) -> f32 {
714        v.ceil()
715    }
716    #[inline(always)]
717    fn f64(v: f64) -> f64 {
718        v.ceil()
719    }
720    #[inline(always)]
721    fn u8(v: u8) -> u8 {
722        v
723    }
724    #[inline(always)]
725    fn u32(v: u32) -> u32 {
726        v
727    }
728    #[inline(always)]
729    fn i64(v: i64) -> i64 {
730        v
731    }
732}
733
734impl UnaryOpT for Floor {
735    const NAME: &'static str = "floor";
736    const KERNEL: &'static str = "ufloor";
737    const V: Self = Floor;
738    #[inline(always)]
739    fn bf16(v: bf16) -> bf16 {
740        v.floor()
741    }
742    #[inline(always)]
743    fn f16(v: f16) -> f16 {
744        v.floor()
745    }
746    #[inline(always)]
747    fn f32(v: f32) -> f32 {
748        v.floor()
749    }
750    #[inline(always)]
751    fn f64(v: f64) -> f64 {
752        v.floor()
753    }
754    #[inline(always)]
755    fn u8(v: u8) -> u8 {
756        v
757    }
758    #[inline(always)]
759    fn u32(v: u32) -> u32 {
760        v
761    }
762    #[inline(always)]
763    fn i64(v: i64) -> i64 {
764        v
765    }
766}
767
768impl UnaryOpT for Round {
769    const NAME: &'static str = "round";
770    const KERNEL: &'static str = "uround";
771    const V: Self = Round;
772    #[inline(always)]
773    fn bf16(v: bf16) -> bf16 {
774        v.round()
775    }
776    #[inline(always)]
777    fn f16(v: f16) -> f16 {
778        v.round()
779    }
780    #[inline(always)]
781    fn f32(v: f32) -> f32 {
782        v.round()
783    }
784    #[inline(always)]
785    fn f64(v: f64) -> f64 {
786        v.round()
787    }
788    #[inline(always)]
789    fn u8(v: u8) -> u8 {
790        v
791    }
792    #[inline(always)]
793    fn u32(v: u32) -> u32 {
794        v
795    }
796    #[inline(always)]
797    fn i64(v: i64) -> i64 {
798        v
799    }
800}
801
802impl UnaryOpT for GeluErf {
803    const NAME: &'static str = "gelu_erf";
804    const KERNEL: &'static str = "ugelu_erf";
805    const V: Self = GeluErf;
806    #[inline(always)]
807    fn bf16(v: bf16) -> bf16 {
808        bf16::from_f64(Self::f64(v.to_f64()))
809    }
810    #[inline(always)]
811    fn f16(v: f16) -> f16 {
812        f16::from_f64(Self::f64(v.to_f64()))
813    }
814    #[inline(always)]
815    fn f32(v: f32) -> f32 {
816        Self::f64(v as f64) as f32
817    }
818    #[inline(always)]
819    fn f64(v: f64) -> f64 {
820        (crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
821    }
822    #[inline(always)]
823    fn u8(_: u8) -> u8 {
824        0
825    }
826    #[inline(always)]
827    fn u32(_: u32) -> u32 {
828        0
829    }
830    #[inline(always)]
831    fn i64(_: i64) -> i64 {
832        0
833    }
834}
835
836impl UnaryOpT for Relu {
837    const NAME: &'static str = "relu";
838    const KERNEL: &'static str = "urelu";
839    const V: Self = Relu;
840    #[inline(always)]
841    fn bf16(v: bf16) -> bf16 {
842        v.max(bf16::ZERO)
843    }
844    #[inline(always)]
845    fn f16(v: f16) -> f16 {
846        v.max(f16::ZERO)
847    }
848    #[inline(always)]
849    fn f32(v: f32) -> f32 {
850        v.max(0f32)
851    }
852    #[inline(always)]
853    fn f64(v: f64) -> f64 {
854        v.max(0f64)
855    }
856    #[inline(always)]
857    fn u8(v: u8) -> u8 {
858        v
859    }
860    #[inline(always)]
861    fn u32(v: u32) -> u32 {
862        v
863    }
864    #[inline(always)]
865    fn i64(v: i64) -> i64 {
866        v
867    }
868}
869
870/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
871/// properly checked when creating a new value
872#[derive(Clone)]
873pub struct BackpropOp(Option<Op>);
874
875impl BackpropOp {
876    pub(crate) fn none() -> Self {
877        BackpropOp(None)
878    }
879
880    pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
881        let op = if arg.track_op() {
882            Some(f(arg.clone()))
883        } else {
884            None
885        };
886        Self(op)
887    }
888
889    pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
890        let op = if arg1.track_op() || arg2.track_op() {
891            Some(f(arg1.clone(), arg2.clone()))
892        } else {
893            None
894        };
895        Self(op)
896    }
897
898    pub(crate) fn new3(
899        arg1: &Tensor,
900        arg2: &Tensor,
901        arg3: &Tensor,
902        f: impl Fn(Tensor, Tensor, Tensor) -> Op,
903    ) -> Self {
904        let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
905            Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
906        } else {
907            None
908        };
909        Self(op)
910    }
911
912    pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
913        let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
914            let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
915            Some(f(args))
916        } else {
917            None
918        };
919        Self(op)
920    }
921
922    pub(crate) fn is_none(&self) -> bool {
923        self.0.is_none()
924    }
925}
926
927impl std::ops::Deref for BackpropOp {
928    type Target = Option<Op>;
929    fn deref(&self) -> &Self::Target {
930        &self.0
931    }
932}
933
934impl UnaryOpT for Sign {
935    const NAME: &'static str = "sign";
936    const KERNEL: &'static str = "usign";
937    const V: Self = Sign;
938    #[inline(always)]
939    fn bf16(v: bf16) -> bf16 {
940        bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
941    }
942    #[inline(always)]
943    fn f16(v: f16) -> f16 {
944        f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
945    }
946    #[inline(always)]
947    fn f32(v: f32) -> f32 {
948        f32::from(v > 0.) - f32::from(v < 0.)
949    }
950    #[inline(always)]
951    fn f64(v: f64) -> f64 {
952        f64::from(v > 0.) - f64::from(v < 0.)
953    }
954    #[inline(always)]
955    fn u8(v: u8) -> u8 {
956        u8::min(1, v)
957    }
958    #[inline(always)]
959    fn u32(v: u32) -> u32 {
960        u32::min(1, v)
961    }
962    #[inline(always)]
963    fn i64(v: i64) -> i64 {
964        (v > 0) as i64 - (v < 0) as i64
965    }
966}