Skip to main content

candle_core/
op.rs

1//! Tensor Operation Enums and Traits
2//!
3#![allow(clippy::redundant_closure_call)]
4use crate::Tensor;
5use float8::F8E4M3 as f8e4m3;
6use half::{bf16, f16};
7use num_traits::float::Float;
8
9#[derive(Clone, Copy, PartialEq, Eq)]
10pub enum CmpOp {
11    Eq,
12    Ne,
13    Le,
14    Ge,
15    Lt,
16    Gt,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ReduceOp {
21    Sum,
22    Min,
23    Max,
24    ArgMin,
25    ArgMax,
26}
27
28impl ReduceOp {
29    pub(crate) fn name(&self) -> &'static str {
30        match self {
31            Self::ArgMax => "argmax",
32            Self::ArgMin => "argmin",
33            Self::Min => "min",
34            Self::Max => "max",
35            Self::Sum => "sum",
36        }
37    }
38}
39
40// These ops return the same type as their input type.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum BinaryOp {
43    Add,
44    Mul,
45    Sub,
46    Div,
47    Maximum,
48    Minimum,
49}
50
51// Unary ops with no argument
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum UnaryOp {
54    Exp,
55    Log,
56    Sin,
57    Cos,
58    Abs,
59    Neg,
60    Recip,
61    Sqr,
62    Sqrt,
63    Gelu,
64    GeluErf,
65    Erf,
66    Relu,
67    Silu,
68    Tanh,
69    Floor,
70    Ceil,
71    Round,
72    Sign,
73}
74
75#[derive(Clone)]
76pub enum Op {
77    Binary(Tensor, Tensor, BinaryOp),
78    Unary(Tensor, UnaryOp),
79    Cmp(Tensor, CmpOp),
80    // The third argument is the reduced shape with `keepdim=true`.
81    Reduce(Tensor, ReduceOp, Vec<usize>),
82    Matmul(Tensor, Tensor),
83    Gather(Tensor, Tensor, usize),
84    Scatter(Tensor, Tensor, Tensor, usize),
85    ScatterAdd(Tensor, Tensor, Tensor, usize),
86    IndexSelect(Tensor, Tensor, usize),
87    IndexAdd(Tensor, Tensor, Tensor, usize),
88    WhereCond(Tensor, Tensor, Tensor),
89
90    #[allow(dead_code)]
91    Conv1D {
92        arg: Tensor,
93        kernel: Tensor,
94        padding: usize,
95        stride: usize,
96        dilation: usize,
97    },
98
99    #[allow(dead_code)]
100    ConvTranspose1D {
101        arg: Tensor,
102        kernel: Tensor,
103        padding: usize,
104        output_padding: usize,
105        stride: usize,
106        dilation: usize,
107    },
108
109    #[allow(dead_code)]
110    Conv2D {
111        arg: Tensor,
112        kernel: Tensor,
113        padding: usize,
114        stride: usize,
115        dilation: usize,
116    },
117
118    #[allow(dead_code)]
119    ConvTranspose2D {
120        arg: Tensor,
121        kernel: Tensor,
122        padding: usize,
123        output_padding: usize,
124        stride: usize,
125        dilation: usize,
126    },
127
128    AvgPool2D {
129        arg: Tensor,
130        kernel_size: (usize, usize),
131        stride: (usize, usize),
132    },
133
134    MaxPool2D {
135        arg: Tensor,
136        kernel_size: (usize, usize),
137        stride: (usize, usize),
138    },
139
140    UpsampleNearest1D {
141        arg: Tensor,
142        target_size: usize,
143    },
144    UpsampleNearest2D {
145        arg: Tensor,
146        target_h: usize,
147        target_w: usize,
148    },
149    UpsampleBilinear2D {
150        arg: Tensor,
151        target_h: usize,
152        target_w: usize,
153        align_corners: bool,
154    },
155
156    Cat(Vec<Tensor>, usize),
157
158    #[allow(dead_code)] // add is currently unused.
159    Affine {
160        arg: Tensor,
161        mul: f64,
162        add: f64,
163    },
164    ToDType(Tensor),
165    Copy(Tensor),
166    Broadcast(Tensor),
167    Narrow(Tensor, usize, usize, usize),
168    SliceScatter0(Tensor, Tensor, usize),
169    Reshape(Tensor),
170    ToDevice(Tensor),
171    Transpose(Tensor, usize, usize),
172    Permute(Tensor, Vec<usize>),
173    Elu(Tensor, f64),
174    Powf(Tensor, f64),
175    CustomOp1(
176        Tensor,
177        std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
178    ),
179    CustomOp2(
180        Tensor,
181        Tensor,
182        std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
183    ),
184    CustomOp3(
185        Tensor,
186        Tensor,
187        Tensor,
188        std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
189    ),
190}
191
192pub trait UnaryOpT {
193    const NAME: &'static str;
194    const KERNEL: &'static str;
195    const V: Self;
196    fn bf16(v1: bf16) -> bf16;
197    fn f16(v1: f16) -> f16;
198    fn f32(v1: f32) -> f32;
199    fn f64(v1: f64) -> f64;
200    fn u8(v1: u8) -> u8;
201    fn u32(v1: u32) -> u32;
202    fn i16(v1: i16) -> i16;
203    fn i32(v1: i32) -> i32;
204    fn i64(v1: i64) -> i64;
205    fn f8e4m3(v1: f8e4m3) -> f8e4m3;
206
207    #[inline(always)]
208    fn bf16_vec(xs: &[bf16], ys: &mut [bf16]) {
209        xs.iter().zip(ys).for_each(|(&x, y)| *y = Self::bf16(x))
210    }
211    #[inline(always)]
212    fn f16_vec(xs: &[f16], ys: &mut [f16]) {
213        xs.iter().zip(ys).for_each(|(&x, y)| *y = Self::f16(x))
214    }
215    #[inline(always)]
216    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
217        xs.iter().zip(ys).for_each(|(&x, y)| *y = Self::f32(x))
218    }
219    #[inline(always)]
220    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
221        xs.iter().zip(ys).for_each(|(&x, y)| *y = Self::f64(x))
222    }
223}
224
225pub trait BinaryOpT {
226    const NAME: &'static str;
227    const KERNEL: &'static str;
228    const V: Self;
229    fn bf16(v1: bf16, v2: bf16) -> bf16;
230    fn f16(v1: f16, v2: f16) -> f16;
231    fn f32(v1: f32, v2: f32) -> f32;
232    fn f64(v1: f64, v2: f64) -> f64;
233    fn u8(v1: u8, v2: u8) -> u8;
234    fn u32(v1: u32, v2: u32) -> u32;
235    fn i16(v1: i16, v2: i16) -> i16;
236    fn i32(v1: i32, v2: i32) -> i32;
237    fn i64(v1: i64, v2: i64) -> i64;
238    fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3;
239
240    #[inline(always)]
241    fn bf16_vec(xs1: &[bf16], xs2: &[bf16], ys: &mut [bf16]) {
242        xs1.iter()
243            .zip(xs2)
244            .zip(ys)
245            .for_each(|((&a, &b), y)| *y = Self::bf16(a, b))
246    }
247    #[inline(always)]
248    fn f16_vec(xs1: &[f16], xs2: &[f16], ys: &mut [f16]) {
249        xs1.iter()
250            .zip(xs2)
251            .zip(ys)
252            .for_each(|((&a, &b), y)| *y = Self::f16(a, b))
253    }
254    #[inline(always)]
255    fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
256        xs1.iter()
257            .zip(xs2)
258            .zip(ys)
259            .for_each(|((&a, &b), y)| *y = Self::f32(a, b))
260    }
261    #[inline(always)]
262    fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
263        xs1.iter()
264            .zip(xs2)
265            .zip(ys)
266            .for_each(|((&a, &b), y)| *y = Self::f64(a, b))
267    }
268    #[inline(always)]
269    fn u8_vec(xs1: &[u8], xs2: &[u8], ys: &mut [u8]) {
270        xs1.iter()
271            .zip(xs2)
272            .zip(ys)
273            .for_each(|((&a, &b), y)| *y = Self::u8(a, b))
274    }
275    #[inline(always)]
276    fn u32_vec(xs1: &[u32], xs2: &[u32], ys: &mut [u32]) {
277        xs1.iter()
278            .zip(xs2)
279            .zip(ys)
280            .for_each(|((&a, &b), y)| *y = Self::u32(a, b))
281    }
282    #[inline(always)]
283    fn i16_vec(xs1: &[i16], xs2: &[i16], ys: &mut [i16]) {
284        xs1.iter()
285            .zip(xs2)
286            .zip(ys)
287            .for_each(|((&a, &b), y)| *y = Self::i16(a, b))
288    }
289    #[inline(always)]
290    fn i32_vec(xs1: &[i32], xs2: &[i32], ys: &mut [i32]) {
291        xs1.iter()
292            .zip(xs2)
293            .zip(ys)
294            .for_each(|((&a, &b), y)| *y = Self::i32(a, b))
295    }
296    #[inline(always)]
297    fn i64_vec(xs1: &[i64], xs2: &[i64], ys: &mut [i64]) {
298        xs1.iter()
299            .zip(xs2)
300            .zip(ys)
301            .for_each(|((&a, &b), y)| *y = Self::i64(a, b))
302    }
303
304    // Scalar-broadcast variants: ys[i] = f(xs[i], scalar).
305    // Used by binary_map_vec for the (1,0) inner-stride branch where one tensor
306    // is contiguous and the other broadcasts a single value.
307    #[inline(always)]
308    fn bf16_scalar_vec(scalar: bf16, xs: &[bf16], ys: &mut [bf16]) {
309        xs.iter()
310            .zip(ys)
311            .for_each(|(&x, y)| *y = Self::bf16(x, scalar))
312    }
313    #[inline(always)]
314    fn f16_scalar_vec(scalar: f16, xs: &[f16], ys: &mut [f16]) {
315        xs.iter()
316            .zip(ys)
317            .for_each(|(&x, y)| *y = Self::f16(x, scalar))
318    }
319    #[inline(always)]
320    fn f32_scalar_vec(scalar: f32, xs: &[f32], ys: &mut [f32]) {
321        xs.iter()
322            .zip(ys)
323            .for_each(|(&x, y)| *y = Self::f32(x, scalar))
324    }
325    #[inline(always)]
326    fn f64_scalar_vec(scalar: f64, xs: &[f64], ys: &mut [f64]) {
327        xs.iter()
328            .zip(ys)
329            .for_each(|(&x, y)| *y = Self::f64(x, scalar))
330    }
331    #[inline(always)]
332    fn u8_scalar_vec(scalar: u8, xs: &[u8], ys: &mut [u8]) {
333        xs.iter()
334            .zip(ys)
335            .for_each(|(&x, y)| *y = Self::u8(x, scalar))
336    }
337    #[inline(always)]
338    fn u32_scalar_vec(scalar: u32, xs: &[u32], ys: &mut [u32]) {
339        xs.iter()
340            .zip(ys)
341            .for_each(|(&x, y)| *y = Self::u32(x, scalar))
342    }
343    #[inline(always)]
344    fn i16_scalar_vec(scalar: i16, xs: &[i16], ys: &mut [i16]) {
345        xs.iter()
346            .zip(ys)
347            .for_each(|(&x, y)| *y = Self::i16(x, scalar))
348    }
349    #[inline(always)]
350    fn i32_scalar_vec(scalar: i32, xs: &[i32], ys: &mut [i32]) {
351        xs.iter()
352            .zip(ys)
353            .for_each(|(&x, y)| *y = Self::i32(x, scalar))
354    }
355    #[inline(always)]
356    fn i64_scalar_vec(scalar: i64, xs: &[i64], ys: &mut [i64]) {
357        xs.iter()
358            .zip(ys)
359            .for_each(|(&x, y)| *y = Self::i64(x, scalar))
360    }
361}
362
363pub struct Add;
364pub struct Div;
365pub struct Mul;
366pub struct Sub;
367pub struct Maximum;
368pub struct Minimum;
369pub struct Exp;
370pub struct Log;
371pub struct Sin;
372pub struct Cos;
373pub struct Abs;
374pub struct Neg;
375pub struct Recip;
376pub struct Sqr;
377pub struct Sqrt;
378pub struct Gelu;
379pub struct GeluErf;
380pub struct Erf;
381pub struct Relu;
382pub struct Silu;
383pub struct Tanh;
384pub struct Floor;
385pub struct Ceil;
386pub struct Round;
387pub struct Sign;
388
389// `$name` is an ident; stringify! derives the NAME string and the KERNEL prefix automatically.
390// The optional `$vec_op` names a method on `crate::cpu::kernels::VecOps` that overrides
391// f32_vec/f64_vec with an optimised implementation (MKL / Accelerate / SIMD).
392macro_rules! bin_op {
393    ($op:ident, $name:ident, $e:expr $(, $vec_op:ident)?) => {
394        impl BinaryOpT for $op {
395            const NAME: &'static str = stringify!($name);
396            const KERNEL: &'static str = concat!("b", stringify!($name));
397            const V: Self = $op;
398            #[inline(always)]
399            fn bf16(v1: bf16, v2: bf16) -> bf16 { $e(v1, v2) }
400            #[inline(always)]
401            fn f16(v1: f16, v2: f16) -> f16 { $e(v1, v2) }
402            #[inline(always)]
403            fn f32(v1: f32, v2: f32) -> f32 { $e(v1, v2) }
404            #[inline(always)]
405            fn f64(v1: f64, v2: f64) -> f64 { $e(v1, v2) }
406            #[inline(always)]
407            fn u8(v1: u8, v2: u8) -> u8 { $e(v1, v2) }
408            #[inline(always)]
409            fn u32(v1: u32, v2: u32) -> u32 { $e(v1, v2) }
410            #[inline(always)]
411            fn i16(v1: i16, v2: i16) -> i16 { $e(v1, v2) }
412            #[inline(always)]
413            fn i32(v1: i32, v2: i32) -> i32 { $e(v1, v2) }
414            #[inline(always)]
415            fn i64(v1: i64, v2: i64) -> i64 { $e(v1, v2) }
416            #[inline(always)]
417            fn f8e4m3(v1: f8e4m3, v2: f8e4m3) -> f8e4m3 { $e(v1, v2) }
418            $(
419                #[inline(always)]
420                fn f32_vec(lhs: &[f32], rhs: &[f32], res: &mut [f32]) {
421                    <f32 as crate::cpu::kernels::VecOps>::$vec_op(lhs, rhs, res)
422                }
423                #[inline(always)]
424                fn f64_vec(lhs: &[f64], rhs: &[f64], res: &mut [f64]) {
425                    <f64 as crate::cpu::kernels::VecOps>::$vec_op(lhs, rhs, res)
426                }
427                #[inline(always)]
428                fn bf16_vec(lhs: &[bf16], rhs: &[bf16], res: &mut [bf16]) {
429                    <bf16 as crate::cpu::kernels::VecOps>::$vec_op(lhs, rhs, res)
430                }
431                #[inline(always)]
432                fn f16_vec(lhs: &[f16], rhs: &[f16], res: &mut [f16]) {
433                    <f16 as crate::cpu::kernels::VecOps>::$vec_op(lhs, rhs, res)
434                }
435                #[inline(always)]
436                fn bf16_scalar_vec(scalar: bf16, xs: &[bf16], ys: &mut [bf16]) {
437                    <bf16 as crate::cpu::kernels::VecOps>::scalar_add(scalar, xs, ys)
438                }
439                #[inline(always)]
440                fn f16_scalar_vec(scalar: f16, xs: &[f16], ys: &mut [f16]) {
441                    <f16 as crate::cpu::kernels::VecOps>::scalar_add(scalar, xs, ys)
442                }
443                #[inline(always)]
444                fn f32_scalar_vec(scalar: f32, xs: &[f32], ys: &mut [f32]) {
445                    <f32 as crate::cpu::kernels::VecOps>::scalar_add(scalar, xs, ys)
446                }
447                #[inline(always)]
448                fn f64_scalar_vec(scalar: f64, xs: &[f64], ys: &mut [f64]) {
449                    <f64 as crate::cpu::kernels::VecOps>::scalar_add(scalar, xs, ys)
450                }
451            )?
452        }
453    };
454}
455
456bin_op!(Add, add, |v1, v2| v1 + v2, vec_add);
457bin_op!(Sub, sub, |v1, v2| v1 - v2);
458bin_op!(Mul, mul, |v1, v2| v1 * v2);
459bin_op!(Div, div, |v1, v2| v1 / v2);
460bin_op!(Minimum, minimum, |v1, v2| if v1 > v2 { v2 } else { v1 });
461bin_op!(Maximum, maximum, |v1, v2| if v1 < v2 { v2 } else { v1 });
462
463#[allow(clippy::redundant_closure_call)]
464macro_rules! unary_op {
465    ($op: ident, $name: literal, $a: ident, $e: expr) => {
466        impl UnaryOpT for $op {
467            const NAME: &'static str = $name;
468            const KERNEL: &'static str = concat!("u", $name);
469            const V: Self = $op;
470            #[inline(always)]
471            fn bf16($a: bf16) -> bf16 {
472                $e
473            }
474            #[inline(always)]
475            fn f16($a: f16) -> f16 {
476                $e
477            }
478            #[inline(always)]
479            fn f32($a: f32) -> f32 {
480                $e
481            }
482            #[inline(always)]
483            fn f64($a: f64) -> f64 {
484                $e
485            }
486            #[inline(always)]
487            fn u8(_: u8) -> u8 {
488                todo!("no unary function for u8")
489            }
490            #[inline(always)]
491            fn u32(_: u32) -> u32 {
492                todo!("no unary function for u32")
493            }
494            #[inline(always)]
495            fn i16(_: i16) -> i16 {
496                todo!("no unary function for i16")
497            }
498            #[inline(always)]
499            fn i32(_: i32) -> i32 {
500                todo!("no unary function for i32")
501            }
502            #[inline(always)]
503            fn i64(_: i64) -> i64 {
504                todo!("no unary function for i64")
505            }
506            #[inline(always)]
507            fn f8e4m3($a: f8e4m3) -> f8e4m3 {
508                $e
509            }
510        }
511    };
512
513    ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
514        impl UnaryOpT for $op {
515            const NAME: &'static str = $name;
516            const KERNEL: &'static str = concat!("u", $name);
517            const V: Self = $op;
518            #[inline(always)]
519            fn bf16($a: bf16) -> bf16 {
520                $e
521            }
522            #[inline(always)]
523            fn f16($a: f16) -> f16 {
524                $e
525            }
526            #[inline(always)]
527            fn f32($a: f32) -> f32 {
528                $e
529            }
530            #[inline(always)]
531            fn f64($a: f64) -> f64 {
532                $e
533            }
534            #[inline(always)]
535            fn u8(_: u8) -> u8 {
536                todo!("no unary function for u8")
537            }
538            #[inline(always)]
539            fn u32(_: u32) -> u32 {
540                todo!("no unary function for u32")
541            }
542            #[inline(always)]
543            fn i16(_: i16) -> i16 {
544                todo!("no unary function for i16")
545            }
546            #[inline(always)]
547            fn i32(_: i32) -> i32 {
548                todo!("no unary function for i32")
549            }
550            #[inline(always)]
551            fn i64(_: i64) -> i64 {
552                todo!("no unary function for i64")
553            }
554            #[inline(always)]
555            fn f8e4m3($a: f8e4m3) -> f8e4m3 {
556                $e
557            }
558
559            #[cfg(feature = "mkl")]
560            #[inline(always)]
561            fn f32_vec(xs: &[f32], ys: &mut [f32]) {
562                crate::mkl::$f32_vec(xs, ys)
563            }
564            #[cfg(feature = "mkl")]
565            #[inline(always)]
566            fn f64_vec(xs: &[f64], ys: &mut [f64]) {
567                crate::mkl::$f64_vec(xs, ys)
568            }
569
570            #[cfg(feature = "accelerate")]
571            #[inline(always)]
572            fn f32_vec(xs: &[f32], ys: &mut [f32]) {
573                crate::accelerate::$f32_vec(xs, ys)
574            }
575            #[cfg(feature = "accelerate")]
576            #[inline(always)]
577            fn f64_vec(xs: &[f64], ys: &mut [f64]) {
578                crate::accelerate::$f64_vec(xs, ys)
579            }
580        }
581    };
582}
583
584unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
585unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
586unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
587unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
588unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
589unary_op!(Neg, "neg", v, -v);
590unary_op!(Recip, "recip", v, v.recip());
591unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
592unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
593
594// Hardcode the value for sqrt(2/pi)
595// https://github.com/huggingface/candle/issues/1982
596#[allow(clippy::excessive_precision)]
597const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
598#[allow(clippy::excessive_precision)]
599const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
600
601/// Tanh based approximation of the `gelu` operation
602/// GeluErf is the more precise one.
603/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
604impl UnaryOpT for Gelu {
605    const NAME: &'static str = "gelu";
606    const V: Self = Gelu;
607    #[inline(always)]
608    fn bf16(v: bf16) -> bf16 {
609        bf16::from_f32_const(0.5)
610            * v
611            * (bf16::ONE
612                + bf16::tanh(
613                    bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
614                        * v
615                        * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
616                ))
617    }
618    #[inline(always)]
619    fn f16(v: f16) -> f16 {
620        f16::from_f32_const(0.5)
621            * v
622            * (f16::ONE
623                + f16::tanh(
624                    f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
625                        * v
626                        * (f16::ONE + f16::from_f32_const(0.044715) * v * v),
627                ))
628    }
629    #[inline(always)]
630    fn f32(v: f32) -> f32 {
631        0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
632    }
633    #[inline(always)]
634    fn f64(v: f64) -> f64 {
635        0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
636    }
637    #[inline(always)]
638    fn u8(_: u8) -> u8 {
639        0
640    }
641    #[inline(always)]
642    fn u32(_: u32) -> u32 {
643        0
644    }
645    #[inline(always)]
646    fn i16(_: i16) -> i16 {
647        0
648    }
649    #[inline(always)]
650    fn i32(_: i32) -> i32 {
651        0
652    }
653    #[inline(always)]
654    fn i64(_: i64) -> i64 {
655        0
656    }
657    #[inline(always)]
658    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
659        f8e4m3::from_f32(0.5)
660            * v
661            * (f8e4m3::ONE
662                + f8e4m3::tanh(
663                    f8e4m3::from_f32(SQRT_TWO_OVER_PI_F32)
664                        * v
665                        * (f8e4m3::ONE + f8e4m3::from_f32(0.044715) * v * v),
666                ))
667    }
668    const KERNEL: &'static str = "ugelu";
669
670    #[cfg(feature = "mkl")]
671    #[inline(always)]
672    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
673        crate::mkl::vs_gelu(xs, ys)
674    }
675
676    #[cfg(feature = "mkl")]
677    #[inline(always)]
678    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
679        crate::mkl::vd_gelu(xs, ys)
680    }
681
682    #[cfg(feature = "accelerate")]
683    #[inline(always)]
684    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
685        crate::accelerate::vs_gelu(xs, ys)
686    }
687
688    #[cfg(feature = "accelerate")]
689    #[inline(always)]
690    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
691        crate::accelerate::vd_gelu(xs, ys)
692    }
693}
694
695/// `erf` operation
696/// <https://en.wikipedia.org/wiki/Error_function>
697impl UnaryOpT for Erf {
698    const NAME: &'static str = "erf";
699    const KERNEL: &'static str = "uerf";
700    const V: Self = Erf;
701    #[inline(always)]
702    fn bf16(v: bf16) -> bf16 {
703        bf16::from_f64(Self::f64(v.to_f64()))
704    }
705    #[inline(always)]
706    fn f16(v: f16) -> f16 {
707        f16::from_f64(Self::f64(v.to_f64()))
708    }
709    #[inline(always)]
710    fn f32(v: f32) -> f32 {
711        crate::cpu::erf::erf_f32(v)
712    }
713    #[inline(always)]
714    fn f64(v: f64) -> f64 {
715        crate::cpu::erf::erf_f64(v)
716    }
717    #[inline(always)]
718    fn u8(_: u8) -> u8 {
719        0
720    }
721    #[inline(always)]
722    fn u32(_: u32) -> u32 {
723        0
724    }
725    #[inline(always)]
726    fn i16(_: i16) -> i16 {
727        0
728    }
729    #[inline(always)]
730    fn i32(_: i32) -> i32 {
731        0
732    }
733    #[inline(always)]
734    fn i64(_: i64) -> i64 {
735        0
736    }
737    #[inline(always)]
738    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
739        f8e4m3::from_f64(Self::f64(v.to_f64()))
740    }
741}
742
743/// Silu operation
744impl UnaryOpT for Silu {
745    const NAME: &'static str = "silu";
746    const V: Self = Silu;
747    #[inline(always)]
748    fn bf16(v: bf16) -> bf16 {
749        v / (bf16::ONE + (-v).exp())
750    }
751    #[inline(always)]
752    fn f16(v: f16) -> f16 {
753        v / (f16::ONE + (-v).exp())
754    }
755    #[inline(always)]
756    fn f32(v: f32) -> f32 {
757        v / (1.0 + (-v).exp())
758    }
759    #[inline(always)]
760    fn f64(v: f64) -> f64 {
761        v / (1.0 + (-v).exp())
762    }
763    #[inline(always)]
764    fn u8(_: u8) -> u8 {
765        0
766    }
767    #[inline(always)]
768    fn u32(_: u32) -> u32 {
769        0
770    }
771    #[inline(always)]
772    fn i16(_: i16) -> i16 {
773        0
774    }
775    #[inline(always)]
776    fn i32(_: i32) -> i32 {
777        0
778    }
779    #[inline(always)]
780    fn i64(_: i64) -> i64 {
781        0
782    }
783    #[inline(always)]
784    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
785        v / (f8e4m3::ONE + (-v).exp())
786    }
787    const KERNEL: &'static str = "usilu";
788
789    #[cfg(feature = "mkl")]
790    #[inline(always)]
791    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
792        crate::mkl::vs_silu(xs, ys)
793    }
794
795    #[cfg(feature = "mkl")]
796    #[inline(always)]
797    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
798        crate::mkl::vd_silu(xs, ys)
799    }
800
801    #[cfg(feature = "accelerate")]
802    #[inline(always)]
803    fn f32_vec(xs: &[f32], ys: &mut [f32]) {
804        crate::accelerate::vs_silu(xs, ys)
805    }
806
807    #[cfg(feature = "accelerate")]
808    #[inline(always)]
809    fn f64_vec(xs: &[f64], ys: &mut [f64]) {
810        crate::accelerate::vd_silu(xs, ys)
811    }
812}
813
814impl UnaryOpT for Abs {
815    const NAME: &'static str = "abs";
816    const KERNEL: &'static str = "uabs";
817    const V: Self = Abs;
818    #[inline(always)]
819    fn bf16(v: bf16) -> bf16 {
820        v.abs()
821    }
822    #[inline(always)]
823    fn f16(v: f16) -> f16 {
824        v.abs()
825    }
826    #[inline(always)]
827    fn f32(v: f32) -> f32 {
828        v.abs()
829    }
830    #[inline(always)]
831    fn f64(v: f64) -> f64 {
832        v.abs()
833    }
834    #[inline(always)]
835    fn u8(v: u8) -> u8 {
836        v
837    }
838    #[inline(always)]
839    fn u32(v: u32) -> u32 {
840        v
841    }
842    #[inline(always)]
843    fn i16(v: i16) -> i16 {
844        v.abs()
845    }
846    #[inline(always)]
847    fn i32(v: i32) -> i32 {
848        v.abs()
849    }
850    #[inline(always)]
851    fn i64(v: i64) -> i64 {
852        v.abs()
853    }
854    #[inline(always)]
855    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
856        v.abs()
857    }
858}
859
860impl UnaryOpT for Ceil {
861    const NAME: &'static str = "ceil";
862    const KERNEL: &'static str = "uceil";
863    const V: Self = Ceil;
864    #[inline(always)]
865    fn bf16(v: bf16) -> bf16 {
866        v.ceil()
867    }
868    #[inline(always)]
869    fn f16(v: f16) -> f16 {
870        v.ceil()
871    }
872    #[inline(always)]
873    fn f32(v: f32) -> f32 {
874        v.ceil()
875    }
876    #[inline(always)]
877    fn f64(v: f64) -> f64 {
878        v.ceil()
879    }
880    #[inline(always)]
881    fn u8(v: u8) -> u8 {
882        v
883    }
884    #[inline(always)]
885    fn u32(v: u32) -> u32 {
886        v
887    }
888    #[inline(always)]
889    fn i16(v: i16) -> i16 {
890        v
891    }
892    #[inline(always)]
893    fn i32(v: i32) -> i32 {
894        v
895    }
896    #[inline(always)]
897    fn i64(v: i64) -> i64 {
898        v
899    }
900    #[inline(always)]
901    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
902        v.ceil()
903    }
904}
905
906impl UnaryOpT for Floor {
907    const NAME: &'static str = "floor";
908    const KERNEL: &'static str = "ufloor";
909    const V: Self = Floor;
910    #[inline(always)]
911    fn bf16(v: bf16) -> bf16 {
912        v.floor()
913    }
914    #[inline(always)]
915    fn f16(v: f16) -> f16 {
916        v.floor()
917    }
918    #[inline(always)]
919    fn f32(v: f32) -> f32 {
920        v.floor()
921    }
922    #[inline(always)]
923    fn f64(v: f64) -> f64 {
924        v.floor()
925    }
926    #[inline(always)]
927    fn u8(v: u8) -> u8 {
928        v
929    }
930    #[inline(always)]
931    fn u32(v: u32) -> u32 {
932        v
933    }
934    #[inline(always)]
935    fn i16(v: i16) -> i16 {
936        v
937    }
938    #[inline(always)]
939    fn i32(v: i32) -> i32 {
940        v
941    }
942    #[inline(always)]
943    fn i64(v: i64) -> i64 {
944        v
945    }
946    #[inline(always)]
947    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
948        v.floor()
949    }
950}
951
952impl UnaryOpT for Round {
953    const NAME: &'static str = "round";
954    const KERNEL: &'static str = "uround";
955    const V: Self = Round;
956    #[inline(always)]
957    fn bf16(v: bf16) -> bf16 {
958        v.round()
959    }
960    #[inline(always)]
961    fn f16(v: f16) -> f16 {
962        v.round()
963    }
964    #[inline(always)]
965    fn f32(v: f32) -> f32 {
966        v.round()
967    }
968    #[inline(always)]
969    fn f64(v: f64) -> f64 {
970        v.round()
971    }
972    #[inline(always)]
973    fn u8(v: u8) -> u8 {
974        v
975    }
976    #[inline(always)]
977    fn u32(v: u32) -> u32 {
978        v
979    }
980    #[inline(always)]
981    fn i16(v: i16) -> i16 {
982        v
983    }
984    #[inline(always)]
985    fn i32(v: i32) -> i32 {
986        v
987    }
988    #[inline(always)]
989    fn i64(v: i64) -> i64 {
990        v
991    }
992    #[inline(always)]
993    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
994        v.round()
995    }
996}
997
998impl UnaryOpT for GeluErf {
999    const NAME: &'static str = "gelu_erf";
1000    const KERNEL: &'static str = "ugelu_erf";
1001    const V: Self = GeluErf;
1002    #[inline(always)]
1003    fn bf16(v: bf16) -> bf16 {
1004        bf16::from_f64(Self::f64(v.to_f64()))
1005    }
1006    #[inline(always)]
1007    fn f16(v: f16) -> f16 {
1008        f16::from_f64(Self::f64(v.to_f64()))
1009    }
1010    #[inline(always)]
1011    fn f32(v: f32) -> f32 {
1012        (crate::cpu::erf::erf_f32(v * std::f32::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v
1013    }
1014    #[inline(always)]
1015    fn f64(v: f64) -> f64 {
1016        (crate::cpu::erf::erf_f64(v * std::f64::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v
1017    }
1018    #[inline(always)]
1019    fn u8(_: u8) -> u8 {
1020        0
1021    }
1022    #[inline(always)]
1023    fn u32(_: u32) -> u32 {
1024        0
1025    }
1026    #[inline(always)]
1027    fn i16(_: i16) -> i16 {
1028        0
1029    }
1030    #[inline(always)]
1031    fn i32(_: i32) -> i32 {
1032        0
1033    }
1034    #[inline(always)]
1035    fn i64(_: i64) -> i64 {
1036        0
1037    }
1038    #[inline(always)]
1039    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
1040        f8e4m3::from_f32(Self::f32(v.to_f32()))
1041    }
1042}
1043
1044impl UnaryOpT for Relu {
1045    const NAME: &'static str = "relu";
1046    const KERNEL: &'static str = "urelu";
1047    const V: Self = Relu;
1048    #[inline(always)]
1049    fn bf16(v: bf16) -> bf16 {
1050        v.max(bf16::ZERO)
1051    }
1052    #[inline(always)]
1053    fn f16(v: f16) -> f16 {
1054        v.max(f16::ZERO)
1055    }
1056    #[inline(always)]
1057    fn f32(v: f32) -> f32 {
1058        v.max(0f32)
1059    }
1060    #[inline(always)]
1061    fn f64(v: f64) -> f64 {
1062        v.max(0f64)
1063    }
1064    #[inline(always)]
1065    fn u8(v: u8) -> u8 {
1066        v
1067    }
1068    #[inline(always)]
1069    fn u32(v: u32) -> u32 {
1070        v
1071    }
1072    #[inline(always)]
1073    fn i16(v: i16) -> i16 {
1074        v.max(0)
1075    }
1076    #[inline(always)]
1077    fn i32(v: i32) -> i32 {
1078        v.max(0)
1079    }
1080    #[inline(always)]
1081    fn i64(v: i64) -> i64 {
1082        v.max(0)
1083    }
1084    #[inline(always)]
1085    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
1086        v.max(f8e4m3::ZERO)
1087    }
1088}
1089
1090/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
1091/// properly checked when creating a new value
1092#[derive(Clone)]
1093pub struct BackpropOp(Option<Op>);
1094
1095impl BackpropOp {
1096    pub fn none() -> Self {
1097        BackpropOp(None)
1098    }
1099
1100    pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
1101        let op = if arg.track_op() {
1102            Some(f(arg.clone()))
1103        } else {
1104            None
1105        };
1106        Self(op)
1107    }
1108
1109    pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
1110        let op = if arg1.track_op() || arg2.track_op() {
1111            Some(f(arg1.clone(), arg2.clone()))
1112        } else {
1113            None
1114        };
1115        Self(op)
1116    }
1117
1118    pub(crate) fn new3(
1119        arg1: &Tensor,
1120        arg2: &Tensor,
1121        arg3: &Tensor,
1122        f: impl Fn(Tensor, Tensor, Tensor) -> Op,
1123    ) -> Self {
1124        let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
1125            Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
1126        } else {
1127            None
1128        };
1129        Self(op)
1130    }
1131
1132    pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
1133        let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
1134            let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
1135            Some(f(args))
1136        } else {
1137            None
1138        };
1139        Self(op)
1140    }
1141
1142    pub(crate) fn is_none(&self) -> bool {
1143        self.0.is_none()
1144    }
1145}
1146
1147impl std::ops::Deref for BackpropOp {
1148    type Target = Option<Op>;
1149    fn deref(&self) -> &Self::Target {
1150        &self.0
1151    }
1152}
1153
1154impl UnaryOpT for Sign {
1155    const NAME: &'static str = "sign";
1156    const KERNEL: &'static str = "usign";
1157    const V: Self = Sign;
1158    #[inline(always)]
1159    fn bf16(v: bf16) -> bf16 {
1160        bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
1161    }
1162    #[inline(always)]
1163    fn f16(v: f16) -> f16 {
1164        f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
1165    }
1166    #[inline(always)]
1167    fn f32(v: f32) -> f32 {
1168        f32::from(v > 0.) - f32::from(v < 0.)
1169    }
1170    #[inline(always)]
1171    fn f64(v: f64) -> f64 {
1172        f64::from(v > 0.) - f64::from(v < 0.)
1173    }
1174    #[inline(always)]
1175    fn u8(v: u8) -> u8 {
1176        u8::min(1, v)
1177    }
1178    #[inline(always)]
1179    fn u32(v: u32) -> u32 {
1180        u32::min(1, v)
1181    }
1182    #[inline(always)]
1183    fn i16(v: i16) -> i16 {
1184        (v > 0) as i16 - (v < 0) as i16
1185    }
1186    #[inline(always)]
1187    fn i32(v: i32) -> i32 {
1188        (v > 0) as i32 - (v < 0) as i32
1189    }
1190    #[inline(always)]
1191    fn i64(v: i64) -> i64 {
1192        (v > 0) as i64 - (v < 0) as i64
1193    }
1194    #[inline(always)]
1195    fn f8e4m3(v: f8e4m3) -> f8e4m3 {
1196        if v > f8e4m3::ZERO {
1197            f8e4m3::ONE
1198        } else if v < f8e4m3::ZERO {
1199            -f8e4m3::ONE
1200        } else {
1201            f8e4m3::ZERO
1202        }
1203    }
1204}