candle_core_temp/
op.rs

1#![allow(clippy::redundant_closure_call)]
2use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
3use half::{bf16, f16};
4use num_traits::float::Float;
5
6#[derive(Clone, Copy, PartialEq, Eq)]
7pub enum CmpOp {
8    Eq,
9    Ne,
10    Le,
11    Ge,
12    Lt,
13    Gt,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ReduceOp {
18    Sum,
19    Min,
20    Max,
21    ArgMin,
22    ArgMax,
23}
24
25impl ReduceOp {
26    pub(crate) fn name(&self) -> &'static str {
27        match self {
28            Self::ArgMax => "argmax",
29            Self::ArgMin => "argmin",
30            Self::Min => "min",
31            Self::Max => "max",
32            Self::Sum => "sum",
33        }
34    }
35}
36
37// These ops return the same type as their input type.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum BinaryOp {
40    Add,
41    Mul,
42    Sub,
43    Div,
44    Maximum,
45    Minimum,
46}
47
48// Unary ops with no argument
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum UnaryOp {
51    Exp,
52    Log,
53    Sin,
54    Cos,
55    Abs,
56    Neg,
57    Recip,
58    Sqr,
59    Sqrt,
60    Gelu,
61    GeluErf,
62    Erf,
63    Relu,
64    Tanh,
65    Floor,
66    Ceil,
67    Round,
68}
69
70#[derive(Clone)]
71pub enum Op {
72    Binary(Tensor, Tensor, BinaryOp),
73    Unary(Tensor, UnaryOp),
74    Cmp(Tensor, CmpOp),
75    // The third argument is the reduced shape with `keepdim=true`.
76    Reduce(Tensor, ReduceOp, Vec<usize>),
77    Matmul(Tensor, Tensor),
78    Gather(Tensor, Tensor, usize),
79    ScatterAdd(Tensor, Tensor, Tensor, usize),
80    IndexSelect(Tensor, Tensor, usize),
81    IndexAdd(Tensor, Tensor, Tensor, usize),
82    WhereCond(Tensor, Tensor, Tensor),
83
84    #[allow(dead_code)]
85    Conv1D {
86        arg: Tensor,
87        kernel: Tensor,
88        padding: usize,
89        stride: usize,
90        dilation: usize,
91    },
92
93    #[allow(dead_code)]
94    Conv2D {
95        arg: Tensor,
96        kernel: Tensor,
97        padding: usize,
98        stride: usize,
99        dilation: usize,
100    },
101
102    #[allow(dead_code)]
103    ConvTranspose2D {
104        arg: Tensor,
105        kernel: Tensor,
106        padding: usize,
107        output_padding: usize,
108        stride: usize,
109        dilation: usize,
110    },
111
112    AvgPool2D {
113        arg: Tensor,
114        kernel_size: (usize, usize),
115        stride: (usize, usize),
116    },
117
118    MaxPool2D {
119        arg: Tensor,
120        kernel_size: (usize, usize),
121        stride: (usize, usize),
122    },
123
124    UpsampleNearest1D(Tensor),
125    UpsampleNearest2D(Tensor),
126
127    Cat(Vec<Tensor>, usize),
128
129    #[allow(dead_code)] // add is currently unused.
130    Affine {
131        arg: Tensor,
132        mul: f64,
133        add: f64,
134    },
135    ToDType(Tensor),
136    Copy(Tensor),
137    Broadcast(Tensor),
138    Narrow(Tensor, usize, usize, usize),
139    SliceScatter0(Tensor, Tensor, usize),
140    Reshape(Tensor),
141    ToDevice(Tensor),
142    Transpose(Tensor, usize, usize),
143    Permute(Tensor, Vec<usize>),
144    Elu(Tensor, f64),
145    Powf(Tensor, f64),
146    CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
147    CustomOp2(
148        Tensor,
149        Tensor,
150        std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
151    ),
152    CustomOp3(
153        Tensor,
154        Tensor,
155        Tensor,
156        std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
157    ),
158}
159
160/// Unary ops that can be defined in user-land.
161pub trait CustomOp1 {
162    // Box<dyn> does not support const yet, so use a function to get the name.
163    fn name(&self) -> &'static str;
164
165    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
166    /// offsets etc so the associated layout should be used to access it.
167    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
168
169    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
170    /// offsets etc so the associated layout should be used to access it.
171    fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
172        Err(crate::Error::Cuda(
173            format!("no cuda implementation for {}", self.name()).into(),
174        ))
175    }
176
177    /// This function takes as argument the argument `arg` used in the forward pass, the result
178    /// produced by the forward operation `res` and the gradient of the result `grad_res`.
179    /// The function should return the gradient of the argument.
180    fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
181        Err(crate::Error::BackwardNotSupported { op: self.name() })
182    }
183}
184
185pub trait CustomOp2 {
186    fn name(&self) -> &'static str;
187
188    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
189    /// offsets etc so the associated layout should be used to access it.
190    fn cpu_fwd(
191        &self,
192        s1: &CpuStorage,
193        l1: &Layout,
194        s2: &CpuStorage,
195        l2: &Layout,
196    ) -> Result<(CpuStorage, Shape)>;
197
198    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
199    /// offsets etc so the associated layout should be used to access it.
200    fn cuda_fwd(
201        &self,
202        _: &CudaStorage,
203        _: &Layout,
204        _: &CudaStorage,
205        _: &Layout,
206    ) -> Result<(CudaStorage, Shape)> {
207        Err(crate::Error::Cuda(
208            format!("no cuda implementation for {}", self.name()).into(),
209        ))
210    }
211
212    fn bwd(
213        &self,
214        _arg1: &Tensor,
215        _arg2: &Tensor,
216        _res: &Tensor,
217        _grad_res: &Tensor,
218    ) -> Result<(Option<Tensor>, Option<Tensor>)> {
219        Err(crate::Error::BackwardNotSupported { op: self.name() })
220    }
221}
222
223pub trait CustomOp3 {
224    fn name(&self) -> &'static str;
225
226    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
227    /// offsets etc so the associated layout should be used to access it.
228    fn cpu_fwd(
229        &self,
230        s1: &CpuStorage,
231        l1: &Layout,
232        s2: &CpuStorage,
233        l2: &Layout,
234        s3: &CpuStorage,
235        l3: &Layout,
236    ) -> Result<(CpuStorage, Shape)>;
237
238    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
239    /// offsets etc so the associated layout should be used to access it.
240    fn cuda_fwd(
241        &self,
242        _: &CudaStorage,
243        _: &Layout,
244        _: &CudaStorage,
245        _: &Layout,
246        _: &CudaStorage,
247        _: &Layout,
248    ) -> Result<(CudaStorage, Shape)> {
249        Err(crate::Error::Cuda(
250            format!("no cuda implementation for {}", self.name()).into(),
251        ))
252    }
253
254    fn bwd(
255        &self,
256        _arg1: &Tensor,
257        _arg2: &Tensor,
258        _arg3: &Tensor,
259        _res: &Tensor,
260        _grad_res: &Tensor,
261    ) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
262        Err(crate::Error::BackwardNotSupported { op: self.name() })
263    }
264}
265
266pub trait UnaryOpT {
267    const NAME: &'static str;
268    const KERNEL: &'static str;
269    const V: Self;
270    fn bf16(v1: bf16) -> bf16;
271    fn f16(v1: f16) -> f16;
272    fn f32(v1: f32) -> f32;
273    fn f64(v1: f64) -> f64;
274    fn u8(v1: u8) -> u8;
275    fn u32(v1: u32) -> u32;
276    fn i64(v1: i64) -> i64;
277
278    // There is no very good way to represent optional function in traits so we go for an explicit
279    // boolean flag to mark the function as existing.
280    const BF16_VEC: bool = false;
281    fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
282    const F16_VEC: bool = false;
283    fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
284    const F32_VEC: bool = false;
285    fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
286    const F64_VEC: bool = false;
287    fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
288}
289
290pub trait BinaryOpT {
291    const NAME: &'static str;
292    const KERNEL: &'static str;
293    const V: Self;
294    fn bf16(v1: bf16, v2: bf16) -> bf16;
295    fn f16(v1: f16, v2: f16) -> f16;
296    fn f32(v1: f32, v2: f32) -> f32;
297    fn f64(v1: f64, v2: f64) -> f64;
298    fn u8(v1: u8, v2: u8) -> u8;
299    fn u32(v1: u32, v2: u32) -> u32;
300    fn i64(v1: i64, v2: i64) -> i64;
301
302    const BF16_VEC: bool = false;
303    fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
304    const F16_VEC: bool = false;
305    fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {}
306    const F32_VEC: bool = false;
307    fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {}
308    const F64_VEC: bool = false;
309    fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {}
310    const U8_VEC: bool = false;
311    fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
312    const U32_VEC: bool = false;
313    fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
314    const I64_VEC: bool = false;
315    fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
316}
317
318pub(crate) struct Add;
319pub(crate) struct Div;
320pub(crate) struct Mul;
321pub(crate) struct Sub;
322pub(crate) struct Maximum;
323pub(crate) struct Minimum;
324pub(crate) struct Exp;
325pub(crate) struct Log;
326pub(crate) struct Sin;
327pub(crate) struct Cos;
328pub(crate) struct Abs;
329pub(crate) struct Neg;
330pub(crate) struct Recip;
331pub(crate) struct Sqr;
332pub(crate) struct Sqrt;
333pub(crate) struct Gelu;
334pub(crate) struct GeluErf;
335pub(crate) struct Erf;
336pub(crate) struct Relu;
337pub(crate) struct Tanh;
338pub(crate) struct Floor;
339pub(crate) struct Ceil;
340pub(crate) struct Round;
341
342macro_rules! bin_op {
343    ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
344        impl BinaryOpT for $op {
345            const NAME: &'static str = $name;
346            const KERNEL: &'static str = concat!("b", $name);
347            const V: Self = $op;
348            #[inline(always)]
349            fn bf16(v1: bf16, v2: bf16) -> bf16 {
350                $e(v1, v2)
351            }
352            #[inline(always)]
353            fn f16(v1: f16, v2: f16) -> f16 {
354                $e(v1, v2)
355            }
356            #[inline(always)]
357            fn f32(v1: f32, v2: f32) -> f32 {
358                $e(v1, v2)
359            }
360            #[inline(always)]
361            fn f64(v1: f64, v2: f64) -> f64 {
362                $e(v1, v2)
363            }
364            #[inline(always)]
365            fn u8(v1: u8, v2: u8) -> u8 {
366                $e(v1, v2)
367            }
368            #[inline(always)]
369            fn u32(v1: u32, v2: u32) -> u32 {
370                $e(v1, v2)
371            }
372            #[inline(always)]
373            fn i64(v1: i64, v2: i64) -> i64 {
374                $e(v1, v2)
375            }
376
377            #[cfg(feature = "mkl")]
378            const F32_VEC: bool = true;
379            #[cfg(feature = "mkl")]
380            const F64_VEC: bool = true;
381            #[cfg(feature = "mkl")]
382            #[inline(always)]
383            fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
384                crate::mkl::$f32_vec(xs1, xs2, ys)
385            }
386            #[cfg(feature = "mkl")]
387            #[inline(always)]
388            fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
389                crate::mkl::$f64_vec(xs1, xs2, ys)
390            }
391
392            #[cfg(feature = "accelerate")]
393            const F32_VEC: bool = true;
394            #[cfg(feature = "accelerate")]
395            const F64_VEC: bool = true;
396            #[cfg(feature = "accelerate")]
397            #[inline(always)]
398            fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
399                crate::accelerate::$f32_vec(xs1, xs2, ys)
400            }
401            #[cfg(feature = "accelerate")]
402            #[inline(always)]
403            fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
404                crate::accelerate::$f64_vec(xs1, xs2, ys)
405            }
406        }
407    };
408}
409
410bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
411bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
412bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
413bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
414bin_op!(
415    Minimum,
416    "minimum",
417    |v1, v2| if v1 > v2 { v2 } else { v1 },
418    vs_min,
419    vd_min
420);
421bin_op!(
422    Maximum,
423    "maximum",
424    |v1, v2| if v1 < v2 { v2 } else { v1 },
425    vs_max,
426    vd_max
427);
428
429#[allow(clippy::redundant_closure_call)]
430macro_rules! unary_op {
431    ($op: ident, $name: literal, $a: ident, $e: expr) => {
432        impl UnaryOpT for $op {
433            const NAME: &'static str = $name;
434            const KERNEL: &'static str = concat!("u", $name);
435            const V: Self = $op;
436            #[inline(always)]
437            fn bf16($a: bf16) -> bf16 {
438                $e
439            }
440            #[inline(always)]
441            fn f16($a: f16) -> f16 {
442                $e
443            }
444            #[inline(always)]
445            fn f32($a: f32) -> f32 {
446                $e
447            }
448            #[inline(always)]
449            fn f64($a: f64) -> f64 {
450                $e
451            }
452            #[inline(always)]
453            fn u8(_: u8) -> u8 {
454                todo!("no unary function for u8")
455            }
456            #[inline(always)]
457            fn u32(_: u32) -> u32 {
458                todo!("no unary function for u32")
459            }
460            #[inline(always)]
461            fn i64(_: i64) -> i64 {
462                todo!("no unary function for i64")
463            }
464        }
465    };
466
467    ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
468        impl UnaryOpT for $op {
469            const NAME: &'static str = $name;
470            const KERNEL: &'static str = concat!("u", $name);
471            const V: Self = $op;
472            #[inline(always)]
473            fn bf16($a: bf16) -> bf16 {
474                $e
475            }
476            #[inline(always)]
477            fn f16($a: f16) -> f16 {
478                $e
479            }
480            #[inline(always)]
481            fn f32($a: f32) -> f32 {
482                $e
483            }
484            #[inline(always)]
485            fn f64($a: f64) -> f64 {
486                $e
487            }
488            #[inline(always)]
489            fn u8(_: u8) -> u8 {
490                todo!("no unary function for u8")
491            }
492            #[inline(always)]
493            fn u32(_: u32) -> u32 {
494                todo!("no unary function for u32")
495            }
496            #[inline(always)]
497            fn i64(_: i64) -> i64 {
498                todo!("no unary function for i64")
499            }
500
501            #[cfg(feature = "mkl")]
502            const F32_VEC: bool = true;
503            #[cfg(feature = "mkl")]
504            const F64_VEC: bool = true;
505            #[cfg(feature = "mkl")]
506            #[inline(always)]
507            fn f32_vec(xs: &[f32], ys: &mut [f32]) {
508                crate::mkl::$f32_vec(xs, ys)
509            }
510            #[cfg(feature = "mkl")]
511            #[inline(always)]
512            fn f64_vec(xs: &[f64], ys: &mut [f64]) {
513                crate::mkl::$f64_vec(xs, ys)
514            }
515
516            #[cfg(feature = "accelerate")]
517            const F32_VEC: bool = true;
518            #[cfg(feature = "accelerate")]
519            const F64_VEC: bool = true;
520            #[cfg(feature = "accelerate")]
521            #[inline(always)]
522            fn f32_vec(xs: &[f32], ys: &mut [f32]) {
523                crate::accelerate::$f32_vec(xs, ys)
524            }
525            #[cfg(feature = "accelerate")]
526            #[inline(always)]
527            fn f64_vec(xs: &[f64], ys: &mut [f64]) {
528                crate::accelerate::$f64_vec(xs, ys)
529            }
530        }
531    };
532}
533
534unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
535unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
536unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
537unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
538unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
539unary_op!(Neg, "neg", v, -v);
540unary_op!(Recip, "recip", v, v.recip());
541unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
542unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
543
544/// `gelu` operation
545/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
546impl UnaryOpT for Gelu {
547    const NAME: &'static str = "gelu";
548    const V: Self = Gelu;
549    #[inline(always)]
550    fn bf16(v: bf16) -> bf16 {
551        bf16::from_f32_const(0.5)
552            * v
553            * (bf16::ONE
554                + bf16::tanh(
555                    (bf16::from_f32_const(2.0) / bf16::PI).sqrt()
556                        * v
557                        * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
558                ))
559    }
560    #[inline(always)]
561    fn f16(v: f16) -> f16 {
562        f16::from_f32_const(0.5)
563            * v
564            * (f16::ONE
565                + f16::tanh(
566                    (f16::from_f32_const(2.0) / f16::PI).sqrt()
567                        * v
568                        * (f16::ONE + f16::from_f32_const(0.044715) * v * v),
569                ))
570    }
571    #[inline(always)]
572    fn f32(v: f32) -> f32 {
573        0.5 * v
574            * (1.0
575                + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
576    }
577    #[inline(always)]
578    fn f64(v: f64) -> f64 {
579        0.5 * v
580            * (1.0
581                + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
582    }
583    #[inline(always)]
584    fn u8(_: u8) -> u8 {
585        0
586    }
587    #[inline(always)]
588    fn u32(_: u32) -> u32 {
589        0
590    }
591    #[inline(always)]
592    fn i64(_: i64) -> i64 {
593        0
594    }
595    const KERNEL: &'static str = "ugelu";
596
597    #[cfg(feature = "mkl")]
598    const F32_VEC: bool = true;
599
600    #[cfg(feature = "mkl")]
601    #[inline(always)]
602    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
603        crate::mkl::vs_gelu(xs, ys)
604    }
605
606    #[cfg(feature = "mkl")]
607    const F64_VEC: bool = true;
608
609    #[cfg(feature = "mkl")]
610    #[inline(always)]
611    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
612        crate::mkl::vd_gelu(xs, ys)
613    }
614
615    #[cfg(feature = "accelerate")]
616    const F32_VEC: bool = true;
617
618    #[cfg(feature = "accelerate")]
619    #[inline(always)]
620    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
621        crate::accelerate::vs_gelu(xs, ys)
622    }
623
624    #[cfg(feature = "accelerate")]
625    const F64_VEC: bool = true;
626
627    #[cfg(feature = "accelerate")]
628    #[inline(always)]
629    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
630        crate::accelerate::vd_gelu(xs, ys)
631    }
632}
633
634impl UnaryOpT for Erf {
635    const NAME: &'static str = "erf";
636    const KERNEL: &'static str = "uerf";
637    const V: Self = Erf;
638    #[inline(always)]
639    fn bf16(v: bf16) -> bf16 {
640        bf16::from_f64(Self::f64(v.to_f64()))
641    }
642    #[inline(always)]
643    fn f16(v: f16) -> f16 {
644        f16::from_f64(Self::f64(v.to_f64()))
645    }
646    #[inline(always)]
647    fn f32(v: f32) -> f32 {
648        Self::f64(v as f64) as f32
649    }
650    #[inline(always)]
651    fn f64(v: f64) -> f64 {
652        crate::cpu::erf::erf(v)
653    }
654    #[inline(always)]
655    fn u8(_: u8) -> u8 {
656        0
657    }
658    #[inline(always)]
659    fn u32(_: u32) -> u32 {
660        0
661    }
662    #[inline(always)]
663    fn i64(_: i64) -> i64 {
664        0
665    }
666}
667
668impl UnaryOpT for Abs {
669    const NAME: &'static str = "abs";
670    const KERNEL: &'static str = "uabs";
671    const V: Self = Abs;
672    #[inline(always)]
673    fn bf16(v: bf16) -> bf16 {
674        v.abs()
675    }
676    #[inline(always)]
677    fn f16(v: f16) -> f16 {
678        v.abs()
679    }
680    #[inline(always)]
681    fn f32(v: f32) -> f32 {
682        v.abs()
683    }
684    #[inline(always)]
685    fn f64(v: f64) -> f64 {
686        v.abs()
687    }
688    #[inline(always)]
689    fn u8(v: u8) -> u8 {
690        v
691    }
692    #[inline(always)]
693    fn u32(v: u32) -> u32 {
694        v
695    }
696    #[inline(always)]
697    fn i64(v: i64) -> i64 {
698        v.abs()
699    }
700}
701
702impl UnaryOpT for Ceil {
703    const NAME: &'static str = "ceil";
704    const KERNEL: &'static str = "uceil";
705    const V: Self = Ceil;
706    #[inline(always)]
707    fn bf16(v: bf16) -> bf16 {
708        v.ceil()
709    }
710    #[inline(always)]
711    fn f16(v: f16) -> f16 {
712        v.ceil()
713    }
714    #[inline(always)]
715    fn f32(v: f32) -> f32 {
716        v.ceil()
717    }
718    #[inline(always)]
719    fn f64(v: f64) -> f64 {
720        v.ceil()
721    }
722    #[inline(always)]
723    fn u8(v: u8) -> u8 {
724        v
725    }
726    #[inline(always)]
727    fn u32(v: u32) -> u32 {
728        v
729    }
730    #[inline(always)]
731    fn i64(v: i64) -> i64 {
732        v
733    }
734}
735
736impl UnaryOpT for Floor {
737    const NAME: &'static str = "floor";
738    const KERNEL: &'static str = "ufloor";
739    const V: Self = Floor;
740    #[inline(always)]
741    fn bf16(v: bf16) -> bf16 {
742        v.floor()
743    }
744    #[inline(always)]
745    fn f16(v: f16) -> f16 {
746        v.floor()
747    }
748    #[inline(always)]
749    fn f32(v: f32) -> f32 {
750        v.floor()
751    }
752    #[inline(always)]
753    fn f64(v: f64) -> f64 {
754        v.floor()
755    }
756    #[inline(always)]
757    fn u8(v: u8) -> u8 {
758        v
759    }
760    #[inline(always)]
761    fn u32(v: u32) -> u32 {
762        v
763    }
764    #[inline(always)]
765    fn i64(v: i64) -> i64 {
766        v
767    }
768}
769
770impl UnaryOpT for Round {
771    const NAME: &'static str = "round";
772    const KERNEL: &'static str = "uround";
773    const V: Self = Round;
774    #[inline(always)]
775    fn bf16(v: bf16) -> bf16 {
776        v.round()
777    }
778    #[inline(always)]
779    fn f16(v: f16) -> f16 {
780        v.round()
781    }
782    #[inline(always)]
783    fn f32(v: f32) -> f32 {
784        v.round()
785    }
786    #[inline(always)]
787    fn f64(v: f64) -> f64 {
788        v.round()
789    }
790    #[inline(always)]
791    fn u8(v: u8) -> u8 {
792        v
793    }
794    #[inline(always)]
795    fn u32(v: u32) -> u32 {
796        v
797    }
798    #[inline(always)]
799    fn i64(v: i64) -> i64 {
800        v
801    }
802}
803
804impl UnaryOpT for GeluErf {
805    const NAME: &'static str = "gelu_erf";
806    const KERNEL: &'static str = "ugelu_erf";
807    const V: Self = GeluErf;
808    #[inline(always)]
809    fn bf16(v: bf16) -> bf16 {
810        bf16::from_f64(Self::f64(v.to_f64()))
811    }
812    #[inline(always)]
813    fn f16(v: f16) -> f16 {
814        f16::from_f64(Self::f64(v.to_f64()))
815    }
816    #[inline(always)]
817    fn f32(v: f32) -> f32 {
818        Self::f64(v as f64) as f32
819    }
820    #[inline(always)]
821    fn f64(v: f64) -> f64 {
822        (crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
823    }
824    #[inline(always)]
825    fn u8(_: u8) -> u8 {
826        0
827    }
828    #[inline(always)]
829    fn u32(_: u32) -> u32 {
830        0
831    }
832    #[inline(always)]
833    fn i64(_: i64) -> i64 {
834        0
835    }
836}
837
838impl UnaryOpT for Relu {
839    const NAME: &'static str = "relu";
840    const KERNEL: &'static str = "urelu";
841    const V: Self = Relu;
842    #[inline(always)]
843    fn bf16(v: bf16) -> bf16 {
844        v.max(bf16::ZERO)
845    }
846    #[inline(always)]
847    fn f16(v: f16) -> f16 {
848        v.max(f16::ZERO)
849    }
850    #[inline(always)]
851    fn f32(v: f32) -> f32 {
852        v.max(0f32)
853    }
854    #[inline(always)]
855    fn f64(v: f64) -> f64 {
856        v.max(0f64)
857    }
858    #[inline(always)]
859    fn u8(v: u8) -> u8 {
860        v
861    }
862    #[inline(always)]
863    fn u32(v: u32) -> u32 {
864        v
865    }
866    #[inline(always)]
867    fn i64(v: i64) -> i64 {
868        v
869    }
870}
871
872/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
873/// properly checked when creating a new value
874#[derive(Clone)]
875pub struct BackpropOp(Option<Op>);
876
877impl BackpropOp {
878    pub(crate) fn none() -> Self {
879        BackpropOp(None)
880    }
881
882    pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
883        let op = if arg.track_op() {
884            Some(f(arg.clone()))
885        } else {
886            None
887        };
888        Self(op)
889    }
890
891    pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
892        let op = if arg1.track_op() || arg2.track_op() {
893            Some(f(arg1.clone(), arg2.clone()))
894        } else {
895            None
896        };
897        Self(op)
898    }
899
900    pub(crate) fn new3(
901        arg1: &Tensor,
902        arg2: &Tensor,
903        arg3: &Tensor,
904        f: impl Fn(Tensor, Tensor, Tensor) -> Op,
905    ) -> Self {
906        let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
907            Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
908        } else {
909            None
910        };
911        Self(op)
912    }
913
914    pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
915        let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
916            let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
917            Some(f(args))
918        } else {
919            None
920        };
921        Self(op)
922    }
923
924    pub(crate) fn is_none(&self) -> bool {
925        self.0.is_none()
926    }
927}
928
929impl std::ops::Deref for BackpropOp {
930    type Target = Option<Op>;
931    fn deref(&self) -> &Self::Target {
932        &self.0
933    }
934}